Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inference_metadata_fields] Clear inference results on explicit nulls #119145

Open
wants to merge 10 commits into
base: inference_metadata_fields
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.xcontent.XContent;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
Expand All @@ -50,6 +51,7 @@
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -67,6 +69,8 @@
*/
public class ShardBulkInferenceActionFilter implements MappedActionFilter {
protected static final int DEFAULT_BATCH_SIZE = 512;
private static final Object EXPLICIT_NULL = new Object();
private static final ChunkedInference EMPTY_CHUNKED_INFERENCE = new EmptyChunkedInference();

private final ClusterService clusterService;
private final InferenceServiceRegistry inferenceServiceRegistry;
Expand Down Expand Up @@ -393,11 +397,22 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
for (var entry : response.responses.entrySet()) {
var fieldName = entry.getKey();
var responses = entry.getValue();
var model = responses.get(0).model();
Model model = null;

InferenceFieldMetadata inferenceFieldMetadata = fieldInferenceMap.get(fieldName);
if (inferenceFieldMetadata == null) {
throw new IllegalStateException("No inference field metadata for field [" + fieldName + "]");
}

// ensure that the order in the original field is consistent in case of multiple inputs
Collections.sort(responses, Comparator.comparingInt(FieldInferenceResponse::inputOrder));
Map<String, List<SemanticTextField.Chunk>> chunkMap = new LinkedHashMap<>();
for (var resp : responses) {
// Get the first non-null model from the response list
if (model == null) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to do any additional validation here, to verify that model if it exists is compatible with resp.model?

model = resp.model;
}

var lst = chunkMap.computeIfAbsent(resp.sourceField, k -> new ArrayList<>());
lst.addAll(
SemanticTextField.toSemanticTextFieldChunks(
Expand All @@ -409,21 +424,26 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
)
);
}

List<String> inputs = responses.stream()
.filter(r -> r.sourceField().equals(fieldName))
.map(r -> r.input)
.collect(Collectors.toList());

// The model can be null if we are only processing update requests that clear inference results. This is ok because we will
// merge in the field's existing model settings on the data node.
var result = new SemanticTextField(
useLegacyFormat,
fieldName,
useLegacyFormat ? inputs : null,
new SemanticTextField.InferenceResult(
model.getInferenceEntityId(),
new SemanticTextField.ModelSettings(model),
inferenceFieldMetadata.getInferenceId(),
model != null ? new SemanticTextField.ModelSettings(model) : null,
chunkMap
),
indexRequest.getContentType()
);

if (useLegacyFormat) {
SemanticTextUtils.insertValue(fieldName, newDocMap, result);
} else {
Expand Down Expand Up @@ -490,7 +510,8 @@ private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(Bu
} else {
var inferenceMetadataFieldsValue = XContentMapValues.extractValue(
InferenceMetadataFieldsMapper.NAME + "." + field,
docMap
docMap,
EXPLICIT_NULL
);
if (inferenceMetadataFieldsValue != null) {
// Inference has already been computed
Expand All @@ -500,9 +521,22 @@ private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(Bu

int order = 0;
for (var sourceField : entry.getSourceFields()) {
// TODO: Detect when the field is provided with an explicit null value
var valueObj = XContentMapValues.extractValue(sourceField, docMap);
if (valueObj == null) {
var valueObj = XContentMapValues.extractValue(sourceField, docMap, EXPLICIT_NULL);
if (useLegacyFormat == false && isUpdateRequest && valueObj == EXPLICIT_NULL) {
/**
* It's an update request, and the source field is explicitly set to null,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice comment here!

* so we need to propagate this information to the inference fields metadata
* to overwrite any inference previously computed on the field.
* This ensures that the field is treated as intentionally cleared,
* preventing any unintended carryover of prior inference results.
*/
var slot = ensureResponseAccumulatorSlot(itemIndex);
slot.addOrUpdateResponse(
new FieldInferenceResponse(field, sourceField, null, order++, 0, null, EMPTY_CHUNKED_INFERENCE)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're incrementing order here and we're still incrementing it later on on line 566 in existing code. Do we need to reset the value of order before we iterate through values here? It's a bit confusing on read through.

);
continue;
}
if (valueObj == null || valueObj == EXPLICIT_NULL) {
if (isUpdateRequest && useLegacyFormat) {
addInferenceResponseFailure(
item.id(),
Expand Down Expand Up @@ -552,4 +586,11 @@ static IndexRequest getIndexRequestOrNull(DocWriteRequest<?> docWriteRequest) {
return null;
}
}

private static class EmptyChunkedInference implements ChunkedInference {
@Override
public Iterator<Chunk> chunksAsMatchedTextAndByteReference(XContent xcontent) {
return Collections.emptyIterator();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
);

INFERENCE_RESULT_PARSER.declareString(constructorArg(), new ParseField(INFERENCE_ID_FIELD));
INFERENCE_RESULT_PARSER.declareObject(
INFERENCE_RESULT_PARSER.declareObjectOrNull(
constructorArg(),
(p, c) -> MODEL_SETTINGS_PARSER.parse(p, null),
null,
new ParseField(MODEL_SETTINGS_FIELD)
);
INFERENCE_RESULT_PARSER.declareField(constructorArg(), (p, c) -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -895,7 +895,7 @@ private static boolean canMergeModelSettings(
if (Objects.equals(previous, current)) {
return true;
}
if (previous == null) {
if (previous == null ^ current == null) {
return true;
}
conflicts.addConflict("model_settings", "");
Expand Down
Mikep86 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.support.ActionFilterChain;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.action.update.UpdateRequest;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.IndexMetadata;
Expand Down Expand Up @@ -205,6 +206,11 @@ public void testItemFailures() throws Exception {
XContentMapValues.extractValue(useLegacyFormat ? "field1.text" : "field1", actualRequest.sourceAsMap()),
equalTo("I am a success")
);
if (useInferenceMetadataFieldsFormat) {
assertNotNull(
XContentMapValues.extractValue(InferenceMetadataFieldsMapper.NAME + ".field1", actualRequest.sourceAsMap())
);
}

// item 2 is a failure
assertNotNull(bulkShardRequest.items()[2].getPrimaryResponse());
Expand Down Expand Up @@ -232,6 +238,98 @@ public void testItemFailures() throws Exception {
awaitLatch(chainExecuted, 10, TimeUnit.SECONDS);
}

@SuppressWarnings({ "unchecked", "rawtypes" })
public void testExplicitNull() throws Exception {
StaticModel model = StaticModel.createRandomInstance();

ShardBulkInferenceActionFilter filter = createFilter(
threadPool,
Map.of(model.getInferenceEntityId(), model),
randomIntBetween(1, 10),
IndexVersion.current()
);
model.putResult("I am a failure", new ChunkedInferenceError(new IllegalArgumentException("boom")));
model.putResult("I am a success", randomChunkedInferenceEmbeddingSparse(List.of("I am a success")));
CountDownLatch chainExecuted = new CountDownLatch(1);
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
try {
BulkShardRequest bulkShardRequest = (BulkShardRequest) request;
assertNull(bulkShardRequest.getInferenceFieldMap());
assertThat(bulkShardRequest.items().length, equalTo(5));

Object explicitNull = new Object();
// item 0
assertNull(bulkShardRequest.items()[0].getPrimaryResponse());
IndexRequest actualRequest = getIndexRequestOrNull(bulkShardRequest.items()[0].request());
assertTrue(XContentMapValues.extractValue("obj.field1", actualRequest.sourceAsMap(), explicitNull) == explicitNull);
assertNull(XContentMapValues.extractValue(InferenceMetadataFieldsMapper.NAME, actualRequest.sourceAsMap(), explicitNull));

// item 1 is a success
assertNull(bulkShardRequest.items()[1].getPrimaryResponse());
actualRequest = getIndexRequestOrNull(bulkShardRequest.items()[1].request());
assertThat(XContentMapValues.extractValue("obj.field1", actualRequest.sourceAsMap()), equalTo("I am a success"));
assertNotNull(
XContentMapValues.extractValue(
InferenceMetadataFieldsMapper.NAME + ".obj.field1",
actualRequest.sourceAsMap(),
explicitNull
)
);

// item 2 is a failure
assertNotNull(bulkShardRequest.items()[2].getPrimaryResponse());
assertTrue(bulkShardRequest.items()[2].getPrimaryResponse().isFailed());
var failure = bulkShardRequest.items()[2].getPrimaryResponse().getFailure();
assertThat(failure.getCause().getCause().getMessage(), containsString("boom"));

// item 3
assertNull(bulkShardRequest.items()[3].getPrimaryResponse());
actualRequest = getIndexRequestOrNull(bulkShardRequest.items()[3].request());
assertTrue(XContentMapValues.extractValue("obj.field1", actualRequest.sourceAsMap(), explicitNull) == explicitNull);
assertTrue(
XContentMapValues.extractValue(
InferenceMetadataFieldsMapper.NAME + ".obj.field1",
actualRequest.sourceAsMap(),
explicitNull
) == explicitNull
);

// item 4
assertNull(bulkShardRequest.items()[4].getPrimaryResponse());
actualRequest = getIndexRequestOrNull(bulkShardRequest.items()[4].request());
assertNull(XContentMapValues.extractValue("obj.field1", actualRequest.sourceAsMap(), explicitNull));
assertNull(
XContentMapValues.extractValue(
InferenceMetadataFieldsMapper.NAME + ".obj.field1",
actualRequest.sourceAsMap(),
explicitNull
)
);
} finally {
chainExecuted.countDown();
}
};
ActionListener actionListener = mock(ActionListener.class);
Task task = mock(Task.class);

Map<String, InferenceFieldMetadata> inferenceFieldMap = Map.of(
"obj.field1",
new InferenceFieldMetadata("obj.field1", model.getInferenceEntityId(), new String[] { "obj.field1" })
);
BulkItemRequest[] items = new BulkItemRequest[5];
Map<String, Object> sourceWithNull = new HashMap<>();
sourceWithNull.put("field1", null);
items[0] = new BulkItemRequest(0, new IndexRequest("index").source(Map.of("obj", sourceWithNull)));
items[1] = new BulkItemRequest(1, new IndexRequest("index").source("obj.field1", "I am a success"));
items[2] = new BulkItemRequest(2, new IndexRequest("index").source("obj.field1", "I am a failure"));
items[3] = new BulkItemRequest(3, new UpdateRequest().doc(new IndexRequest("index").source(Map.of("obj", sourceWithNull))));
items[4] = new BulkItemRequest(3, new UpdateRequest().doc(new IndexRequest("index").source(Map.of("field2", "value"))));
BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items);
request.setInferenceFieldMap(inferenceFieldMap);
filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain);
awaitLatch(chainExecuted, 10, TimeUnit.SECONDS);
}

@SuppressWarnings({ "unchecked", "rawtypes" })
public void testManyRandomDocs() throws Exception {
Map<String, StaticModel> inferenceModelMap = new HashMap<>();
Expand Down
Loading