From 77c361cba83f767ee1c3a46ddf2310f72ac5e52c Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Wed, 18 Dec 2024 12:46:56 +0000 Subject: [PATCH 1/8] Fix handling of explicit null values for semantic text fields Previously, setting a field explicitly to null in an update request did not work correctly with semantic text fields. This change resolves the issue by adding an explicit null entry to the `_inference_fields` metadata when such cases occur. The explicit null value ensures that any prior inference results are overwritten during the merge of the partial update with the latest document version. --- .../ShardBulkInferenceActionFilter.java | 33 +++++-- .../ShardBulkInferenceActionFilterTests.java | 85 +++++++++++++++++++ 2 files changed, 113 insertions(+), 5 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index b76a39a0f2ac2..86bf4a925121a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -394,6 +394,16 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons for (var entry : response.responses.entrySet()) { var fieldName = entry.getKey(); var responses = entry.getValue(); + if (responses == null) { + if (item.request() instanceof UpdateRequest == false) { + // could be an assert + throw new IllegalArgumentException( + "Inference results can only be cleared for update requests where a field is explicitly set to null." + ); + } + inferenceFieldsMap.put(fieldName, null); + continue; + } var model = responses.get(0).model(); // ensure that the order in the original field is consistent in case of multiple inputs Collections.sort(responses, Comparator.comparingInt(FieldInferenceResponse::inputOrder)); @@ -480,6 +490,7 @@ private Map> createFieldInferenceRequests(Bu } final Map docMap = indexRequest.sourceAsMap(); + Object explicitNull = new Object(); for (var entry : fieldInferenceMap.values()) { String field = entry.getName(); String inferenceId = entry.getInferenceId(); @@ -487,10 +498,11 @@ private Map> createFieldInferenceRequests(Bu if (useInferenceMetadataFieldsFormat) { var inferenceMetadataFieldsValue = XContentMapValues.extractValue( InferenceMetadataFieldsMapper.NAME + "." + field, - docMap + docMap, + explicitNull ); if (inferenceMetadataFieldsValue != null) { - // Inference has already been computed + // Inference has already been computed for this source field continue; } } else { @@ -503,9 +515,20 @@ private Map> createFieldInferenceRequests(Bu int order = 0; for (var sourceField : entry.getSourceFields()) { - // TODO: Detect when the field is provided with an explicit null value - var valueObj = XContentMapValues.extractValue(sourceField, docMap); - if (valueObj == null) { + var valueObj = XContentMapValues.extractValue(sourceField, docMap, explicitNull); + if (useInferenceMetadataFieldsFormat && isUpdateRequest && valueObj == explicitNull) { + /** + * It's an update request, and the source field is explicitly set to null, + * so we need to propagate this information to the inference fields metadata + * to overwrite any inference previously computed on the field. + * This ensures that the field is treated as intentionally cleared, + * preventing any unintended carryover of prior inference results. + */ + var slot = ensureResponseAccumulatorSlot(itemIndex); + slot.responses.put(sourceField, null); + continue; + } + if (valueObj == null || valueObj == explicitNull) { if (isUpdateRequest && (useInferenceMetadataFieldsFormat == false)) { addInferenceResponseFailure( item.id(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 73fe792664071..d0e5b428f3422 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -16,6 +16,7 @@ import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.support.ActionFilterChain; import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.action.update.UpdateRequest; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexMetadata; @@ -212,6 +213,11 @@ public void testItemFailures() throws Exception { ), equalTo("I am a success") ); + if (useInferenceMetadataFieldsFormat) { + assertNotNull( + XContentMapValues.extractValue(InferenceMetadataFieldsMapper.NAME + ".field1", actualRequest.sourceAsMap()) + ); + } // item 2 is a failure assertNotNull(bulkShardRequest.items()[2].getPrimaryResponse()); @@ -239,6 +245,85 @@ public void testItemFailures() throws Exception { awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); } + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void testExplicitNull() throws Exception { + StaticModel model = StaticModel.createRandomInstance(); + + ShardBulkInferenceActionFilter filter = createFilter( + threadPool, + Map.of(model.getInferenceEntityId(), model), + randomIntBetween(1, 10), + IndexVersion.current() + ); + model.putResult("I am a failure", new ChunkedInferenceError(new IllegalArgumentException("boom"))); + model.putResult("I am a success", randomChunkedInferenceEmbeddingSparse(List.of("I am a success"))); + CountDownLatch chainExecuted = new CountDownLatch(1); + ActionFilterChain actionFilterChain = (task, action, request, listener) -> { + try { + BulkShardRequest bulkShardRequest = (BulkShardRequest) request; + assertNull(bulkShardRequest.getInferenceFieldMap()); + assertThat(bulkShardRequest.items().length, equalTo(4)); + + Object explicitNull = new Object(); + // item 0 + assertNull(bulkShardRequest.items()[0].getPrimaryResponse()); + IndexRequest actualRequest = getIndexRequestOrNull(bulkShardRequest.items()[0].request()); + assertTrue(XContentMapValues.extractValue("field1", actualRequest.sourceAsMap(), explicitNull) == explicitNull); + assertNull(XContentMapValues.extractValue(InferenceMetadataFieldsMapper.NAME, actualRequest.sourceAsMap(), explicitNull)); + + // item 1 is a success + assertNull(bulkShardRequest.items()[1].getPrimaryResponse()); + actualRequest = getIndexRequestOrNull(bulkShardRequest.items()[1].request()); + assertThat(XContentMapValues.extractValue("field1", actualRequest.sourceAsMap()), equalTo("I am a success")); + assertNotNull( + XContentMapValues.extractValue( + InferenceMetadataFieldsMapper.NAME + ".field1", + actualRequest.sourceAsMap(), + explicitNull + ) + ); + + // item 2 is a failure + assertNotNull(bulkShardRequest.items()[2].getPrimaryResponse()); + assertTrue(bulkShardRequest.items()[2].getPrimaryResponse().isFailed()); + var failure = bulkShardRequest.items()[2].getPrimaryResponse().getFailure(); + assertThat(failure.getCause().getCause().getMessage(), containsString("boom")); + + // item 3 + assertNull(bulkShardRequest.items()[3].getPrimaryResponse()); + actualRequest = getIndexRequestOrNull(bulkShardRequest.items()[3].request()); + assertTrue(XContentMapValues.extractValue("field1", actualRequest.sourceAsMap(), explicitNull) == explicitNull); + assertTrue( + XContentMapValues.extractValue( + InferenceMetadataFieldsMapper.NAME + ".field1", + actualRequest.sourceAsMap(), + explicitNull + ) == explicitNull + ); + } finally { + chainExecuted.countDown(); + } + }; + ActionListener actionListener = mock(ActionListener.class); + Task task = mock(Task.class); + + Map inferenceFieldMap = Map.of( + "field1", + new InferenceFieldMetadata("field1", model.getInferenceEntityId(), new String[] { "field1" }) + ); + BulkItemRequest[] items = new BulkItemRequest[4]; + Map sourceWithNull = new HashMap<>(); + sourceWithNull.put("field1", null); + items[0] = new BulkItemRequest(0, new IndexRequest("index").source(sourceWithNull)); + items[1] = new BulkItemRequest(1, new IndexRequest("index").source("field1", "I am a success")); + items[2] = new BulkItemRequest(2, new IndexRequest("index").source("field1", "I am a failure")); + items[3] = new BulkItemRequest(3, new UpdateRequest().doc(new IndexRequest("index").source(sourceWithNull))); + BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items); + request.setInferenceFieldMap(inferenceFieldMap); + filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); + awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + } + @SuppressWarnings({ "unchecked", "rawtypes" }) public void testManyRandomDocs() throws Exception { IndexVersion indexVersion = getRandomIndexVersion(); From 60a4a1d3246e3a8c98a303102b29c31a95d485e7 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Wed, 18 Dec 2024 13:07:26 +0000 Subject: [PATCH 2/8] improve test --- .../ShardBulkInferenceActionFilterTests.java | 39 ++++++++++++------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index d0e5b428f3422..c69545ce096ed 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -262,22 +262,22 @@ public void testExplicitNull() throws Exception { try { BulkShardRequest bulkShardRequest = (BulkShardRequest) request; assertNull(bulkShardRequest.getInferenceFieldMap()); - assertThat(bulkShardRequest.items().length, equalTo(4)); + assertThat(bulkShardRequest.items().length, equalTo(5)); Object explicitNull = new Object(); // item 0 assertNull(bulkShardRequest.items()[0].getPrimaryResponse()); IndexRequest actualRequest = getIndexRequestOrNull(bulkShardRequest.items()[0].request()); - assertTrue(XContentMapValues.extractValue("field1", actualRequest.sourceAsMap(), explicitNull) == explicitNull); + assertTrue(XContentMapValues.extractValue("obj.field1", actualRequest.sourceAsMap(), explicitNull) == explicitNull); assertNull(XContentMapValues.extractValue(InferenceMetadataFieldsMapper.NAME, actualRequest.sourceAsMap(), explicitNull)); // item 1 is a success assertNull(bulkShardRequest.items()[1].getPrimaryResponse()); actualRequest = getIndexRequestOrNull(bulkShardRequest.items()[1].request()); - assertThat(XContentMapValues.extractValue("field1", actualRequest.sourceAsMap()), equalTo("I am a success")); + assertThat(XContentMapValues.extractValue("obj.field1", actualRequest.sourceAsMap()), equalTo("I am a success")); assertNotNull( XContentMapValues.extractValue( - InferenceMetadataFieldsMapper.NAME + ".field1", + InferenceMetadataFieldsMapper.NAME + ".obj.field1", actualRequest.sourceAsMap(), explicitNull ) @@ -292,14 +292,26 @@ public void testExplicitNull() throws Exception { // item 3 assertNull(bulkShardRequest.items()[3].getPrimaryResponse()); actualRequest = getIndexRequestOrNull(bulkShardRequest.items()[3].request()); - assertTrue(XContentMapValues.extractValue("field1", actualRequest.sourceAsMap(), explicitNull) == explicitNull); + assertTrue(XContentMapValues.extractValue("obj.field1", actualRequest.sourceAsMap(), explicitNull) == explicitNull); assertTrue( XContentMapValues.extractValue( - InferenceMetadataFieldsMapper.NAME + ".field1", + InferenceMetadataFieldsMapper.NAME + ".obj.field1", actualRequest.sourceAsMap(), explicitNull ) == explicitNull ); + + // item 4 + assertNull(bulkShardRequest.items()[4].getPrimaryResponse()); + actualRequest = getIndexRequestOrNull(bulkShardRequest.items()[4].request()); + assertNull(XContentMapValues.extractValue("obj.field1", actualRequest.sourceAsMap(), explicitNull)); + assertNull( + XContentMapValues.extractValue( + InferenceMetadataFieldsMapper.NAME + ".obj.field1", + actualRequest.sourceAsMap(), + explicitNull + ) + ); } finally { chainExecuted.countDown(); } @@ -308,16 +320,17 @@ public void testExplicitNull() throws Exception { Task task = mock(Task.class); Map inferenceFieldMap = Map.of( - "field1", - new InferenceFieldMetadata("field1", model.getInferenceEntityId(), new String[] { "field1" }) + "obj.field1", + new InferenceFieldMetadata("obj.field1", model.getInferenceEntityId(), new String[] { "obj.field1" }) ); - BulkItemRequest[] items = new BulkItemRequest[4]; + BulkItemRequest[] items = new BulkItemRequest[5]; Map sourceWithNull = new HashMap<>(); sourceWithNull.put("field1", null); - items[0] = new BulkItemRequest(0, new IndexRequest("index").source(sourceWithNull)); - items[1] = new BulkItemRequest(1, new IndexRequest("index").source("field1", "I am a success")); - items[2] = new BulkItemRequest(2, new IndexRequest("index").source("field1", "I am a failure")); - items[3] = new BulkItemRequest(3, new UpdateRequest().doc(new IndexRequest("index").source(sourceWithNull))); + items[0] = new BulkItemRequest(0, new IndexRequest("index").source(Map.of("obj", sourceWithNull))); + items[1] = new BulkItemRequest(1, new IndexRequest("index").source("obj.field1", "I am a success")); + items[2] = new BulkItemRequest(2, new IndexRequest("index").source("obj.field1", "I am a failure")); + items[3] = new BulkItemRequest(3, new UpdateRequest().doc(new IndexRequest("index").source(Map.of("obj", sourceWithNull)))); + items[4] = new BulkItemRequest(3, new UpdateRequest().doc(new IndexRequest("index").source(Map.of("field2", "value")))); BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items); request.setInferenceFieldMap(inferenceFieldMap); filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); From af5a348ba1ef6f0d33b9bf4638dee84d3e86e1aa Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Thu, 19 Dec 2024 10:47:08 -0500 Subject: [PATCH 3/8] Create a field inference response with empty chunks --- .../ShardBulkInferenceActionFilter.java | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 86bf4a925121a..fb90e6a2871c1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -40,6 +40,7 @@ import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; +import org.elasticsearch.xcontent.XContent; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; @@ -51,6 +52,7 @@ import java.util.Collections; import java.util.Comparator; import java.util.HashMap; +import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -68,6 +70,8 @@ */ public class ShardBulkInferenceActionFilter implements MappedActionFilter { protected static final int DEFAULT_BATCH_SIZE = 512; + private static final Object EXPLICIT_NULL = new Object(); + private static final ChunkedInference EMPTY_CHUNKED_INFERENCE = new EmptyChunkedInference(); private final ClusterService clusterService; private final InferenceServiceRegistry inferenceServiceRegistry; @@ -394,16 +398,6 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons for (var entry : response.responses.entrySet()) { var fieldName = entry.getKey(); var responses = entry.getValue(); - if (responses == null) { - if (item.request() instanceof UpdateRequest == false) { - // could be an assert - throw new IllegalArgumentException( - "Inference results can only be cleared for update requests where a field is explicitly set to null." - ); - } - inferenceFieldsMap.put(fieldName, null); - continue; - } var model = responses.get(0).model(); // ensure that the order in the original field is consistent in case of multiple inputs Collections.sort(responses, Comparator.comparingInt(FieldInferenceResponse::inputOrder)); @@ -490,7 +484,6 @@ private Map> createFieldInferenceRequests(Bu } final Map docMap = indexRequest.sourceAsMap(); - Object explicitNull = new Object(); for (var entry : fieldInferenceMap.values()) { String field = entry.getName(); String inferenceId = entry.getInferenceId(); @@ -499,10 +492,10 @@ private Map> createFieldInferenceRequests(Bu var inferenceMetadataFieldsValue = XContentMapValues.extractValue( InferenceMetadataFieldsMapper.NAME + "." + field, docMap, - explicitNull + EXPLICIT_NULL ); if (inferenceMetadataFieldsValue != null) { - // Inference has already been computed for this source field + // Inference has already been computed continue; } } else { @@ -515,8 +508,8 @@ private Map> createFieldInferenceRequests(Bu int order = 0; for (var sourceField : entry.getSourceFields()) { - var valueObj = XContentMapValues.extractValue(sourceField, docMap, explicitNull); - if (useInferenceMetadataFieldsFormat && isUpdateRequest && valueObj == explicitNull) { + var valueObj = XContentMapValues.extractValue(sourceField, docMap, EXPLICIT_NULL); + if (useInferenceMetadataFieldsFormat && isUpdateRequest && valueObj == EXPLICIT_NULL) { /** * It's an update request, and the source field is explicitly set to null, * so we need to propagate this information to the inference fields metadata @@ -525,10 +518,10 @@ private Map> createFieldInferenceRequests(Bu * preventing any unintended carryover of prior inference results. */ var slot = ensureResponseAccumulatorSlot(itemIndex); - slot.responses.put(sourceField, null); + slot.addOrUpdateResponse(new FieldInferenceResponse(field, sourceField, null, order++, 0, null, EMPTY_CHUNKED_INFERENCE)); continue; } - if (valueObj == null || valueObj == explicitNull) { + if (valueObj == null || valueObj == EXPLICIT_NULL) { if (isUpdateRequest && (useInferenceMetadataFieldsFormat == false)) { addInferenceResponseFailure( item.id(), @@ -578,4 +571,11 @@ static IndexRequest getIndexRequestOrNull(DocWriteRequest docWriteRequest) { return null; } } + + private static class EmptyChunkedInference implements ChunkedInference { + @Override + public Iterator chunksAsMatchedTextAndByteReference(XContent xcontent) { + return Collections.emptyIterator(); + } + } } From 8d0c1460c1f1c2397d50e2ab245bd45842c67710 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Thu, 19 Dec 2024 14:18:30 -0500 Subject: [PATCH 4/8] Allow model settings to be null --- .../ShardBulkInferenceActionFilter.java | 26 ++++++++++++++++--- .../inference/mapper/SemanticTextField.java | 3 ++- .../mapper/SemanticTextFieldMapper.java | 2 +- 3 files changed, 25 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index fb90e6a2871c1..320a5c8816cf8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -398,11 +398,22 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons for (var entry : response.responses.entrySet()) { var fieldName = entry.getKey(); var responses = entry.getValue(); - var model = responses.get(0).model(); + Model model = null; + + InferenceFieldMetadata inferenceFieldMetadata = fieldInferenceMap.get(fieldName); + if (inferenceFieldMetadata == null) { + throw new IllegalStateException("No inference field metadata for field [" + fieldName + "]"); + } + // ensure that the order in the original field is consistent in case of multiple inputs Collections.sort(responses, Comparator.comparingInt(FieldInferenceResponse::inputOrder)); Map> chunkMap = new LinkedHashMap<>(); for (var resp : responses) { + // Get the first non-null model from the response list + if (model == null) { + model = resp.model; + } + var lst = chunkMap.computeIfAbsent(resp.sourceField, k -> new ArrayList<>()); lst.addAll( SemanticTextField.toSemanticTextFieldChunks( @@ -414,21 +425,26 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons ) ); } + List inputs = responses.stream() .filter(r -> r.sourceField().equals(fieldName)) .map(r -> r.input) .collect(Collectors.toList()); + + // The model can be null if we are only processing update requests that clear inference results. This is ok because we will + // merge in the field's existing model settings on the data node. var result = new SemanticTextField( indexCreatedVersion, fieldName, addMetadataField ? null : inputs, new SemanticTextField.InferenceResult( - model.getInferenceEntityId(), - new SemanticTextField.ModelSettings(model), + inferenceFieldMetadata.getInferenceId(), + model != null ? new SemanticTextField.ModelSettings(model) : null, chunkMap ), indexRequest.getContentType() ); + if (addMetadataField) { inferenceFieldsMap.put(fieldName, result); } else { @@ -518,7 +534,9 @@ private Map> createFieldInferenceRequests(Bu * preventing any unintended carryover of prior inference results. */ var slot = ensureResponseAccumulatorSlot(itemIndex); - slot.addOrUpdateResponse(new FieldInferenceResponse(field, sourceField, null, order++, 0, null, EMPTY_CHUNKED_INFERENCE)); + slot.addOrUpdateResponse( + new FieldInferenceResponse(field, sourceField, null, order++, 0, null, EMPTY_CHUNKED_INFERENCE) + ); continue; } if (valueObj == null || valueObj == EXPLICIT_NULL) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java index 5f63d65ae5062..40902209dc9dc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java @@ -349,9 +349,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws ); INFERENCE_RESULT_PARSER.declareString(constructorArg(), new ParseField(INFERENCE_ID_FIELD)); - INFERENCE_RESULT_PARSER.declareObject( + INFERENCE_RESULT_PARSER.declareObjectOrNull( constructorArg(), (p, c) -> MODEL_SETTINGS_PARSER.parse(p, null), + null, new ParseField(MODEL_SETTINGS_FIELD) ); INFERENCE_RESULT_PARSER.declareField(constructorArg(), (p, c) -> { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index b695cf2ee2fb2..dc4e4adebefa4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -886,7 +886,7 @@ private static boolean canMergeModelSettings( if (Objects.equals(previous, current)) { return true; } - if (previous == null) { + if (previous == null ^ current == null) { return true; } conflicts.addConflict("model_settings", ""); From e4fe4945fa9dffaab56ce84862b668f57a271ab5 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Thu, 19 Dec 2024 14:52:19 -0500 Subject: [PATCH 5/8] Refactor "Bypass inference on bulk update operation" into two tests --- .../60_semantic_text_inference_update.yml | 191 ++++++++++-------- .../60_semantic_text_inference_update_bwc.yml | 25 +++ 2 files changed, 135 insertions(+), 81 deletions(-) diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/60_semantic_text_inference_update.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/60_semantic_text_inference_update.yml index 660d3e37f4242..c3d0a3c272a77 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/60_semantic_text_inference_update.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/60_semantic_text_inference_update.yml @@ -819,84 +819,113 @@ setup: - match: { hits.hits.0._source._inference_fields.dense_field.inference.chunks.dense_field.0.start_offset: 0 } - match: { hits.hits.0._source._inference_fields.dense_field.inference.chunks.dense_field.0.end_offset: 30 } -# TODO: Uncomment this test once we implement a fix -#--- -#"Bypass inference on bulk update operation": -# # Update as upsert -# - do: -# bulk: -# body: -# - '{"update": {"_index": "test-index", "_id": "doc_1"}}' -# - '{"doc": { "sparse_field": "inference test", "dense_field": "another inference test", "non_inference_field": "non inference test" }, "doc_as_upsert": true}' -# -# - match: { errors: false } -# - match: { items.0.update.result: "created" } -# -# - do: -# bulk: -# body: -# - '{"update": {"_index": "test-index", "_id": "doc_1"}}' -# - '{"doc": { "non_inference_field": "another value" }, "doc_as_upsert": true}' -# refresh: true -# -# - match: { errors: false } -# - match: { items.0.update.result: "updated" } -# -# - do: -# search: -# index: test-index -# body: -# fields: [ _inference_fields ] -# query: -# match_all: { } -# -# - match: { hits.total.value: 1 } -# - match: { hits.total.relation: eq } -# -# - length: { hits.hits.0._source._inference_fields.sparse_field.inference.chunks: 1 } -# - length: { hits.hits.0._source._inference_fields.sparse_field.inference.chunks.sparse_field: 1 } -# - exists: hits.hits.0._source._inference_fields.sparse_field.inference.chunks.sparse_field.0.embeddings -# - match: { hits.hits.0._source._inference_fields.sparse_field.inference.chunks.sparse_field.0.start_offset: 0 } -# - match: { hits.hits.0._source._inference_fields.sparse_field.inference.chunks.sparse_field.0.end_offset: 14 } -# -# - length: { hits.hits.0._source._inference_fields.dense_field.inference.chunks: 1 } -# - length: { hits.hits.0._source._inference_fields.dense_field.inference.chunks.dense_field: 1 } -# - exists: hits.hits.0._source._inference_fields.dense_field.inference.chunks.dense_field.0.embeddings -# - match: { hits.hits.0._source._inference_fields.dense_field.inference.chunks.dense_field.0.start_offset: 0 } -# - match: { hits.hits.0._source._inference_fields.dense_field.inference.chunks.dense_field.0.end_offset: 22 } -# -# - match: { hits.hits.0._source.sparse_field: "inference test" } -# - match: { hits.hits.0._source.dense_field: "another inference test" } -# - match: { hits.hits.0._source.non_inference_field: "another value" } -# -# - do: -# bulk: -# body: -# - '{"update": {"_index": "test-index", "_id": "doc_1"}}' -# - '{"doc": { "sparse_field": null, "dense_field": null, "non_inference_field": "updated value" }, "doc_as_upsert": true}' -# refresh: true -# -# - match: { errors: false } -# - match: { items.0.update.result: "updated" } -# -# - do: -# search: -# index: test-index -# body: -# fields: [ _inference_fields ] -# query: -# match_all: { } -# -# - match: { hits.total.value: 1 } -# - match: { hits.total.relation: eq } -# -# # TODO: BUG! Setting sparse_field & dense_field to null does not clear _inference_fields -# - length: { hits.hits.0._source._inference_fields.sparse_field.inference.chunks: 1 } -# - length: { hits.hits.0._source._inference_fields.sparse_field.inference.chunks.sparse_field: 0 } -# -# - length: { hits.hits.0._source._inference_fields.dense_field.inference.chunks: 1 } -# - length: { hits.hits.0._source._inference_fields.dense_field.inference.chunks.dense_field: 0 } -# -# - not_exists: hits.hits.0._source.sparse_field -# - not_exists: hits.hits.0._source.dense_field -# - match: { hits.hits.0._source.non_inference_field: "updated value" } +--- +"Bypass inference on bulk update operation": + # Update as upsert + - do: + bulk: + body: + - '{"update": {"_index": "test-index", "_id": "doc_1"}}' + - '{"doc": { "sparse_field": "inference test", "dense_field": "another inference test", "non_inference_field": "non inference test" }, "doc_as_upsert": true}' + + - match: { errors: false } + - match: { items.0.update.result: "created" } + + - do: + bulk: + body: + - '{"update": {"_index": "test-index", "_id": "doc_1"}}' + - '{"doc": { "non_inference_field": "another value" }, "doc_as_upsert": true}' + refresh: true + + - match: { errors: false } + - match: { items.0.update.result: "updated" } + + - do: + search: + index: test-index + body: + fields: [ _inference_fields ] + query: + match_all: { } + + - match: { hits.total.value: 1 } + - match: { hits.total.relation: eq } + + - length: { hits.hits.0._source._inference_fields.sparse_field.inference.chunks: 1 } + - length: { hits.hits.0._source._inference_fields.sparse_field.inference.chunks.sparse_field: 1 } + - exists: hits.hits.0._source._inference_fields.sparse_field.inference.chunks.sparse_field.0.embeddings + - match: { hits.hits.0._source._inference_fields.sparse_field.inference.chunks.sparse_field.0.start_offset: 0 } + - match: { hits.hits.0._source._inference_fields.sparse_field.inference.chunks.sparse_field.0.end_offset: 14 } + + - length: { hits.hits.0._source._inference_fields.dense_field.inference.chunks: 1 } + - length: { hits.hits.0._source._inference_fields.dense_field.inference.chunks.dense_field: 1 } + - exists: hits.hits.0._source._inference_fields.dense_field.inference.chunks.dense_field.0.embeddings + - match: { hits.hits.0._source._inference_fields.dense_field.inference.chunks.dense_field.0.start_offset: 0 } + - match: { hits.hits.0._source._inference_fields.dense_field.inference.chunks.dense_field.0.end_offset: 22 } + + - match: { hits.hits.0._source.sparse_field: "inference test" } + - match: { hits.hits.0._source.dense_field: "another inference test" } + - match: { hits.hits.0._source.non_inference_field: "another value" } + +--- +"Explicit nulls clear inference results on bulk update operation": + # Update as upsert + - do: + bulk: + body: + - '{"update": {"_index": "test-index", "_id": "doc_1"}}' + - '{"doc": { "sparse_field": "inference test", "dense_field": "another inference test", "non_inference_field": "non inference test" }, "doc_as_upsert": true}' + refresh: true + + - match: { errors: false } + - match: { items.0.update.result: "created" } + + - do: + search: + index: test-index + body: + fields: [ _inference_fields ] + query: + match_all: { } + + - match: { hits.total.value: 1 } + - match: { hits.total.relation: eq } + + - length: { hits.hits.0._source._inference_fields.sparse_field.inference.chunks: 1 } + - length: { hits.hits.0._source._inference_fields.sparse_field.inference.chunks.sparse_field: 1 } + - exists: hits.hits.0._source._inference_fields.sparse_field.inference.chunks.sparse_field.0.embeddings + - match: { hits.hits.0._source._inference_fields.sparse_field.inference.chunks.sparse_field.0.start_offset: 0 } + - match: { hits.hits.0._source._inference_fields.sparse_field.inference.chunks.sparse_field.0.end_offset: 14 } + + - length: { hits.hits.0._source._inference_fields.dense_field.inference.chunks: 1 } + - length: { hits.hits.0._source._inference_fields.dense_field.inference.chunks.dense_field: 1 } + - exists: hits.hits.0._source._inference_fields.dense_field.inference.chunks.dense_field.0.embeddings + - match: { hits.hits.0._source._inference_fields.dense_field.inference.chunks.dense_field.0.start_offset: 0 } + - match: { hits.hits.0._source._inference_fields.dense_field.inference.chunks.dense_field.0.end_offset: 22 } + + - do: + bulk: + body: + - '{"update": {"_index": "test-index", "_id": "doc_1"}}' + - '{"doc": { "sparse_field": null, "dense_field": null, "non_inference_field": "updated value" }, "doc_as_upsert": true}' + refresh: true + + - match: { errors: false } + - match: { items.0.update.result: "updated" } + + - do: + search: + index: test-index + body: + fields: [ _inference_fields ] + query: + match_all: { } + + - match: { hits.total.value: 1 } + - match: { hits.total.relation: eq } + + - not_exists: hits.hits.0._source._inference_fields + - not_exists: hits.hits.0._source.sparse_field + - not_exists: hits.hits.0._source.dense_field + - match: { hits.hits.0._source.non_inference_field: "updated value" } diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/60_semantic_text_inference_update_bwc.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/60_semantic_text_inference_update_bwc.yml index 6b494d531b2d1..912cdb5a85d35 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/60_semantic_text_inference_update_bwc.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/60_semantic_text_inference_update_bwc.yml @@ -632,6 +632,31 @@ setup: - match: { _source.dense_field.inference.chunks.0.text: "another inference test" } - match: { _source.non_inference_field: "another value" } +--- +"Explicit nulls clear inference results on bulk update operation": + # Update as upsert + - do: + bulk: + body: + - '{"update": {"_index": "test-index", "_id": "doc_1"}}' + - '{"doc": { "sparse_field": "inference test", "dense_field": "another inference test", "non_inference_field": "non inference test" }, "doc_as_upsert": true}' + + - match: { errors: false } + - match: { items.0.update.result: "created" } + + - do: + get: + index: test-index + id: doc_1 + + - match: { _source.sparse_field.text: "inference test" } + - exists: _source.sparse_field.inference.chunks.0.embeddings + - match: { _source.sparse_field.inference.chunks.0.text: "inference test" } + - match: { _source.dense_field.text: "another inference test" } + - exists: _source.dense_field.inference.chunks.0.embeddings + - match: { _source.dense_field.inference.chunks.0.text: "another inference test" } + - match: { _source.non_inference_field: "non inference test" } + - do: bulk: body: From cc1c5de9f2a145cdd92b5dc789f62b791b071a50 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Thu, 19 Dec 2024 15:48:47 -0500 Subject: [PATCH 6/8] Update "Explicit nulls clear inference results on bulk update operation" test to use multiple source fields --- .../60_semantic_text_inference_update.yml | 108 +++++++++++++++--- 1 file changed, 95 insertions(+), 13 deletions(-) diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/60_semantic_text_inference_update.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/60_semantic_text_inference_update.yml index c3d0a3c272a77..32530d3cf8757 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/60_semantic_text_inference_update.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/60_semantic_text_inference_update.yml @@ -870,20 +870,47 @@ setup: --- "Explicit nulls clear inference results on bulk update operation": - # Update as upsert - do: - bulk: + indices.create: + index: test-copy-to-index body: - - '{"update": {"_index": "test-index", "_id": "doc_1"}}' - - '{"doc": { "sparse_field": "inference test", "dense_field": "another inference test", "non_inference_field": "non inference test" }, "doc_as_upsert": true}' - refresh: true + settings: + index: + mapping: + semantic_text: + use_legacy_format: false + mappings: + properties: + sparse_field: + type: semantic_text + inference_id: sparse-inference-id + sparse_source_field: + type: text + copy_to: sparse_field + dense_field: + type: semantic_text + inference_id: dense-inference-id + dense_source_field: + type: text + copy_to: dense_field + non_inference_field: + type: text - - match: { errors: false } - - match: { items.0.update.result: "created" } + - do: + index: + index: test-copy-to-index + id: doc_1 + body: + sparse_field: "inference test" + sparse_source_field: "sparse source test" + dense_field: "another inference test" + dense_source_field: "dense source test" + non_inference_field: "non inference test" + refresh: true - do: search: - index: test-index + index: test-copy-to-index body: fields: [ _inference_fields ] query: @@ -892,22 +919,36 @@ setup: - match: { hits.total.value: 1 } - match: { hits.total.relation: eq } - - length: { hits.hits.0._source._inference_fields.sparse_field.inference.chunks: 1 } + - length: { hits.hits.0._source._inference_fields.sparse_field.inference.chunks: 2 } - length: { hits.hits.0._source._inference_fields.sparse_field.inference.chunks.sparse_field: 1 } - exists: hits.hits.0._source._inference_fields.sparse_field.inference.chunks.sparse_field.0.embeddings - match: { hits.hits.0._source._inference_fields.sparse_field.inference.chunks.sparse_field.0.start_offset: 0 } - match: { hits.hits.0._source._inference_fields.sparse_field.inference.chunks.sparse_field.0.end_offset: 14 } + - length: { hits.hits.0._source._inference_fields.sparse_field.inference.chunks.sparse_source_field: 1 } + - exists: hits.hits.0._source._inference_fields.sparse_field.inference.chunks.sparse_source_field.0.embeddings + - match: { hits.hits.0._source._inference_fields.sparse_field.inference.chunks.sparse_source_field.0.start_offset: 0 } + - match: { hits.hits.0._source._inference_fields.sparse_field.inference.chunks.sparse_source_field.0.end_offset: 18 } - - length: { hits.hits.0._source._inference_fields.dense_field.inference.chunks: 1 } + - length: { hits.hits.0._source._inference_fields.dense_field.inference.chunks: 2 } - length: { hits.hits.0._source._inference_fields.dense_field.inference.chunks.dense_field: 1 } - exists: hits.hits.0._source._inference_fields.dense_field.inference.chunks.dense_field.0.embeddings - match: { hits.hits.0._source._inference_fields.dense_field.inference.chunks.dense_field.0.start_offset: 0 } - match: { hits.hits.0._source._inference_fields.dense_field.inference.chunks.dense_field.0.end_offset: 22 } + - length: { hits.hits.0._source._inference_fields.dense_field.inference.chunks.dense_source_field: 1 } + - exists: hits.hits.0._source._inference_fields.dense_field.inference.chunks.dense_source_field.0.embeddings + - match: { hits.hits.0._source._inference_fields.dense_field.inference.chunks.dense_source_field.0.start_offset: 0 } + - match: { hits.hits.0._source._inference_fields.dense_field.inference.chunks.dense_source_field.0.end_offset: 17 } + + - match: { hits.hits.0._source.sparse_field: "inference test" } + - match: { hits.hits.0._source.sparse_source_field: "sparse source test" } + - match: { hits.hits.0._source.dense_field: "another inference test" } + - match: { hits.hits.0._source.dense_source_field: "dense source test" } + - match: { hits.hits.0._source.non_inference_field: "non inference test" } - do: bulk: body: - - '{"update": {"_index": "test-index", "_id": "doc_1"}}' + - '{"update": {"_index": "test-copy-to-index", "_id": "doc_1"}}' - '{"doc": { "sparse_field": null, "dense_field": null, "non_inference_field": "updated value" }, "doc_as_upsert": true}' refresh: true @@ -916,7 +957,7 @@ setup: - do: search: - index: test-index + index: test-copy-to-index body: fields: [ _inference_fields ] query: @@ -925,7 +966,48 @@ setup: - match: { hits.total.value: 1 } - match: { hits.total.relation: eq } - - not_exists: hits.hits.0._source._inference_fields + - length: { hits.hits.0._source._inference_fields.sparse_field.inference.chunks: 1 } + - length: { hits.hits.0._source._inference_fields.sparse_field.inference.chunks.sparse_source_field: 1 } + - exists: hits.hits.0._source._inference_fields.sparse_field.inference.chunks.sparse_source_field.0.embeddings + - match: { hits.hits.0._source._inference_fields.sparse_field.inference.chunks.sparse_source_field.0.start_offset: 0 } + - match: { hits.hits.0._source._inference_fields.sparse_field.inference.chunks.sparse_source_field.0.end_offset: 18 } + + - length: { hits.hits.0._source._inference_fields.dense_field.inference.chunks: 1 } + - length: { hits.hits.0._source._inference_fields.dense_field.inference.chunks.dense_source_field: 1 } + - exists: hits.hits.0._source._inference_fields.dense_field.inference.chunks.dense_source_field.0.embeddings + - match: { hits.hits.0._source._inference_fields.dense_field.inference.chunks.dense_source_field.0.start_offset: 0 } + - match: { hits.hits.0._source._inference_fields.dense_field.inference.chunks.dense_source_field.0.end_offset: 17 } + - not_exists: hits.hits.0._source.sparse_field + - match: { hits.hits.0._source.sparse_source_field: "sparse source test" } - not_exists: hits.hits.0._source.dense_field + - match: { hits.hits.0._source.dense_source_field: "dense source test" } - match: { hits.hits.0._source.non_inference_field: "updated value" } + + - do: + bulk: + body: + - '{"update": {"_index": "test-copy-to-index", "_id": "doc_1"}}' + - '{"doc": { "sparse_source_field": null, "dense_source_field": null, "non_inference_field": "another value" }, "doc_as_upsert": true}' + refresh: true + + - match: { errors: false } + - match: { items.0.update.result: "updated" } + + - do: + search: + index: test-copy-to-index + body: + fields: [ _inference_fields ] + query: + match_all: { } + + - match: { hits.total.value: 1 } + - match: { hits.total.relation: eq } + + - not_exists: hits.hits.0._source._inference_fields + - not_exists: hits.hits.0._source.sparse_field + - not_exists: hits.hits.0._source.sparse_source_field + - not_exists: hits.hits.0._source.dense_field + - not_exists: hits.hits.0._source.dense_source_field + - match: { hits.hits.0._source.non_inference_field: "another value" } From e69f15a86ad213deb7db7801e31e1c528968f912 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Mon, 23 Dec 2024 14:49:42 -0500 Subject: [PATCH 7/8] Fix compilation errors --- .../action/filter/ShardBulkInferenceActionFilterTests.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 6ffbe058ae399..eced5102dda55 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -206,7 +206,7 @@ public void testItemFailures() throws Exception { XContentMapValues.extractValue(useLegacyFormat ? "field1.text" : "field1", actualRequest.sourceAsMap()), equalTo("I am a success") ); - if (useInferenceMetadataFieldsFormat) { + if (useLegacyFormat == false) { assertNotNull( XContentMapValues.extractValue(InferenceMetadataFieldsMapper.NAME + ".field1", actualRequest.sourceAsMap()) ); @@ -246,7 +246,7 @@ public void testExplicitNull() throws Exception { threadPool, Map.of(model.getInferenceEntityId(), model), randomIntBetween(1, 10), - IndexVersion.current() + useLegacyFormat ); model.putResult("I am a failure", new ChunkedInferenceError(new IllegalArgumentException("boom"))); model.putResult("I am a success", randomChunkedInferenceEmbeddingSparse(List.of("I am a success"))); From 5cee11beea56463212e6f030237f8c1e64855d2e Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Mon, 23 Dec 2024 16:55:54 -0500 Subject: [PATCH 8/8] Fix unit test --- .../ShardBulkInferenceActionFilterTests.java | 95 +++++++++++++------ 1 file changed, 64 insertions(+), 31 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index eced5102dda55..0432a2ff3fc9e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -68,6 +68,8 @@ import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch; import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.DEFAULT_BATCH_SIZE; import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.getIndexRequestOrNull; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getChunksFieldName; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getOriginalTextFieldName; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomChunkedInferenceEmbeddingSparse; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticTextInput; @@ -76,12 +78,15 @@ import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; import static org.mockito.Mockito.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; public class ShardBulkInferenceActionFilterTests extends ESTestCase { + private static final Object EXPLICIT_NULL = new Object(); + private final boolean useLegacyFormat; private ThreadPool threadPool; @@ -241,6 +246,8 @@ public void testItemFailures() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testExplicitNull() throws Exception { StaticModel model = StaticModel.createRandomInstance(); + model.putResult("I am a failure", new ChunkedInferenceError(new IllegalArgumentException("boom"))); + model.putResult("I am a success", randomChunkedInferenceEmbeddingSparse(List.of("I am a success"))); ShardBulkInferenceActionFilter filter = createFilter( threadPool, @@ -248,8 +255,7 @@ public void testExplicitNull() throws Exception { randomIntBetween(1, 10), useLegacyFormat ); - model.putResult("I am a failure", new ChunkedInferenceError(new IllegalArgumentException("boom"))); - model.putResult("I am a success", randomChunkedInferenceEmbeddingSparse(List.of("I am a success"))); + CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { @@ -257,24 +263,16 @@ public void testExplicitNull() throws Exception { assertNull(bulkShardRequest.getInferenceFieldMap()); assertThat(bulkShardRequest.items().length, equalTo(5)); - Object explicitNull = new Object(); // item 0 assertNull(bulkShardRequest.items()[0].getPrimaryResponse()); IndexRequest actualRequest = getIndexRequestOrNull(bulkShardRequest.items()[0].request()); - assertTrue(XContentMapValues.extractValue("obj.field1", actualRequest.sourceAsMap(), explicitNull) == explicitNull); - assertNull(XContentMapValues.extractValue(InferenceMetadataFieldsMapper.NAME, actualRequest.sourceAsMap(), explicitNull)); + assertThat(XContentMapValues.extractValue("obj.field1", actualRequest.sourceAsMap(), EXPLICIT_NULL), is(EXPLICIT_NULL)); + assertNull(XContentMapValues.extractValue(InferenceMetadataFieldsMapper.NAME, actualRequest.sourceAsMap(), EXPLICIT_NULL)); // item 1 is a success assertNull(bulkShardRequest.items()[1].getPrimaryResponse()); actualRequest = getIndexRequestOrNull(bulkShardRequest.items()[1].request()); - assertThat(XContentMapValues.extractValue("obj.field1", actualRequest.sourceAsMap()), equalTo("I am a success")); - assertNotNull( - XContentMapValues.extractValue( - InferenceMetadataFieldsMapper.NAME + ".obj.field1", - actualRequest.sourceAsMap(), - explicitNull - ) - ); + assertInferenceResults(useLegacyFormat, actualRequest, "obj.field1", "I am a success", 1); // item 2 is a failure assertNotNull(bulkShardRequest.items()[2].getPrimaryResponse()); @@ -285,26 +283,13 @@ public void testExplicitNull() throws Exception { // item 3 assertNull(bulkShardRequest.items()[3].getPrimaryResponse()); actualRequest = getIndexRequestOrNull(bulkShardRequest.items()[3].request()); - assertTrue(XContentMapValues.extractValue("obj.field1", actualRequest.sourceAsMap(), explicitNull) == explicitNull); - assertTrue( - XContentMapValues.extractValue( - InferenceMetadataFieldsMapper.NAME + ".obj.field1", - actualRequest.sourceAsMap(), - explicitNull - ) == explicitNull - ); + assertInferenceResults(useLegacyFormat, actualRequest, "obj.field1", EXPLICIT_NULL, 0); // item 4 assertNull(bulkShardRequest.items()[4].getPrimaryResponse()); actualRequest = getIndexRequestOrNull(bulkShardRequest.items()[4].request()); - assertNull(XContentMapValues.extractValue("obj.field1", actualRequest.sourceAsMap(), explicitNull)); - assertNull( - XContentMapValues.extractValue( - InferenceMetadataFieldsMapper.NAME + ".obj.field1", - actualRequest.sourceAsMap(), - explicitNull - ) - ); + assertNull(XContentMapValues.extractValue("obj.field1", actualRequest.sourceAsMap(), EXPLICIT_NULL)); + assertNull(XContentMapValues.extractValue(InferenceMetadataFieldsMapper.NAME, actualRequest.sourceAsMap(), EXPLICIT_NULL)); } finally { chainExecuted.countDown(); } @@ -316,14 +301,15 @@ public void testExplicitNull() throws Exception { "obj.field1", new InferenceFieldMetadata("obj.field1", model.getInferenceEntityId(), new String[] { "obj.field1" }) ); - BulkItemRequest[] items = new BulkItemRequest[5]; Map sourceWithNull = new HashMap<>(); sourceWithNull.put("field1", null); + + BulkItemRequest[] items = new BulkItemRequest[5]; items[0] = new BulkItemRequest(0, new IndexRequest("index").source(Map.of("obj", sourceWithNull))); items[1] = new BulkItemRequest(1, new IndexRequest("index").source("obj.field1", "I am a success")); items[2] = new BulkItemRequest(2, new IndexRequest("index").source("obj.field1", "I am a failure")); items[3] = new BulkItemRequest(3, new UpdateRequest().doc(new IndexRequest("index").source(Map.of("obj", sourceWithNull)))); - items[4] = new BulkItemRequest(3, new UpdateRequest().doc(new IndexRequest("index").source(Map.of("field2", "value")))); + items[4] = new BulkItemRequest(4, new UpdateRequest().doc(new IndexRequest("index").source(Map.of("field2", "value")))); BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items); request.setInferenceFieldMap(inferenceFieldMap); filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); @@ -533,6 +519,53 @@ private static BulkItemRequest[] randomBulkItemRequest( new BulkItemRequest(requestId, new IndexRequest("index").source(expectedDocMap, requestContentType)) }; } + @SuppressWarnings({ "unchecked" }) + private static void assertInferenceResults( + boolean useLegacyFormat, + IndexRequest request, + String fieldName, + Object expectedOriginalValue, + int expectedChunkCount + ) { + final Map requestMap = request.sourceAsMap(); + if (useLegacyFormat) { + assertThat( + XContentMapValues.extractValue(getOriginalTextFieldName(fieldName), requestMap, EXPLICIT_NULL), + equalTo(expectedOriginalValue) + ); + + List chunks = (List) XContentMapValues.extractValue(getChunksFieldName(fieldName), requestMap); + if (expectedChunkCount > 0) { + assertNotNull(chunks); + assertThat(chunks.size(), equalTo(expectedChunkCount)); + } else { + // If the expected chunk count is 0, we expect that no inference has been performed. In this case, the source should not be + // transformed, and thus the semantic text field structure should not be created. + assertNull(chunks); + } + } else { + assertThat(XContentMapValues.extractValue(fieldName, requestMap, EXPLICIT_NULL), equalTo(expectedOriginalValue)); + + Map inferenceMetadataFields = (Map) XContentMapValues.extractValue( + InferenceMetadataFieldsMapper.NAME, + requestMap, + EXPLICIT_NULL + ); + assertNotNull(inferenceMetadataFields); + + // When using the inference metadata fields format, chunks are mapped by source field. We handle clearing inference results for + // a field by emitting an empty chunk list for it. This is done to prevent the clear operation from clearing inference results + // for other source fields. + List chunks = (List) XContentMapValues.extractValue( + getChunksFieldName(fieldName) + "." + fieldName, + inferenceMetadataFields, + EXPLICIT_NULL + ); + assertNotNull(chunks); + assertThat(chunks.size(), equalTo(expectedChunkCount)); + } + } + private static class StaticModel extends TestModel { private final Map resultMap;