Skip to content

Commit

Permalink
Persist model definition in model metadata (#1527)
Browse files Browse the repository at this point in the history
* Add MethodComponentContext to ModelMetadata

Signed-off-by: Ryan Bogan <[email protected]>

* Add changelog

Signed-off-by: Ryan Bogan <[email protected]>

* Address PR Comments

Signed-off-by: Ryan Bogan <[email protected]>

* Address PR Comments

Signed-off-by: Ryan Bogan <[email protected]>

* Change fromString

Signed-off-by: Ryan Bogan <[email protected]>

* Address PR Comments

Signed-off-by: Ryan Bogan <[email protected]>

* Address PR Comments

Signed-off-by: Ryan Bogan <[email protected]>

* Address PR Comments

Signed-off-by: Ryan Bogan <[email protected]>

* Fix spotless

Signed-off-by: Ryan Bogan <[email protected]>

---------

Signed-off-by: Ryan Bogan <[email protected]>
  • Loading branch information
ryanbogan authored Mar 15, 2024
1 parent 2959d06 commit 4734d88
Show file tree
Hide file tree
Showing 23 changed files with 834 additions and 121 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Optize Faiss Query With Filters: Reduce iteration and memory for id filter [#1402](https://github.com/opensearch-project/k-NN/pull/1402)
* Detect AVX2 Dynamically on the System [#1502](https://github.com/opensearch-project/k-NN/pull/1502)
* Validate zero vector when using cosine metric [#1501](https://github.com/opensearch-project/k-NN/pull/1501)
* Persist model definition in model metadata [#1527] (https://github.com/opensearch-project/k-NN/pull/1527)
### Bug Fixes
* Disable sdc table for HNSWPQ read-only indices [#1518](https://github.com/opensearch-project/k-NN/pull/1518)
* Switch SpaceType.INNERPRODUCT's vector similarity function to MAXIMUM_INNER_PRODUCT [#1532](https://github.com/opensearch-project/k-NN/pull/1532)
Expand Down
1 change: 1 addition & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ public class KNNConstants {
public static final String MODEL_DESCRIPTION = "description";
public static final String MODEL_ERROR = "error";
public static final String MODEL_NODE_ASSIGNMENT = "training_node_assignment";
public static final String MODEL_METHOD_COMPONENT_CONTEXT = "model_definition";
public static final String PARAM_SIZE = "size";
public static final Integer SEARCH_MODEL_MIN_SIZE = 1;
public static final Integer SEARCH_MODEL_MAX_SIZE = 1000;
Expand Down
3 changes: 3 additions & 0 deletions src/main/java/org/opensearch/knn/index/IndexUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,16 @@
public class IndexUtil {

public static final String MODEL_NODE_ASSIGNMENT_KEY = KNNConstants.MODEL_NODE_ASSIGNMENT;
public static final String MODEL_METHOD_COMPONENT_CONTEXT_KEY = KNNConstants.MODEL_METHOD_COMPONENT_CONTEXT;

private static final Version MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED = Version.V_2_11_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT = Version.V_2_12_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_METHOD_COMPONENT_CONTEXT = Version.V_2_13_0;
private static final Map<String, Version> minimalRequiredVersionMap = new HashMap<String, Version>() {
{
put("ignore_unmapped", MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED);
put(MODEL_NODE_ASSIGNMENT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT);
put(MODEL_METHOD_COMPONENT_CONTEXT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_METHOD_COMPONENT_CONTEXT);
}
};

Expand Down
202 changes: 202 additions & 0 deletions src/main/java/org/opensearch/knn/index/MethodComponentContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,29 @@

package org.opensearch.knn.index;

import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import org.apache.commons.lang.math.NumberUtils;
import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.xcontent.ToXContentFragment;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.mapper.MapperParsingException;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;
import org.opensearch.knn.indices.ModelMetadata;

import static org.opensearch.knn.common.KNNConstants.NAME;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
Expand All @@ -41,6 +46,13 @@
@RequiredArgsConstructor
public class MethodComponentContext implements ToXContentFragment, Writeable {

// EMPTY method component context can only occur if a model originated on a cluster before 2.13.0 and the cluster is then upgraded to
// 2.13.0
public static final MethodComponentContext EMPTY = new MethodComponentContext("", Collections.emptyMap());

private static final String DELIMITER = ";";
private static final String DELIMITER_PLACEHOLDER = "$%$";

@Getter
private final String name;
private final Map<String, Object> parameters;
Expand Down Expand Up @@ -161,6 +173,15 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
return builder;
}

public static MethodComponentContext fromXContent(XContentParser xContentParser) throws IOException {
// If it is a fresh parser, move to the first token
if (xContentParser.currentToken() == null) {
xContentParser.nextToken();
}
Map<String, Object> parsedMap = xContentParser.map();
return MethodComponentContext.parse(parsedMap);
}

@Override
public boolean equals(Object obj) {
if (this == obj) return true;
Expand Down Expand Up @@ -193,6 +214,187 @@ public Map<String, Object> getParameters() {
return parameters;
}

/**
*
* Provides a String representation of MethodComponentContext
* Sample return:
* {name=ivf;parameters=[nlist=4;type=fp16;encoder={name=sq;parameters=[nprobes=2;clip=false;]};]}
*
* @return string representation
*/
public String toClusterStateString() {
StringBuilder stringBuilder = new StringBuilder();
stringBuilder.append("{name=").append(name).append(DELIMITER);
stringBuilder.append("parameters=[");
if (Objects.nonNull(parameters)) {
for (Map.Entry<String, Object> entry : parameters.entrySet()) {
stringBuilder.append(entry.getKey()).append("=");
Object objectValue = entry.getValue();
String value;
if (objectValue instanceof MethodComponentContext) {
value = ((MethodComponentContext) objectValue).toClusterStateString();
} else {
value = entry.getValue().toString();
}
// Model Metadata uses a delimiter to split the input string in its fromString method
// https://github.com/opensearch-project/k-NN/blob/2.12/src/main/java/org/opensearch/knn/indices/ModelMetadata.java#L265
// If any of the values in the method component context contain this delimiter,
// then the method will not work correctly. Therefore, we replace the delimiter with an uncommon
// sequence that is very unlikely to appear in the value itself.
// https://github.com/opensearch-project/k-NN/issues/1337
value = value.replace(ModelMetadata.DELIMITER, DELIMITER_PLACEHOLDER);
stringBuilder.append(value).append(DELIMITER);
}
}
stringBuilder.append("]}");
return stringBuilder.toString();
}

/**
* This method converts a string created by the toClusterStateString() method of MethodComponentContext
* to a MethodComponentContext object.
*
* @param in a string representation of MethodComponentContext
* @return a MethodComponentContext object
*/
public static MethodComponentContext fromClusterStateString(String in) {
String stringToParse = unwrapString(in, '{', '}');

// Parse name from string
String[] nameAndParameters = stringToParse.split(DELIMITER, 2);
checkExpectedArrayLength(nameAndParameters, 2);
String name = parseName(nameAndParameters[0]);
String parametersString = nameAndParameters[1];
Map<String, Object> parameters = parseParameters(parametersString);
return new MethodComponentContext(name, parameters);
}

private static String parseName(String candidateNameString) {
// Expecting candidateNameString to look like "name=ivf"
checkStringNotEmpty(candidateNameString);
String[] nameKeyAndValue = candidateNameString.split("=");
checkStringMatches(nameKeyAndValue[0], "name");
if (nameKeyAndValue.length == 1) {
return "";
}
checkExpectedArrayLength(nameKeyAndValue, 2);
return nameKeyAndValue[1];
}

private static Map<String, Object> parseParameters(String candidateParameterString) {
checkStringNotEmpty(candidateParameterString);
String[] parametersKeyAndValue = candidateParameterString.split("=", 2);
checkStringMatches(parametersKeyAndValue[0], "parameters");
if (parametersKeyAndValue.length == 1) {
return Collections.emptyMap();
}
checkExpectedArrayLength(parametersKeyAndValue, 2);
return parseParametersValue(parametersKeyAndValue[1]);
}

private static Map<String, Object> parseParametersValue(String candidateParameterValueString) {
// Expected input is [nlist=4;type=fp16;encoder={name=sq;parameters=[nprobes=2;clip=false;]};]
checkStringNotEmpty(candidateParameterValueString);
candidateParameterValueString = unwrapString(candidateParameterValueString, '[', ']');
Map<String, Object> parameters = new HashMap<>();
while (!candidateParameterValueString.isEmpty()) {
String[] keyAndValueToParse = candidateParameterValueString.split("=", 2);
if (keyAndValueToParse.length == 1 && keyAndValueToParse[0].charAt(0) == ';') {
break;
}
String key = keyAndValueToParse[0];
ValueAndRestToParse parsed = parseParameterValueAndRestToParse(keyAndValueToParse[1]);
parameters.put(key, parsed.getValue());
candidateParameterValueString = parsed.getRestToParse();
}

return parameters;
}

private static ValueAndRestToParse parseParameterValueAndRestToParse(String candidateParameterValueAndRestToParse) {
if (candidateParameterValueAndRestToParse.charAt(0) == '{') {
int endOfNestedMap = findClosingPosition(candidateParameterValueAndRestToParse, '{', '}');
String nestedMethodContext = candidateParameterValueAndRestToParse.substring(0, endOfNestedMap + 1);
Object nestedParse = fromClusterStateString(nestedMethodContext);
String restToParse = candidateParameterValueAndRestToParse.substring(endOfNestedMap + 1);
return new ValueAndRestToParse(nestedParse, restToParse);
}

String[] stringValueAndRestToParse = candidateParameterValueAndRestToParse.split(DELIMITER, 2);
String stringValue = stringValueAndRestToParse[0];
Object value;
if (NumberUtils.isNumber(stringValue)) {
value = Integer.parseInt(stringValue);
} else if (stringValue.equals("true") || stringValue.equals("false")) {
value = Boolean.parseBoolean(stringValue);
} else {
stringValue = stringValue.replace(DELIMITER_PLACEHOLDER, ModelMetadata.DELIMITER);
value = stringValue;
}

return new ValueAndRestToParse(value, stringValueAndRestToParse[1]);
}

private static String unwrapString(String in, char expectedStart, char expectedEnd) {
if (in.length() < 2) {
throw new IllegalArgumentException("Invalid string.");
}

if (in.charAt(0) != expectedStart || in.charAt(in.length() - 1) != expectedEnd) {
throw new IllegalArgumentException("Invalid string." + in);
}
return in.substring(1, in.length() - 1);
}

private static int findClosingPosition(String in, char expectedStart, char expectedEnd) {
int nestedLevel = 0;
for (int i = 0; i < in.length(); i++) {
if (in.charAt(i) == expectedStart) {
nestedLevel++;
continue;
}

if (in.charAt(i) == expectedEnd) {
nestedLevel--;
}

if (nestedLevel == 0) {
return i;
}
}

throw new IllegalArgumentException("Invalid string. No end to the nesting");
}

private static void checkStringNotEmpty(String string) {
if (string.isEmpty()) {
throw new IllegalArgumentException("Unable to parse MethodComponentContext");
}
}

private static void checkStringMatches(String string, String expected) {
if (!Objects.equals(string, expected)) {
throw new IllegalArgumentException("Unexpected key in MethodComponentContext. Expected 'name' or 'parameters'");
}
}

private static void checkExpectedArrayLength(String[] array, int expectedLength) {
if (null == array) {
throw new IllegalArgumentException("Error parsing MethodComponentContext. Array is null.");
}

if (array.length != expectedLength) {
throw new IllegalArgumentException("Error parsing MethodComponentContext. Array is not expected length.");
}
}

@AllArgsConstructor
@Getter
private static class ValueAndRestToParse {
private final Object value;
private final String restToParse;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(this.name);
Expand Down
11 changes: 11 additions & 0 deletions src/main/java/org/opensearch/knn/indices/ModelDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,14 @@
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.common.exception.DeleteModelWhenInTrainStateException;
import org.opensearch.knn.index.MethodComponentContext;
import org.opensearch.knn.plugin.transport.DeleteModelResponse;
import org.opensearch.knn.plugin.transport.GetModelResponse;
import org.opensearch.knn.plugin.transport.RemoveModelFromCacheAction;
Expand Down Expand Up @@ -288,6 +292,13 @@ private void putInternal(Model model, ActionListener<IndexResponse> listener, Do
put(KNNConstants.MODEL_DESCRIPTION, modelMetadata.getDescription());
put(KNNConstants.MODEL_ERROR, modelMetadata.getError());
put(KNNConstants.MODEL_NODE_ASSIGNMENT, modelMetadata.getNodeAssignment());

MethodComponentContext methodComponentContext = modelMetadata.getMethodComponentContext();
if (!methodComponentContext.getName().isEmpty()) {
XContentBuilder builder = XContentFactory.jsonBuilder().startObject();
builder = methodComponentContext.toXContent(builder, ToXContent.EMPTY_PARAMS).endObject();
put(KNNConstants.MODEL_METHOD_COMPONENT_CONTEXT, builder.toString());
}
}
};

Expand Down
Loading

0 comments on commit 4734d88

Please sign in to comment.