From 12f43573a11e3eeb0b4ffb9ea80051768d6f45b4 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner <56361221+jonathan-buttner@users.noreply.github.com> Date: Tue, 7 Jan 2025 14:26:17 -0500 Subject: [PATCH] [8.x] [ML] Fix loss of context in the inference API for streaming APIs (#118999) (#119218) * [ML] Fix loss of context in the inference API for streaming APIs (#118999) * Adding context preserving fix * Update docs/changelog/118999.yaml * Update docs/changelog/118999.yaml * Using a setonce and adding a test * Updating the changelog (cherry picked from commit 7ba3cb9d0dc624f273f0cc8d58440992523de3cf) # Conflicts: # x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java * Removing assert --- docs/changelog/118999.yaml | 6 +++ .../inference/InferenceBaseRestTest.java | 39 ++++++++++++++----- .../xpack/inference/InferenceCrudIT.java | 14 +++++-- ...rverSentEventsRestActionListenerTests.java | 17 ++++++-- .../xpack/inference/InferencePlugin.java | 9 ++++- .../rest/RestStreamInferenceAction.java | 12 +++++- .../RestUnifiedCompletionInferenceAction.java | 16 +++++++- .../ServerSentEventsRestActionListener.java | 20 ++++++++-- .../rest/RestStreamInferenceActionTests.java | 13 ++++++- ...UnifiedCompletionInferenceActionTests.java | 12 +++++- 10 files changed, 132 insertions(+), 26 deletions(-) create mode 100644 docs/changelog/118999.yaml diff --git a/docs/changelog/118999.yaml b/docs/changelog/118999.yaml new file mode 100644 index 0000000000000..0188cebbd7685 --- /dev/null +++ b/docs/changelog/118999.yaml @@ -0,0 +1,6 @@ +pr: 118999 +summary: Fix loss of context in the inference API for streaming APIs +area: Machine Learning +type: bug +issues: + - 119000 diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index 5e6c4d53f4c58..cdc6d9b2dff5f 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -34,6 +34,7 @@ import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.equalTo; @@ -341,31 +342,44 @@ protected Map infer(String modelId, List input) throws I return inferInternal(endpoint, input, null, Map.of()); } - protected Deque streamInferOnMockService(String modelId, TaskType taskType, List input) throws Exception { + protected Deque streamInferOnMockService( + String modelId, + TaskType taskType, + List input, + @Nullable Consumer responseConsumerCallback + ) throws Exception { var endpoint = Strings.format("_inference/%s/%s/_stream", taskType, modelId); - return callAsync(endpoint, input); + return callAsync(endpoint, input, responseConsumerCallback); } - protected Deque unifiedCompletionInferOnMockService(String modelId, TaskType taskType, List input) - throws Exception { + protected Deque unifiedCompletionInferOnMockService( + String modelId, + TaskType taskType, + List input, + @Nullable Consumer responseConsumerCallback + ) throws Exception { var endpoint = Strings.format("_inference/%s/%s/_unified", taskType, modelId); - return callAsyncUnified(endpoint, input, "user"); + return callAsyncUnified(endpoint, input, "user", responseConsumerCallback); } - private Deque callAsync(String endpoint, List input) throws Exception { + private Deque callAsync(String endpoint, List input, @Nullable Consumer responseConsumerCallback) + throws Exception { var request = new Request("POST", endpoint); request.setJsonEntity(jsonBody(input, null)); - return execAsyncCall(request); + return execAsyncCall(request, responseConsumerCallback); } - private Deque execAsyncCall(Request request) throws Exception { + private Deque execAsyncCall(Request request, @Nullable Consumer responseConsumerCallback) throws Exception { var responseConsumer = new AsyncInferenceResponseConsumer(); request.setOptions(RequestOptions.DEFAULT.toBuilder().setHttpAsyncResponseConsumerFactory(() -> responseConsumer).build()); var latch = new CountDownLatch(1); client().performRequestAsync(request, new ResponseListener() { @Override public void onSuccess(Response response) { + if (responseConsumerCallback != null) { + responseConsumerCallback.accept(response); + } latch.countDown(); } @@ -378,11 +392,16 @@ public void onFailure(Exception exception) { return responseConsumer.events(); } - private Deque callAsyncUnified(String endpoint, List input, String role) throws Exception { + private Deque callAsyncUnified( + String endpoint, + List input, + String role, + @Nullable Consumer responseConsumerCallback + ) throws Exception { var request = new Request("POST", endpoint); request.setJsonEntity(createUnifiedJsonBody(input, role)); - return execAsyncCall(request); + return execAsyncCall(request, responseConsumerCallback); } private String createUnifiedJsonBody(List input, String role) throws IOException { diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java index 2e280f644add5..1bf8af00adbb5 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java @@ -10,6 +10,7 @@ package org.elasticsearch.xpack.inference; import org.apache.http.util.EntityUtils; +import org.elasticsearch.client.Response; import org.elasticsearch.client.ResponseException; import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; @@ -25,6 +26,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.function.Consumer; import java.util.function.Function; import java.util.stream.IntStream; import java.util.stream.Stream; @@ -34,9 +36,15 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalToIgnoringCase; import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; public class InferenceCrudIT extends InferenceBaseRestTest { + private static final Consumer VALIDATE_ELASTIC_PRODUCT_HEADER_CONSUMER = (r) -> assertThat( + r.getHeader("X-elastic-product"), + is("Elasticsearch") + ); + @SuppressWarnings("unchecked") public void testCRUD() throws IOException { for (int i = 0; i < 5; i++) { @@ -288,7 +296,7 @@ public void testUnsupportedStream() throws Exception { assertEquals(TaskType.SPARSE_EMBEDDING.toString(), singleModel.get("task_type")); try { - var events = streamInferOnMockService(modelId, TaskType.SPARSE_EMBEDDING, List.of(randomUUID())); + var events = streamInferOnMockService(modelId, TaskType.SPARSE_EMBEDDING, List.of(randomUUID()), null); assertThat(events.size(), equalTo(2)); events.forEach(event -> { switch (event.name()) { @@ -315,7 +323,7 @@ public void testSupportedStream() throws Exception { var input = IntStream.range(1, 2 + randomInt(8)).mapToObj(i -> randomAlphanumericOfLength(5)).toList(); try { - var events = streamInferOnMockService(modelId, TaskType.COMPLETION, input); + var events = streamInferOnMockService(modelId, TaskType.COMPLETION, input, VALIDATE_ELASTIC_PRODUCT_HEADER_CONSUMER); var expectedResponses = Stream.concat( input.stream().map(s -> s.toUpperCase(Locale.ROOT)).map(str -> "{\"completion\":[{\"delta\":\"" + str + "\"}]}"), @@ -342,7 +350,7 @@ public void testUnifiedCompletionInference() throws Exception { var input = IntStream.range(1, 2 + randomInt(8)).mapToObj(i -> randomAlphanumericOfLength(5)).toList(); try { - var events = unifiedCompletionInferOnMockService(modelId, TaskType.COMPLETION, input); + var events = unifiedCompletionInferOnMockService(modelId, TaskType.COMPLETION, input, VALIDATE_ELASTIC_PRODUCT_HEADER_CONSUMER); var expectedResponses = expectedResultsIterator(input); assertThat(events.size(), equalTo((input.size() + 1) * 2)); events.forEach(event -> { diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java index ab3f466f3c11f..b993cf36cb875 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java @@ -17,6 +17,7 @@ import org.apache.http.nio.util.SimpleInputBuffer; import org.apache.http.protocol.HttpContext; import org.apache.http.util.EntityUtils; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.client.Request; import org.elasticsearch.client.RequestOptions; import org.elasticsearch.client.Response; @@ -43,6 +44,7 @@ import org.elasticsearch.rest.RestHandler; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; @@ -52,6 +54,7 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.Collection; +import java.util.Collections; import java.util.Deque; import java.util.Iterator; import java.util.List; @@ -96,6 +99,14 @@ protected Collection> nodePlugins() { } public static class StreamingPlugin extends Plugin implements ActionPlugin { + private final SetOnce threadPool = new SetOnce<>(); + + @Override + public Collection createComponents(PluginServices services) { + threadPool.set(services.threadPool()); + return Collections.emptyList(); + } + @Override public Collection getRestHandlers( Settings settings, @@ -122,7 +133,7 @@ public void handleRequest(RestRequest request, RestChannel channel, NodeClient c var publisher = new RandomPublisher(requestCount, withError); var inferenceServiceResults = new StreamingInferenceServiceResults(publisher); var inferenceResponse = new InferenceAction.Response(inferenceServiceResults, inferenceServiceResults.publisher()); - new ServerSentEventsRestActionListener(channel).onResponse(inferenceResponse); + new ServerSentEventsRestActionListener(channel, threadPool).onResponse(inferenceResponse); } }, new RestHandler() { @Override @@ -132,7 +143,7 @@ public List routes() { @Override public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) { - new ServerSentEventsRestActionListener(channel).onFailure(expectedException); + new ServerSentEventsRestActionListener(channel, threadPool).onFailure(expectedException); } }, new RestHandler() { @Override @@ -143,7 +154,7 @@ public List routes() { @Override public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) { var inferenceResponse = new InferenceAction.Response(new SingleInferenceServiceResults()); - new ServerSentEventsRestActionListener(channel).onResponse(inferenceResponse); + new ServerSentEventsRestActionListener(channel, threadPool).onResponse(inferenceResponse); } }); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index ffaaddeb7bce8..3e3960a8475be 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -42,6 +42,7 @@ import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.threadpool.ExecutorBuilder; import org.elasticsearch.threadpool.ScalingExecutorBuilder; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.action.XPackUsageFeatureAction; @@ -151,6 +152,9 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP private final SetOnce amazonBedrockFactory = new SetOnce<>(); private final SetOnce serviceComponents = new SetOnce<>(); private final SetOnce eisComponents = new SetOnce<>(); + // This is mainly so that the rest handlers can access the ThreadPool in a way that avoids potential null pointers from it + // not being initialized yet + private final SetOnce threadPoolSetOnce = new SetOnce<>(); private final SetOnce inferenceServiceRegistry = new SetOnce<>(); private final SetOnce shardBulkInferenceActionFilter = new SetOnce<>(); private List inferenceServiceExtensions; @@ -195,7 +199,7 @@ public List getRestHandlers( ) { var availableRestActions = List.of( new RestInferenceAction(), - new RestStreamInferenceAction(), + new RestStreamInferenceAction(threadPoolSetOnce), new RestGetInferenceModelAction(), new RestPutInferenceModelAction(), new RestUpdateInferenceModelAction(), @@ -203,7 +207,7 @@ public List getRestHandlers( new RestGetInferenceDiagnosticsAction() ); List conditionalRestActions = UnifiedCompletionFeature.UNIFIED_COMPLETION_FEATURE_FLAG.isEnabled() - ? List.of(new RestUnifiedCompletionInferenceAction()) + ? List.of(new RestUnifiedCompletionInferenceAction(threadPoolSetOnce)) : List.of(); return Stream.concat(availableRestActions.stream(), conditionalRestActions.stream()).toList(); @@ -214,6 +218,7 @@ public Collection createComponents(PluginServices services) { var throttlerManager = new ThrottlerManager(settings, services.threadPool(), services.clusterService()); var truncator = new Truncator(settings, services.clusterService()); serviceComponents.set(new ServiceComponents(services.threadPool(), throttlerManager, settings, truncator)); + threadPoolSetOnce.set(services.threadPool()); var httpClientManager = HttpClientManager.create(settings, services.threadPool(), services.clusterService(), throttlerManager); var httpRequestSenderFactory = new HttpRequestSender.Factory(serviceComponents.get(), httpClientManager, services.clusterService()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java index 875c288da52bd..881af435b29b6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceAction.java @@ -7,13 +7,16 @@ package org.elasticsearch.xpack.inference.rest; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionListener; import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.Scope; import org.elasticsearch.rest.ServerlessScope; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import java.util.List; +import java.util.Objects; import static org.elasticsearch.rest.RestRequest.Method.POST; import static org.elasticsearch.xpack.inference.rest.Paths.STREAM_INFERENCE_ID_PATH; @@ -21,6 +24,13 @@ @ServerlessScope(Scope.PUBLIC) public class RestStreamInferenceAction extends BaseInferenceAction { + private final SetOnce threadPool; + + public RestStreamInferenceAction(SetOnce threadPool) { + super(); + this.threadPool = Objects.requireNonNull(threadPool); + } + @Override public String getName() { return "stream_inference_action"; @@ -38,6 +48,6 @@ protected InferenceAction.Request prepareInferenceRequest(InferenceAction.Reques @Override protected ActionListener listener(RestChannel channel) { - return new ServerSentEventsRestActionListener(channel); + return new ServerSentEventsRestActionListener(channel, threadPool); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java index 5c71b560a6b9d..51f1bc48c8306 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java @@ -7,15 +7,18 @@ package org.elasticsearch.xpack.inference.rest; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.rest.BaseRestHandler; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.Scope; import org.elasticsearch.rest.ServerlessScope; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; import java.io.IOException; import java.util.List; +import java.util.Objects; import static org.elasticsearch.rest.RestRequest.Method.POST; import static org.elasticsearch.xpack.inference.rest.Paths.UNIFIED_INFERENCE_ID_PATH; @@ -23,6 +26,13 @@ @ServerlessScope(Scope.PUBLIC) public class RestUnifiedCompletionInferenceAction extends BaseRestHandler { + private final SetOnce threadPool; + + public RestUnifiedCompletionInferenceAction(SetOnce threadPool) { + super(); + this.threadPool = Objects.requireNonNull(threadPool); + } + @Override public String getName() { return "unified_inference_action"; @@ -44,6 +54,10 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient request = UnifiedCompletionAction.Request.parseRequest(params.inferenceEntityId(), params.taskType(), inferTimeout, parser); } - return channel -> client.execute(UnifiedCompletionAction.INSTANCE, request, new ServerSentEventsRestActionListener(channel)); + return channel -> client.execute( + UnifiedCompletionAction.INSTANCE, + request, + new ServerSentEventsRestActionListener(channel, threadPool) + ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListener.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListener.java index 72a0d17da89bc..4e9f207d46372 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListener.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListener.java @@ -10,9 +10,11 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ContextPreservingActionListener; import org.elasticsearch.common.bytes.ReleasableBytesReference; import org.elasticsearch.common.collect.Iterators; import org.elasticsearch.common.io.stream.BytesStream; @@ -30,6 +32,7 @@ import org.elasticsearch.rest.RestResponse; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.TaskCancelledException; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.action.InferenceAction; @@ -39,6 +42,7 @@ import java.nio.charset.StandardCharsets; import java.util.Iterator; import java.util.Map; +import java.util.Objects; import java.util.concurrent.Flow; import java.util.concurrent.atomic.AtomicBoolean; @@ -56,6 +60,7 @@ public class ServerSentEventsRestActionListener implements ActionListener threadPool; /** * A listener for the first part of the next entry to become available for transmission. @@ -67,13 +72,14 @@ public class ServerSentEventsRestActionListener implements ActionListener nextBodyPartListener; - public ServerSentEventsRestActionListener(RestChannel channel) { - this(channel, channel.request()); + public ServerSentEventsRestActionListener(RestChannel channel, SetOnce threadPool) { + this(channel, channel.request(), threadPool); } - public ServerSentEventsRestActionListener(RestChannel channel, ToXContent.Params params) { + public ServerSentEventsRestActionListener(RestChannel channel, ToXContent.Params params, SetOnce threadPool) { this.channel = channel; this.params = params; + this.threadPool = Objects.requireNonNull(threadPool); } @Override @@ -100,7 +106,7 @@ protected void ensureOpen() { } private void initializeStream(InferenceAction.Response response) { - nextBodyPartListener = ActionListener.wrap(bodyPart -> { + ActionListener chunkedResponseBodyActionListener = ActionListener.wrap(bodyPart -> { // this is the first response, so we need to send the RestResponse to open the stream // all subsequent bytes will be delivered through the nextBodyPartListener channel.sendResponse(RestResponse.chunked(RestStatus.OK, bodyPart, this::release)); @@ -116,6 +122,12 @@ private void initializeStream(InferenceAction.Response response) { ) ); }); + + nextBodyPartListener = ContextPreservingActionListener.wrapPreservingContext( + chunkedResponseBodyActionListener, + threadPool.get().getThreadContext() + ); + // subscribe will call onSubscribe, which requests the first chunk response.publisher().subscribe(subscriber); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceActionTests.java index b999e2c9b72f0..f67680ef6b625 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestStreamInferenceActionTests.java @@ -12,8 +12,11 @@ import org.elasticsearch.rest.RestRequest; import org.elasticsearch.test.rest.FakeRestRequest; import org.elasticsearch.test.rest.RestActionTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.junit.After; import org.junit.Before; import static org.elasticsearch.xpack.inference.rest.BaseInferenceActionTests.createResponse; @@ -22,10 +25,18 @@ import static org.hamcrest.Matchers.instanceOf; public class RestStreamInferenceActionTests extends RestActionTestCase { + private final SetOnce threadPool = new SetOnce<>(); @Before public void setUpAction() { - controller().registerHandler(new RestStreamInferenceAction()); + threadPool.set(new TestThreadPool(getTestName())); + controller().registerHandler(new RestStreamInferenceAction(threadPool)); + } + + @After + public void tearDownAction() { + terminate(threadPool.get()); + } public void testStreamIsTrue() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java index 5acfe67b175df..9dc23c890c14d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java @@ -17,8 +17,11 @@ import org.elasticsearch.rest.RestResponse; import org.elasticsearch.test.rest.FakeRestRequest; import org.elasticsearch.test.rest.RestActionTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; +import org.junit.After; import org.junit.Before; import static org.elasticsearch.xpack.inference.rest.BaseInferenceActionTests.createResponse; @@ -27,10 +30,17 @@ import static org.hamcrest.Matchers.instanceOf; public class RestUnifiedCompletionInferenceActionTests extends RestActionTestCase { + private final SetOnce threadPool = new SetOnce<>(); @Before public void setUpAction() { - controller().registerHandler(new RestUnifiedCompletionInferenceAction()); + threadPool.set(new TestThreadPool(getTestName())); + controller().registerHandler(new RestUnifiedCompletionInferenceAction(threadPool)); + } + + @After + public void tearDownAction() { + terminate(threadPool.get()); } public void testStreamIsTrue() {