Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[8.17] [ML] Fix loss of context in the inference API for streaming APIs (#118999) #119222

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/118999.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 118999
summary: Fix loss of context in the inference API for streaming APIs
area: Machine Learning
type: bug
issues:
- 119000
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;

import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.equalTo;
Expand Down Expand Up @@ -336,20 +337,34 @@ protected Map<String, Object> infer(String modelId, List<String> input) throws I
return inferInternal(endpoint, input, Map.of());
}

protected Deque<ServerSentEvent> streamInferOnMockService(String modelId, TaskType taskType, List<String> input) throws Exception {
protected Deque<ServerSentEvent> streamInferOnMockService(
String modelId,
TaskType taskType,
List<String> input,
@Nullable Consumer<Response> responseConsumerCallback
) throws Exception {
var endpoint = Strings.format("_inference/%s/%s/_stream", taskType, modelId);
return callAsync(endpoint, input);
return callAsync(endpoint, input, responseConsumerCallback);
}

private Deque<ServerSentEvent> callAsync(String endpoint, List<String> input) throws Exception {
var responseConsumer = new AsyncInferenceResponseConsumer();
private Deque<ServerSentEvent> callAsync(String endpoint, List<String> input, @Nullable Consumer<Response> responseConsumerCallback)
throws Exception {
var request = new Request("POST", endpoint);
request.setJsonEntity(jsonBody(input));

return execAsyncCall(request, responseConsumerCallback);
}

private Deque<ServerSentEvent> execAsyncCall(Request request, @Nullable Consumer<Response> responseConsumerCallback) throws Exception {
var responseConsumer = new AsyncInferenceResponseConsumer();
request.setOptions(RequestOptions.DEFAULT.toBuilder().setHttpAsyncResponseConsumerFactory(() -> responseConsumer).build());
var latch = new CountDownLatch(1);
client().performRequestAsync(request, new ResponseListener() {
@Override
public void onSuccess(Response response) {
if (responseConsumerCallback != null) {
responseConsumerCallback.accept(response);
}
latch.countDown();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
package org.elasticsearch.xpack.inference;

import org.apache.http.util.EntityUtils;
import org.elasticsearch.client.Response;
import org.elasticsearch.client.ResponseException;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.inference.TaskType;
Expand All @@ -19,6 +20,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.IntStream;
import java.util.stream.Stream;
Expand All @@ -28,9 +30,15 @@
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.equalToIgnoringCase;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;

public class InferenceCrudIT extends InferenceBaseRestTest {

private static final Consumer<Response> VALIDATE_ELASTIC_PRODUCT_HEADER_CONSUMER = (r) -> assertThat(
r.getHeader("X-elastic-product"),
is("Elasticsearch")
);

@SuppressWarnings("unchecked")
public void testCRUD() throws IOException {
for (int i = 0; i < 5; i++) {
Expand Down Expand Up @@ -282,7 +290,7 @@ public void testUnsupportedStream() throws Exception {
assertEquals(TaskType.SPARSE_EMBEDDING.toString(), singleModel.get("task_type"));

try {
var events = streamInferOnMockService(modelId, TaskType.SPARSE_EMBEDDING, List.of(randomAlphaOfLength(10)));
var events = streamInferOnMockService(modelId, TaskType.SPARSE_EMBEDDING, List.of(randomAlphaOfLength(10)), null);
assertThat(events.size(), equalTo(2));
events.forEach(event -> {
switch (event.name()) {
Expand All @@ -309,7 +317,7 @@ public void testSupportedStream() throws Exception {

var input = IntStream.range(1, 2 + randomInt(8)).mapToObj(i -> randomAlphaOfLength(10)).toList();
try {
var events = streamInferOnMockService(modelId, TaskType.COMPLETION, input);
var events = streamInferOnMockService(modelId, TaskType.COMPLETION, input, VALIDATE_ELASTIC_PRODUCT_HEADER_CONSUMER);

var expectedResponses = Stream.concat(
input.stream().map(String::toUpperCase).map(str -> "{\"completion\":[{\"delta\":\"" + str + "\"}]}"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.apache.http.nio.util.SimpleInputBuffer;
import org.apache.http.protocol.HttpContext;
import org.apache.http.util.EntityUtils;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.Response;
Expand All @@ -43,6 +44,7 @@
import org.elasticsearch.rest.RestHandler;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
Expand All @@ -52,6 +54,7 @@
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Collection;
import java.util.Collections;
import java.util.Deque;
import java.util.Iterator;
import java.util.List;
Expand Down Expand Up @@ -96,6 +99,14 @@ protected Collection<Class<? extends Plugin>> nodePlugins() {
}

public static class StreamingPlugin extends Plugin implements ActionPlugin {
private final SetOnce<ThreadPool> threadPool = new SetOnce<>();

@Override
public Collection<?> createComponents(PluginServices services) {
threadPool.set(services.threadPool());
return Collections.emptyList();
}

@Override
public Collection<RestHandler> getRestHandlers(
Settings settings,
Expand All @@ -122,7 +133,7 @@ public void handleRequest(RestRequest request, RestChannel channel, NodeClient c
var publisher = new RandomPublisher(requestCount, withError);
var inferenceServiceResults = new StreamingInferenceServiceResults(publisher);
var inferenceResponse = new InferenceAction.Response(inferenceServiceResults, inferenceServiceResults.publisher());
new ServerSentEventsRestActionListener(channel).onResponse(inferenceResponse);
new ServerSentEventsRestActionListener(channel, threadPool).onResponse(inferenceResponse);
}
}, new RestHandler() {
@Override
Expand All @@ -132,7 +143,7 @@ public List<Route> routes() {

@Override
public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) {
new ServerSentEventsRestActionListener(channel).onFailure(expectedException);
new ServerSentEventsRestActionListener(channel, threadPool).onFailure(expectedException);
}
}, new RestHandler() {
@Override
Expand All @@ -143,7 +154,7 @@ public List<Route> routes() {
@Override
public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) {
var inferenceResponse = new InferenceAction.Response(new SingleInferenceServiceResults());
new ServerSentEventsRestActionListener(channel).onResponse(inferenceResponse);
new ServerSentEventsRestActionListener(channel, threadPool).onResponse(inferenceResponse);
}
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.threadpool.ExecutorBuilder;
import org.elasticsearch.threadpool.ScalingExecutorBuilder;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.action.XPackUsageFeatureAction;
Expand Down Expand Up @@ -140,6 +141,9 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP
private final SetOnce<AmazonBedrockRequestSender.Factory> amazonBedrockFactory = new SetOnce<>();
private final SetOnce<ServiceComponents> serviceComponents = new SetOnce<>();
private final SetOnce<ElasticInferenceServiceComponents> eisComponents = new SetOnce<>();
// This is mainly so that the rest handlers can access the ThreadPool in a way that avoids potential null pointers from it
// not being initialized yet
private final SetOnce<ThreadPool> threadPoolSetOnce = new SetOnce<>();
private final SetOnce<InferenceServiceRegistry> inferenceServiceRegistry = new SetOnce<>();
private final SetOnce<ShardBulkInferenceActionFilter> shardBulkInferenceActionFilter = new SetOnce<>();
private List<InferenceServiceExtension> inferenceServiceExtensions;
Expand Down Expand Up @@ -176,7 +180,7 @@ public List<RestHandler> getRestHandlers(
) {
return List.of(
new RestInferenceAction(),
new RestStreamInferenceAction(),
new RestStreamInferenceAction(threadPoolSetOnce),
new RestGetInferenceModelAction(),
new RestPutInferenceModelAction(),
new RestUpdateInferenceModelAction(),
Expand All @@ -190,6 +194,7 @@ public Collection<?> createComponents(PluginServices services) {
var throttlerManager = new ThrottlerManager(settings, services.threadPool(), services.clusterService());
var truncator = new Truncator(settings, services.clusterService());
serviceComponents.set(new ServiceComponents(services.threadPool(), throttlerManager, settings, truncator));
threadPoolSetOnce.set(services.threadPool());

var httpClientManager = HttpClientManager.create(settings, services.threadPool(), services.clusterService(), throttlerManager);
var httpRequestSenderFactory = new HttpRequestSender.Factory(serviceComponents.get(), httpClientManager, services.clusterService());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,30 @@

package org.elasticsearch.xpack.inference.rest;

import org.apache.lucene.util.SetOnce;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.rest.Scope;
import org.elasticsearch.rest.ServerlessScope;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;

import java.util.List;
import java.util.Objects;

import static org.elasticsearch.rest.RestRequest.Method.POST;
import static org.elasticsearch.xpack.inference.rest.Paths.STREAM_INFERENCE_ID_PATH;
import static org.elasticsearch.xpack.inference.rest.Paths.STREAM_TASK_TYPE_INFERENCE_ID_PATH;

@ServerlessScope(Scope.PUBLIC)
public class RestStreamInferenceAction extends BaseInferenceAction {
private final SetOnce<ThreadPool> threadPool;

public RestStreamInferenceAction(SetOnce<ThreadPool> threadPool) {
super();
this.threadPool = Objects.requireNonNull(threadPool);
}

@Override
public String getName() {
return "stream_inference_action";
Expand All @@ -38,6 +48,6 @@ protected InferenceAction.Request prepareInferenceRequest(InferenceAction.Reques

@Override
protected ActionListener<InferenceAction.Response> listener(RestChannel channel) {
return new ServerSentEventsRestActionListener(channel);
return new ServerSentEventsRestActionListener(channel, threadPool);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.common.collect.Iterators;
import org.elasticsearch.common.io.stream.BytesStream;
Expand All @@ -30,6 +32,7 @@
import org.elasticsearch.rest.RestResponse;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
Expand All @@ -39,6 +42,7 @@
import java.nio.charset.StandardCharsets;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.Flow;
import java.util.concurrent.atomic.AtomicBoolean;

Expand All @@ -55,6 +59,7 @@ public class ServerSentEventsRestActionListener implements ActionListener<Infere
private final AtomicBoolean isLastPart = new AtomicBoolean(false);
private final RestChannel channel;
private final ToXContent.Params params;
private final SetOnce<ThreadPool> threadPool;

/**
* A listener for the first part of the next entry to become available for transmission.
Expand All @@ -66,13 +71,14 @@ public class ServerSentEventsRestActionListener implements ActionListener<Infere
*/
private ActionListener<ChunkedRestResponseBodyPart> nextBodyPartListener;

public ServerSentEventsRestActionListener(RestChannel channel) {
this(channel, channel.request());
public ServerSentEventsRestActionListener(RestChannel channel, SetOnce<ThreadPool> threadPool) {
this(channel, channel.request(), threadPool);
}

public ServerSentEventsRestActionListener(RestChannel channel, ToXContent.Params params) {
public ServerSentEventsRestActionListener(RestChannel channel, ToXContent.Params params, SetOnce<ThreadPool> threadPool) {
this.channel = channel;
this.params = params;
this.threadPool = Objects.requireNonNull(threadPool);
}

@Override
Expand All @@ -99,7 +105,7 @@ protected void ensureOpen() {
}

private void initializeStream(InferenceAction.Response response) {
nextBodyPartListener = ActionListener.wrap(bodyPart -> {
ActionListener<ChunkedRestResponseBodyPart> chunkedResponseBodyActionListener = ActionListener.wrap(bodyPart -> {
// this is the first response, so we need to send the RestResponse to open the stream
// all subsequent bytes will be delivered through the nextBodyPartListener
channel.sendResponse(RestResponse.chunked(RestStatus.OK, bodyPart, this::release));
Expand All @@ -115,6 +121,12 @@ private void initializeStream(InferenceAction.Response response) {
)
);
});

nextBodyPartListener = ContextPreservingActionListener.wrapPreservingContext(
chunkedResponseBodyActionListener,
threadPool.get().getThreadContext()
);

// subscribe will call onSubscribe, which requests the first chunk
response.publisher().subscribe(subscriber);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.test.rest.FakeRestRequest;
import org.elasticsearch.test.rest.RestActionTestCase;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.junit.After;
import org.junit.Before;

import static org.elasticsearch.xpack.inference.rest.BaseInferenceActionTests.createResponse;
Expand All @@ -22,10 +25,18 @@
import static org.hamcrest.Matchers.instanceOf;

public class RestStreamInferenceActionTests extends RestActionTestCase {
private final SetOnce<ThreadPool> threadPool = new SetOnce<>();

@Before
public void setUpAction() {
controller().registerHandler(new RestStreamInferenceAction());
threadPool.set(new TestThreadPool(getTestName()));
controller().registerHandler(new RestStreamInferenceAction(threadPool));
}

@After
public void tearDownAction() {
terminate(threadPool.get());

}

public void testStreamIsTrue() {
Expand Down