Skip to content

Commit

Permalink
[ML] Fix loss of context in the inference API for streaming APIs (#11…
Browse files Browse the repository at this point in the history
…8999) (#119222)

* Adding context preserving fix

* Update docs/changelog/118999.yaml

* Update docs/changelog/118999.yaml

* Using a setonce and adding a test

* Updating the changelog

(cherry picked from commit 7ba3cb9)

# Conflicts:
#	x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java
#	x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java
#	x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java
#	x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java
#	x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java
  • Loading branch information
jonathan-buttner authored Jan 7, 2025
1 parent c65e727 commit 69bbd54
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 16 deletions.
6 changes: 6 additions & 0 deletions docs/changelog/118999.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 118999
summary: Fix loss of context in the inference API for streaming APIs
area: Machine Learning
type: bug
issues:
- 119000
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;

import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.equalTo;
Expand Down Expand Up @@ -336,20 +337,34 @@ protected Map<String, Object> infer(String modelId, List<String> input) throws I
return inferInternal(endpoint, input, Map.of());
}

protected Deque<ServerSentEvent> streamInferOnMockService(String modelId, TaskType taskType, List<String> input) throws Exception {
protected Deque<ServerSentEvent> streamInferOnMockService(
String modelId,
TaskType taskType,
List<String> input,
@Nullable Consumer<Response> responseConsumerCallback
) throws Exception {
var endpoint = Strings.format("_inference/%s/%s/_stream", taskType, modelId);
return callAsync(endpoint, input);
return callAsync(endpoint, input, responseConsumerCallback);
}

private Deque<ServerSentEvent> callAsync(String endpoint, List<String> input) throws Exception {
var responseConsumer = new AsyncInferenceResponseConsumer();
private Deque<ServerSentEvent> callAsync(String endpoint, List<String> input, @Nullable Consumer<Response> responseConsumerCallback)
throws Exception {
var request = new Request("POST", endpoint);
request.setJsonEntity(jsonBody(input));

return execAsyncCall(request, responseConsumerCallback);
}

private Deque<ServerSentEvent> execAsyncCall(Request request, @Nullable Consumer<Response> responseConsumerCallback) throws Exception {
var responseConsumer = new AsyncInferenceResponseConsumer();
request.setOptions(RequestOptions.DEFAULT.toBuilder().setHttpAsyncResponseConsumerFactory(() -> responseConsumer).build());
var latch = new CountDownLatch(1);
client().performRequestAsync(request, new ResponseListener() {
@Override
public void onSuccess(Response response) {
if (responseConsumerCallback != null) {
responseConsumerCallback.accept(response);
}
latch.countDown();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
package org.elasticsearch.xpack.inference;

import org.apache.http.util.EntityUtils;
import org.elasticsearch.client.Response;
import org.elasticsearch.client.ResponseException;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.inference.TaskType;
Expand All @@ -19,6 +20,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.IntStream;
import java.util.stream.Stream;
Expand All @@ -28,9 +30,15 @@
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.equalToIgnoringCase;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;

public class InferenceCrudIT extends InferenceBaseRestTest {

private static final Consumer<Response> VALIDATE_ELASTIC_PRODUCT_HEADER_CONSUMER = (r) -> assertThat(
r.getHeader("X-elastic-product"),
is("Elasticsearch")
);

@SuppressWarnings("unchecked")
public void testCRUD() throws IOException {
for (int i = 0; i < 5; i++) {
Expand Down Expand Up @@ -282,7 +290,7 @@ public void testUnsupportedStream() throws Exception {
assertEquals(TaskType.SPARSE_EMBEDDING.toString(), singleModel.get("task_type"));

try {
var events = streamInferOnMockService(modelId, TaskType.SPARSE_EMBEDDING, List.of(randomAlphaOfLength(10)));
var events = streamInferOnMockService(modelId, TaskType.SPARSE_EMBEDDING, List.of(randomAlphaOfLength(10)), null);
assertThat(events.size(), equalTo(2));
events.forEach(event -> {
switch (event.name()) {
Expand All @@ -309,7 +317,7 @@ public void testSupportedStream() throws Exception {

var input = IntStream.range(1, 2 + randomInt(8)).mapToObj(i -> randomAlphaOfLength(10)).toList();
try {
var events = streamInferOnMockService(modelId, TaskType.COMPLETION, input);
var events = streamInferOnMockService(modelId, TaskType.COMPLETION, input, VALIDATE_ELASTIC_PRODUCT_HEADER_CONSUMER);

var expectedResponses = Stream.concat(
input.stream().map(String::toUpperCase).map(str -> "{\"completion\":[{\"delta\":\"" + str + "\"}]}"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.apache.http.nio.util.SimpleInputBuffer;
import org.apache.http.protocol.HttpContext;
import org.apache.http.util.EntityUtils;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.Response;
Expand All @@ -43,6 +44,7 @@
import org.elasticsearch.rest.RestHandler;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
Expand All @@ -52,6 +54,7 @@
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Collection;
import java.util.Collections;
import java.util.Deque;
import java.util.Iterator;
import java.util.List;
Expand Down Expand Up @@ -96,6 +99,14 @@ protected Collection<Class<? extends Plugin>> nodePlugins() {
}

public static class StreamingPlugin extends Plugin implements ActionPlugin {
private final SetOnce<ThreadPool> threadPool = new SetOnce<>();

@Override
public Collection<?> createComponents(PluginServices services) {
threadPool.set(services.threadPool());
return Collections.emptyList();
}

@Override
public Collection<RestHandler> getRestHandlers(
Settings settings,
Expand All @@ -122,7 +133,7 @@ public void handleRequest(RestRequest request, RestChannel channel, NodeClient c
var publisher = new RandomPublisher(requestCount, withError);
var inferenceServiceResults = new StreamingInferenceServiceResults(publisher);
var inferenceResponse = new InferenceAction.Response(inferenceServiceResults, inferenceServiceResults.publisher());
new ServerSentEventsRestActionListener(channel).onResponse(inferenceResponse);
new ServerSentEventsRestActionListener(channel, threadPool).onResponse(inferenceResponse);
}
}, new RestHandler() {
@Override
Expand All @@ -132,7 +143,7 @@ public List<Route> routes() {

@Override
public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) {
new ServerSentEventsRestActionListener(channel).onFailure(expectedException);
new ServerSentEventsRestActionListener(channel, threadPool).onFailure(expectedException);
}
}, new RestHandler() {
@Override
Expand All @@ -143,7 +154,7 @@ public List<Route> routes() {
@Override
public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) {
var inferenceResponse = new InferenceAction.Response(new SingleInferenceServiceResults());
new ServerSentEventsRestActionListener(channel).onResponse(inferenceResponse);
new ServerSentEventsRestActionListener(channel, threadPool).onResponse(inferenceResponse);
}
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.threadpool.ExecutorBuilder;
import org.elasticsearch.threadpool.ScalingExecutorBuilder;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.action.XPackUsageFeatureAction;
Expand Down Expand Up @@ -140,6 +141,9 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP
private final SetOnce<AmazonBedrockRequestSender.Factory> amazonBedrockFactory = new SetOnce<>();
private final SetOnce<ServiceComponents> serviceComponents = new SetOnce<>();
private final SetOnce<ElasticInferenceServiceComponents> eisComponents = new SetOnce<>();
// This is mainly so that the rest handlers can access the ThreadPool in a way that avoids potential null pointers from it
// not being initialized yet
private final SetOnce<ThreadPool> threadPoolSetOnce = new SetOnce<>();
private final SetOnce<InferenceServiceRegistry> inferenceServiceRegistry = new SetOnce<>();
private final SetOnce<ShardBulkInferenceActionFilter> shardBulkInferenceActionFilter = new SetOnce<>();
private List<InferenceServiceExtension> inferenceServiceExtensions;
Expand Down Expand Up @@ -176,7 +180,7 @@ public List<RestHandler> getRestHandlers(
) {
return List.of(
new RestInferenceAction(),
new RestStreamInferenceAction(),
new RestStreamInferenceAction(threadPoolSetOnce),
new RestGetInferenceModelAction(),
new RestPutInferenceModelAction(),
new RestUpdateInferenceModelAction(),
Expand All @@ -190,6 +194,7 @@ public Collection<?> createComponents(PluginServices services) {
var throttlerManager = new ThrottlerManager(settings, services.threadPool(), services.clusterService());
var truncator = new Truncator(settings, services.clusterService());
serviceComponents.set(new ServiceComponents(services.threadPool(), throttlerManager, settings, truncator));
threadPoolSetOnce.set(services.threadPool());

var httpClientManager = HttpClientManager.create(settings, services.threadPool(), services.clusterService(), throttlerManager);
var httpRequestSenderFactory = new HttpRequestSender.Factory(serviceComponents.get(), httpClientManager, services.clusterService());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,30 @@

package org.elasticsearch.xpack.inference.rest;

import org.apache.lucene.util.SetOnce;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.rest.Scope;
import org.elasticsearch.rest.ServerlessScope;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;

import java.util.List;
import java.util.Objects;

import static org.elasticsearch.rest.RestRequest.Method.POST;
import static org.elasticsearch.xpack.inference.rest.Paths.STREAM_INFERENCE_ID_PATH;
import static org.elasticsearch.xpack.inference.rest.Paths.STREAM_TASK_TYPE_INFERENCE_ID_PATH;

@ServerlessScope(Scope.PUBLIC)
public class RestStreamInferenceAction extends BaseInferenceAction {
private final SetOnce<ThreadPool> threadPool;

public RestStreamInferenceAction(SetOnce<ThreadPool> threadPool) {
super();
this.threadPool = Objects.requireNonNull(threadPool);
}

@Override
public String getName() {
return "stream_inference_action";
Expand All @@ -38,6 +48,6 @@ protected InferenceAction.Request prepareInferenceRequest(InferenceAction.Reques

@Override
protected ActionListener<InferenceAction.Response> listener(RestChannel channel) {
return new ServerSentEventsRestActionListener(channel);
return new ServerSentEventsRestActionListener(channel, threadPool);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.common.collect.Iterators;
import org.elasticsearch.common.io.stream.BytesStream;
Expand All @@ -30,6 +32,7 @@
import org.elasticsearch.rest.RestResponse;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
Expand All @@ -39,6 +42,7 @@
import java.nio.charset.StandardCharsets;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.Flow;
import java.util.concurrent.atomic.AtomicBoolean;

Expand All @@ -55,6 +59,7 @@ public class ServerSentEventsRestActionListener implements ActionListener<Infere
private final AtomicBoolean isLastPart = new AtomicBoolean(false);
private final RestChannel channel;
private final ToXContent.Params params;
private final SetOnce<ThreadPool> threadPool;

/**
* A listener for the first part of the next entry to become available for transmission.
Expand All @@ -66,13 +71,14 @@ public class ServerSentEventsRestActionListener implements ActionListener<Infere
*/
private ActionListener<ChunkedRestResponseBodyPart> nextBodyPartListener;

public ServerSentEventsRestActionListener(RestChannel channel) {
this(channel, channel.request());
public ServerSentEventsRestActionListener(RestChannel channel, SetOnce<ThreadPool> threadPool) {
this(channel, channel.request(), threadPool);
}

public ServerSentEventsRestActionListener(RestChannel channel, ToXContent.Params params) {
public ServerSentEventsRestActionListener(RestChannel channel, ToXContent.Params params, SetOnce<ThreadPool> threadPool) {
this.channel = channel;
this.params = params;
this.threadPool = Objects.requireNonNull(threadPool);
}

@Override
Expand All @@ -99,7 +105,7 @@ protected void ensureOpen() {
}

private void initializeStream(InferenceAction.Response response) {
nextBodyPartListener = ActionListener.wrap(bodyPart -> {
ActionListener<ChunkedRestResponseBodyPart> chunkedResponseBodyActionListener = ActionListener.wrap(bodyPart -> {
// this is the first response, so we need to send the RestResponse to open the stream
// all subsequent bytes will be delivered through the nextBodyPartListener
channel.sendResponse(RestResponse.chunked(RestStatus.OK, bodyPart, this::release));
Expand All @@ -115,6 +121,12 @@ private void initializeStream(InferenceAction.Response response) {
)
);
});

nextBodyPartListener = ContextPreservingActionListener.wrapPreservingContext(
chunkedResponseBodyActionListener,
threadPool.get().getThreadContext()
);

// subscribe will call onSubscribe, which requests the first chunk
response.publisher().subscribe(subscriber);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.test.rest.FakeRestRequest;
import org.elasticsearch.test.rest.RestActionTestCase;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.junit.After;
import org.junit.Before;

import static org.elasticsearch.xpack.inference.rest.BaseInferenceActionTests.createResponse;
Expand All @@ -22,10 +25,18 @@
import static org.hamcrest.Matchers.instanceOf;

public class RestStreamInferenceActionTests extends RestActionTestCase {
private final SetOnce<ThreadPool> threadPool = new SetOnce<>();

@Before
public void setUpAction() {
controller().registerHandler(new RestStreamInferenceAction());
threadPool.set(new TestThreadPool(getTestName()));
controller().registerHandler(new RestStreamInferenceAction(threadPool));
}

@After
public void tearDownAction() {
terminate(threadPool.get());

}

public void testStreamIsTrue() {
Expand Down

0 comments on commit 69bbd54

Please sign in to comment.