-
Notifications
You must be signed in to change notification settings - Fork 24.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[inference_metadata_fields] Clear inference results on explicit nulls #119145
base: inference_metadata_fields
Are you sure you want to change the base?
Changes from 7 commits
77c361c
60a4a1d
af5a348
8d0c146
67fb4e1
e4fe494
cc1c5de
6677e8f
e69f15a
5cee11b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,6 +39,7 @@ | |
import org.elasticsearch.inference.UnparsedModel; | ||
import org.elasticsearch.rest.RestStatus; | ||
import org.elasticsearch.tasks.Task; | ||
import org.elasticsearch.xcontent.XContent; | ||
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; | ||
import org.elasticsearch.xpack.inference.mapper.SemanticTextField; | ||
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; | ||
|
@@ -50,6 +51,7 @@ | |
import java.util.Collections; | ||
import java.util.Comparator; | ||
import java.util.HashMap; | ||
import java.util.Iterator; | ||
import java.util.LinkedHashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
@@ -67,6 +69,8 @@ | |
*/ | ||
public class ShardBulkInferenceActionFilter implements MappedActionFilter { | ||
protected static final int DEFAULT_BATCH_SIZE = 512; | ||
private static final Object EXPLICIT_NULL = new Object(); | ||
private static final ChunkedInference EMPTY_CHUNKED_INFERENCE = new EmptyChunkedInference(); | ||
|
||
private final ClusterService clusterService; | ||
private final InferenceServiceRegistry inferenceServiceRegistry; | ||
|
@@ -393,11 +397,22 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons | |
for (var entry : response.responses.entrySet()) { | ||
var fieldName = entry.getKey(); | ||
var responses = entry.getValue(); | ||
var model = responses.get(0).model(); | ||
Model model = null; | ||
|
||
InferenceFieldMetadata inferenceFieldMetadata = fieldInferenceMap.get(fieldName); | ||
if (inferenceFieldMetadata == null) { | ||
throw new IllegalStateException("No inference field metadata for field [" + fieldName + "]"); | ||
} | ||
|
||
// ensure that the order in the original field is consistent in case of multiple inputs | ||
Collections.sort(responses, Comparator.comparingInt(FieldInferenceResponse::inputOrder)); | ||
Map<String, List<SemanticTextField.Chunk>> chunkMap = new LinkedHashMap<>(); | ||
for (var resp : responses) { | ||
// Get the first non-null model from the response list | ||
if (model == null) { | ||
model = resp.model; | ||
} | ||
|
||
var lst = chunkMap.computeIfAbsent(resp.sourceField, k -> new ArrayList<>()); | ||
lst.addAll( | ||
SemanticTextField.toSemanticTextFieldChunks( | ||
|
@@ -409,21 +424,26 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons | |
) | ||
); | ||
} | ||
|
||
List<String> inputs = responses.stream() | ||
.filter(r -> r.sourceField().equals(fieldName)) | ||
.map(r -> r.input) | ||
.collect(Collectors.toList()); | ||
|
||
// The model can be null if we are only processing update requests that clear inference results. This is ok because we will | ||
// merge in the field's existing model settings on the data node. | ||
var result = new SemanticTextField( | ||
useLegacyFormat, | ||
fieldName, | ||
useLegacyFormat ? inputs : null, | ||
new SemanticTextField.InferenceResult( | ||
model.getInferenceEntityId(), | ||
new SemanticTextField.ModelSettings(model), | ||
inferenceFieldMetadata.getInferenceId(), | ||
model != null ? new SemanticTextField.ModelSettings(model) : null, | ||
chunkMap | ||
), | ||
indexRequest.getContentType() | ||
); | ||
|
||
if (useLegacyFormat) { | ||
SemanticTextUtils.insertValue(fieldName, newDocMap, result); | ||
} else { | ||
|
@@ -490,7 +510,8 @@ private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(Bu | |
} else { | ||
var inferenceMetadataFieldsValue = XContentMapValues.extractValue( | ||
InferenceMetadataFieldsMapper.NAME + "." + field, | ||
docMap | ||
docMap, | ||
EXPLICIT_NULL | ||
); | ||
if (inferenceMetadataFieldsValue != null) { | ||
// Inference has already been computed | ||
|
@@ -500,9 +521,22 @@ private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(Bu | |
|
||
int order = 0; | ||
for (var sourceField : entry.getSourceFields()) { | ||
// TODO: Detect when the field is provided with an explicit null value | ||
var valueObj = XContentMapValues.extractValue(sourceField, docMap); | ||
if (valueObj == null) { | ||
var valueObj = XContentMapValues.extractValue(sourceField, docMap, EXPLICIT_NULL); | ||
if (useLegacyFormat == false && isUpdateRequest && valueObj == EXPLICIT_NULL) { | ||
/** | ||
* It's an update request, and the source field is explicitly set to null, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice comment here! |
||
* so we need to propagate this information to the inference fields metadata | ||
* to overwrite any inference previously computed on the field. | ||
* This ensures that the field is treated as intentionally cleared, | ||
* preventing any unintended carryover of prior inference results. | ||
*/ | ||
var slot = ensureResponseAccumulatorSlot(itemIndex); | ||
slot.addOrUpdateResponse( | ||
new FieldInferenceResponse(field, sourceField, null, order++, 0, null, EMPTY_CHUNKED_INFERENCE) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We're incrementing |
||
); | ||
continue; | ||
} | ||
if (valueObj == null || valueObj == EXPLICIT_NULL) { | ||
if (isUpdateRequest && useLegacyFormat) { | ||
addInferenceResponseFailure( | ||
item.id(), | ||
|
@@ -552,4 +586,11 @@ static IndexRequest getIndexRequestOrNull(DocWriteRequest<?> docWriteRequest) { | |
return null; | ||
} | ||
} | ||
|
||
private static class EmptyChunkedInference implements ChunkedInference { | ||
@Override | ||
public Iterator<Chunk> chunksAsMatchedTextAndByteReference(XContent xcontent) { | ||
return Collections.emptyIterator(); | ||
} | ||
} | ||
} |
Mikep86 marked this conversation as resolved.
Show resolved
Hide resolved
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to do any additional validation here, to verify that
model
if it exists is compatible withresp.model
?