diff --git a/docs/changelog/118999.yaml b/docs/changelog/118999.yaml new file mode 100644 index 0000000000000..0188cebbd7685 --- /dev/null +++ b/docs/changelog/118999.yaml @@ -0,0 +1,6 @@ +pr: 118999 +summary: Fix loss of context in the inference API for streaming APIs +area: Machine Learning +type: bug +issues: + - 119000 diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index 4e32ef99d06dd..a4b39e64c88fb 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -30,6 +30,7 @@ import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.equalTo; @@ -336,20 +337,34 @@ protected Map infer(String modelId, List input) throws I return inferInternal(endpoint, input, Map.of()); } - protected Deque streamInferOnMockService(String modelId, TaskType taskType, List input) throws Exception { + protected Deque streamInferOnMockService( + String modelId, + TaskType taskType, + List input, + @Nullable Consumer responseConsumerCallback + ) throws Exception { var endpoint = Strings.format("_inference/%s/%s/_stream", taskType, modelId); - return callAsync(endpoint, input); + return callAsync(endpoint, input, responseConsumerCallback); } - private Deque callAsync(String endpoint, List input) throws Exception { - var responseConsumer = new AsyncInferenceResponseConsumer(); + private Deque callAsync(String endpoint, List input, @Nullable Consumer responseConsumerCallback) + throws Exception { var request = new Request("POST", endpoint); request.setJsonEntity(jsonBody(input)); + + return execAsyncCall(request, responseConsumerCallback); + } + + private Deque execAsyncCall(Request request, @Nullable Consumer responseConsumerCallback) throws Exception { + var responseConsumer = new AsyncInferenceResponseConsumer(); request.setOptions(RequestOptions.DEFAULT.toBuilder().setHttpAsyncResponseConsumerFactory(() -> responseConsumer).build()); var latch = new CountDownLatch(1); client().performRequestAsync(request, new ResponseListener() { @Override public void onSuccess(Response response) { + if (responseConsumerCallback != null) { + responseConsumerCallback.accept(response); + } latch.countDown(); } diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java index f1831acbcc40f..9fd817e3b184d 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java @@ -10,6 +10,7 @@ package org.elasticsearch.xpack.inference; import org.apache.http.util.EntityUtils; +import org.elasticsearch.client.Response; import org.elasticsearch.client.ResponseException; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.inference.TaskType; @@ -19,6 +20,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.function.Consumer; import java.util.function.Function; import java.util.stream.IntStream; import java.util.stream.Stream; @@ -28,9 +30,15 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalToIgnoringCase; import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; public class InferenceCrudIT extends InferenceBaseRestTest { + private static final Consumer VALIDATE_ELASTIC_PRODUCT_HEADER_CONSUMER = (r) -> assertThat( + r.getHeader("X-elastic-product"), + is("Elasticsearch") + ); + @SuppressWarnings("unchecked") public void testCRUD() throws IOException { for (int i = 0; i < 5; i++) { @@ -282,7 +290,7 @@ public void testUnsupportedStream() throws Exception { assertEquals(TaskType.SPARSE_EMBEDDING.toString(), singleModel.get("task_type")); try { - var events = streamInferOnMockService(modelId, TaskType.SPARSE_EMBEDDING, List.of(randomAlphaOfLength(10))); + var events = streamInferOnMockService(modelId, TaskType.SPARSE_EMBEDDING, List.of(randomAlphaOfLength(10)), null); assertThat(events.size(), equalTo(2)); events.forEach(event -> { switch (event.name()) { @@ -309,7 +317,7 @@ public void testSupportedStream() throws Exception { var input = IntStream.range(1, 2 + randomInt(8)).mapToObj(i -> randomAlphaOfLength(10)).toList(); try { - var events = streamInferOnMockService(modelId, TaskType.COMPLETION, input); + var events = streamInferOnMockService(modelId, TaskType.COMPLETION, input, VALIDATE_ELASTIC_PRODUCT_HEADER_CONSUMER); var expectedResponses = Stream.concat( input.stream().map(String::toUpperCase).map(str -> "{\"completion\":[{\"delta\":\"" + str + "\"}]}"), diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java index ab3f466f3c11f..b993cf36cb875 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java @@ -17,6 +17,7 @@ import org.apache.http.nio.util.SimpleInputBuffer; import org.apache.http.protocol.HttpContext; import org.apache.http.util.EntityUtils; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.client.Request; import org.elasticsearch.client.RequestOptions; import org.elasticsearch.client.Response; @@ -43,6 +44,7 @@ import org.elasticsearch.rest.RestHandler; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; @@ -52,6 +54,7 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.Collection; +import java.util.Collections; import java.util.Deque; import java.util.Iterator; import java.util.List; @@ -96,6 +99,14 @@ protected Collection> nodePlugins() { } public static class StreamingPlugin extends Plugin implements ActionPlugin { + private final SetOnce threadPool = new SetOnce<>(); + + @Override + public Collection createComponents(PluginServices services) { + threadPool.set(services.threadPool()); + return Collections.emptyList(); + } + @Override public Collection getRestHandlers( Settings settings, @@ -122,7 +133,7 @@ public void handleRequest(RestRequest request, RestChannel channel, NodeClient c var publisher = new RandomPublisher(requestCount, withError); var inferenceServiceResults = new StreamingInferenceServiceResults(publisher); var inferenceResponse = new InferenceAction.Response(inferenceServiceResults, inferenceServiceResults.publisher()); - new ServerSentEventsRestActionListener(channel).onResponse(inferenceResponse); + new ServerSentEventsRestActionListener(channel, threadPool).onResponse(inferenceResponse); } }, new RestHandler() { @Override @@ -132,7 +143,7 @@ public List routes() { @Override public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) { - new ServerSentEventsRestActionListener(channel).onFailure(expectedException); + new ServerSentEventsRestActionListener(channel, threadPool).onFailure(expectedException); } }, new RestHandler() { @Override @@ -143,7 +154,7 @@ public List routes() { @Override public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) { var inferenceResponse = new InferenceAction.Response(new SingleInferenceServiceResults()); - new ServerSentEventsRestActionListener(channel).onResponse(inferenceResponse); + new ServerSentEventsRestActionListener(channel, threadPool).onResponse(inferenceResponse); } }); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index b3ab421e71e9a..a3eb32b77060a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -39,6 +39,7 @@ import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.threadpool.ExecutorBuilder; import org.elasticsearch.threadpool.ScalingExecutorBuilder; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.action.XPackUsageFeatureAction; @@ -140,6 +141,9 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP private final SetOnce amazonBedrockFactory = new SetOnce<>(); private final SetOnce serviceComponents = new SetOnce<>(); private final SetOnce eisComponents = new SetOnce<>(); + // This is mainly so that the rest handlers can access the ThreadPool in a way that avoids potential null pointers from it + // not being initialized yet + private final SetOnce threadPoolSetOnce = new SetOnce<>(); private final SetOnce inferenceServiceRegistry = new SetOnce<>(); private final SetOnce shardBulkInferenceActionFilter = new SetOnce<>(); private List inferenceServiceExtensions; @@ -176,7 +180,7 @@ public List getRestHandlers( ) { return List.of( new RestInferenceAction(), - new RestStreamInferenceAction(), + new RestStreamInferenceAction(threadPoolSetOnce), new RestGetInferenceModelAction(), new RestPutInferenceModelAction(), new RestUpdateInferenceModelAction(), @@ -190,6 +194,7 @@ public Collection createComponents(PluginServices services) { var throttlerManager = new ThrottlerManager(settings, services.threadPool(), services.clusterService()); var truncator = new Truncator(settings, services.clusterService()); serviceComponents.set(new ServiceComponents(services.threadPool(), throttlerManager, settings, truncator)); + threadPoolSetOnce.set(services.threadPool()); var httpClientManager = HttpClientManager.create(settings, services.threadPool(), services.clusterService(), throttlerManager); var httpRequestSenderFactory = new HttpRequestSender.Factory(serviceComponents.get(), httpClientManager, services.clusterService()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java index 875c288da52bd..881af435b29b6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java @@ -7,13 +7,16 @@ package org.elasticsearch.xpack.inference.rest; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionListener; import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.Scope; import org.elasticsearch.rest.ServerlessScope; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import java.util.List; +import java.util.Objects; import static org.elasticsearch.rest.RestRequest.Method.POST; import static org.elasticsearch.xpack.inference.rest.Paths.STREAM_INFERENCE_ID_PATH; @@ -21,6 +24,13 @@ @ServerlessScope(Scope.PUBLIC) public class RestStreamInferenceAction extends BaseInferenceAction { + private final SetOnce threadPool; + + public RestStreamInferenceAction(SetOnce threadPool) { + super(); + this.threadPool = Objects.requireNonNull(threadPool); + } + @Override public String getName() { return "stream_inference_action"; @@ -38,6 +48,6 @@ protected InferenceAction.Request prepareInferenceRequest(InferenceAction.Reques @Override protected ActionListener listener(RestChannel channel) { - return new ServerSentEventsRestActionListener(channel); + return new ServerSentEventsRestActionListener(channel, threadPool); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListener.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListener.java index d5f82ec78ca49..b34221c48d4e4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListener.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListener.java @@ -10,9 +10,11 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ContextPreservingActionListener; import org.elasticsearch.common.bytes.ReleasableBytesReference; import org.elasticsearch.common.collect.Iterators; import org.elasticsearch.common.io.stream.BytesStream; @@ -30,6 +32,7 @@ import org.elasticsearch.rest.RestResponse; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.TaskCancelledException; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.action.InferenceAction; @@ -39,6 +42,7 @@ import java.nio.charset.StandardCharsets; import java.util.Iterator; import java.util.Map; +import java.util.Objects; import java.util.concurrent.Flow; import java.util.concurrent.atomic.AtomicBoolean; @@ -55,6 +59,7 @@ public class ServerSentEventsRestActionListener implements ActionListener threadPool; /** * A listener for the first part of the next entry to become available for transmission. @@ -66,13 +71,14 @@ public class ServerSentEventsRestActionListener implements ActionListener nextBodyPartListener; - public ServerSentEventsRestActionListener(RestChannel channel) { - this(channel, channel.request()); + public ServerSentEventsRestActionListener(RestChannel channel, SetOnce threadPool) { + this(channel, channel.request(), threadPool); } - public ServerSentEventsRestActionListener(RestChannel channel, ToXContent.Params params) { + public ServerSentEventsRestActionListener(RestChannel channel, ToXContent.Params params, SetOnce threadPool) { this.channel = channel; this.params = params; + this.threadPool = Objects.requireNonNull(threadPool); } @Override @@ -99,7 +105,7 @@ protected void ensureOpen() { } private void initializeStream(InferenceAction.Response response) { - nextBodyPartListener = ActionListener.wrap(bodyPart -> { + ActionListener chunkedResponseBodyActionListener = ActionListener.wrap(bodyPart -> { // this is the first response, so we need to send the RestResponse to open the stream // all subsequent bytes will be delivered through the nextBodyPartListener channel.sendResponse(RestResponse.chunked(RestStatus.OK, bodyPart, this::release)); @@ -115,6 +121,12 @@ private void initializeStream(InferenceAction.Response response) { ) ); }); + + nextBodyPartListener = ContextPreservingActionListener.wrapPreservingContext( + chunkedResponseBodyActionListener, + threadPool.get().getThreadContext() + ); + // subscribe will call onSubscribe, which requests the first chunk response.publisher().subscribe(subscriber); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceActionTests.java index b999e2c9b72f0..f67680ef6b625 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceActionTests.java @@ -12,8 +12,11 @@ import org.elasticsearch.rest.RestRequest; import org.elasticsearch.test.rest.FakeRestRequest; import org.elasticsearch.test.rest.RestActionTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.junit.After; import org.junit.Before; import static org.elasticsearch.xpack.inference.rest.BaseInferenceActionTests.createResponse; @@ -22,10 +25,18 @@ import static org.hamcrest.Matchers.instanceOf; public class RestStreamInferenceActionTests extends RestActionTestCase { + private final SetOnce threadPool = new SetOnce<>(); @Before public void setUpAction() { - controller().registerHandler(new RestStreamInferenceAction()); + threadPool.set(new TestThreadPool(getTestName())); + controller().registerHandler(new RestStreamInferenceAction(threadPool)); + } + + @After + public void tearDownAction() { + terminate(threadPool.get()); + } public void testStreamIsTrue() {