Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jaybcee committed Dec 23, 2024
1 parent 0feba86 commit 93cc995
Show file tree
Hide file tree
Showing 9 changed files with 63 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,19 @@
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
import org.elasticsearch.xpack.inference.telemetry.TraceContextAware;
import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler;

import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.Objects;

public class ElasticInferenceServiceSparseEmbeddingsRequest implements ElasticInferenceServiceRequest, TraceContextAware {
public class ElasticInferenceServiceSparseEmbeddingsRequest implements ElasticInferenceServiceRequest {

private final URI uri;

private final ElasticInferenceServiceSparseEmbeddingsModel model;

private final Truncator.TruncationResult truncationResult;
private final Truncator truncator;

private final TraceContext traceContext;
private final TraceContextHandler traceContextHandler;

public ElasticInferenceServiceSparseEmbeddingsRequest(
Truncator truncator,
Expand All @@ -45,7 +42,7 @@ public ElasticInferenceServiceSparseEmbeddingsRequest(
this.truncationResult = truncationResult;
this.model = Objects.requireNonNull(model);
this.uri = model.uri();
this.traceContext = traceContext;
this.traceContextHandler = new TraceContextHandler(traceContext);
}

@Override
Expand All @@ -56,13 +53,16 @@ public HttpRequest createHttpRequest() {
ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
httpPost.setEntity(byteEntity);

propagateTraceContext(httpPost);

traceContextHandler.propagateTraceContext(httpPost);
httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));

return new HttpRequest(httpPost, getInferenceEntityId());
}

public TraceContext getTraceContext() {
return traceContextHandler.traceContext();
}

@Override
public String getInferenceEntityId() {
return model.getInferenceEntityId();
Expand All @@ -73,20 +73,15 @@ public URI getURI() {
return this.uri;
}

@Override
public TraceContext getTraceContext() {
return traceContext;
}

@Override
public Request truncate() {
var truncatedInput = truncator.truncate(truncationResult.input());

return new ElasticInferenceServiceSparseEmbeddingsRequest(truncator, truncatedInput, model, traceContext);
return new ElasticInferenceServiceSparseEmbeddingsRequest(truncator, truncatedInput, model, traceContextHandler.traceContext());
}

@Override
public boolean[] getTruncationInfo() {
return truncationResult.truncated().clone();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,17 @@
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
import org.elasticsearch.xpack.inference.telemetry.TraceContextAware;
import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler;

import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.Objects;

public class ElasticInferenceServiceUnifiedChatCompletionRequest implements TraceContextAware, Request {
public class ElasticInferenceServiceUnifiedChatCompletionRequest implements Request {

private final ElasticInferenceServiceCompletionModel model;
private final UnifiedChatInput unifiedChatInput;
private final URI uri;
private final TraceContext traceContext;
private final TraceContextHandler traceContextHandler;

public ElasticInferenceServiceUnifiedChatCompletionRequest(
UnifiedChatInput unifiedChatInput,
Expand All @@ -38,33 +37,28 @@ public ElasticInferenceServiceUnifiedChatCompletionRequest(
) {
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
this.model = Objects.requireNonNull(model);
this.uri = model.uri();
this.traceContext = traceContext;

this.traceContextHandler = new TraceContextHandler(traceContext);
}

@Override
public HttpRequest createHttpRequest() {
var httpPost = new HttpPost(uri);
var httpPost = new HttpPost(model.uri());
var requestEntity = Strings.toString(
new ElasticInferenceServiceUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId())
);

ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
httpPost.setEntity(byteEntity);

if (traceContext != null) {
propagateTraceContext(httpPost);
}

traceContextHandler.propagateTraceContext(httpPost);
httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));

return new HttpRequest(httpPost, getInferenceEntityId());
}

@Override
public URI getURI() {
return uri;
return model.uri();
}

@Override
Expand All @@ -88,9 +82,4 @@ public String getInferenceEntityId() {
public boolean isStreaming() {
return true;
}

@Override
public TraceContext getTraceContext() {
return traceContext;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedRequest.maxCompletionTokens());
}

// Underlying providers except OpenAI only return 1 possible choice.
// Underlying providers expect OpenAI to only return 1 possible choice.
builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1);

if (unifiedRequest.stop() != null && unifiedRequest.stop().isEmpty() == false) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import org.elasticsearch.xpack.inference.external.action.elastic.ElasticInferenceServiceActionVisitor;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.rest.RestStatus;

import java.net.URI;
import java.net.URISyntaxException;
Expand Down Expand Up @@ -57,12 +59,7 @@ public ElasticInferenceServiceSparseEmbeddingsModel(
ElasticInferenceServiceSparseEmbeddingsServiceSettings serviceSettings
) {
super(model, serviceSettings);

try {
this.uri = createUri();
} catch (URISyntaxException e) {
throw new RuntimeException(e);
}
this.uri = createUri();
}

ElasticInferenceServiceSparseEmbeddingsModel(
Expand All @@ -80,12 +77,7 @@ public ElasticInferenceServiceSparseEmbeddingsModel(
serviceSettings,
elasticInferenceServiceComponents
);

try {
this.uri = createUri();
} catch (URISyntaxException e) {
throw new RuntimeException(e);
}
this.uri = createUri();
}

@Override
Expand All @@ -102,19 +94,29 @@ public URI uri() {
return uri;
}

private URI createUri() throws URISyntaxException {
private URI createUri() throws ElasticsearchStatusException {
String modelId = getServiceSettings().modelId();
String modelIdUriPath;

switch (modelId) {
case ElserModels.ELSER_V2_MODEL -> modelIdUriPath = "ELSERv2";
default -> throw new IllegalArgumentException(
String.format(Locale.ROOT, "Unsupported model for %s [%s]", ELASTIC_INFERENCE_SERVICE_IDENTIFIER, modelId)
default -> throw new ElasticsearchStatusException(
String.format(Locale.ROOT, "Unsupported model for %s [%s]", ELASTIC_INFERENCE_SERVICE_IDENTIFIER, modelId),
RestStatus.BAD_REQUEST
);
}

return new URI(
elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/sparse-text-embeddings/" + modelIdUriPath
);
try {
// TODO, consider transforming the base URL into a URI for better error handling.
return new URI(
elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/sparse-text-embeddings/" + modelIdUriPath
);
} catch (URISyntaxException e) {
throw new ElasticsearchStatusException(
"Failed to create URI for sparse embeddings service: " + e.getMessage(),
RestStatus.BAD_REQUEST,
e
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.inference.services.elastic.completion;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.EmptySecretSettings;
import org.elasticsearch.inference.EmptyTaskSettings;
Expand All @@ -16,6 +17,7 @@
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel;
Expand Down Expand Up @@ -68,12 +70,8 @@ public ElasticInferenceServiceCompletionModel(
ElasticInferenceServiceCompletionServiceSettings serviceSettings
) {
super(model, serviceSettings);
this.uri = createUri();

try {
this.uri = createUri();
} catch (URISyntaxException e) {
throw new RuntimeException(e);
}
}

ElasticInferenceServiceCompletionModel(
Expand All @@ -92,11 +90,8 @@ public ElasticInferenceServiceCompletionModel(
elasticInferenceServiceComponents
);

try {
this.uri = createUri();
} catch (URISyntaxException e) {
throw new RuntimeException(e);
}
this.uri = createUri();

}

@Override
Expand All @@ -108,9 +103,18 @@ public URI uri() {
return uri;
}

private URI createUri() throws URISyntaxException {
return new URI(elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/chat/completions");
private URI createUri() throws ElasticsearchStatusException {
try {
// TODO, consider transforming the base URL into a URI for better error handling.
return new URI(elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/chat/completions");
} catch (URISyntaxException e) {
throw new ElasticsearchStatusException(
"Failed to create URI for completion service: " + e.getMessage(),
RestStatus.BAD_REQUEST,
e
);
}
}

// TODO create the Configuration class?
// TODO create/refactor the Configuration class to be extensible for different task types (i.e completion, sparse embeddings).
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public class ElasticInferenceServiceCompletionServiceSettings extends FilteredXC
public static final String NAME = "elastic_inference_service_completion_service_settings";

// TODO what value do we put here?
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(1_000);
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(240L);

public static ElasticInferenceServiceCompletionServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
ValidationException validationException = new ValidationException();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.telemetry;

import org.apache.http.client.methods.HttpPost;
import org.elasticsearch.tasks.Task;

public interface TraceContextAware {
TraceContext getTraceContext();
public record TraceContextHandler(TraceContext traceContext) {

default void propagateTraceContext(HttpPost httpPost) {
TraceContext traceContext = this.getTraceContext();
public void propagateTraceContext(HttpPost httpPost) {
if (traceContext == null) {
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@
public class ElasticInferenceServiceUnifiedChatCompletionRequestEntityTests extends ESTestCase {

private static final String ROLE = "user";
private static final String USER = "a_user";

// TODO remove if EIS doesn't use the model and user fields
public void testModelUserFieldsSerialization() throws IOException {
UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message(
new UnifiedCompletionRequest.ContentString("Hello, world!"),
Expand All @@ -43,7 +41,7 @@ public void testModelUserFieldsSerialization() throws IOException {
var unifiedRequest = UnifiedCompletionRequest.of(messageList);

UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true);
OpenAiChatCompletionModel model = createChatCompletionModel("test-url", "organizationId", "api-key", "test-endpoint", USER);
OpenAiChatCompletionModel model = createChatCompletionModel("test-url", "organizationId", "api-key", "test-endpoint", null);

OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model);

Expand All @@ -64,8 +62,7 @@ public void testModelUserFieldsSerialization() throws IOException {
"stream": true,
"stream_options": {
"include_usage": true
},
"user": "a_user"
}
}
""";
assertJsonEquals(jsonString, expectedJson);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public void testFromMap() {
ConfigurationParseContext.REQUEST
);

assertThat(serviceSettings, is(new ElasticInferenceServiceCompletionServiceSettings(modelId, new RateLimitSettings(1000))));
assertThat(serviceSettings, is(new ElasticInferenceServiceCompletionServiceSettings(modelId, new RateLimitSettings(240L))));
}

public void testFromMap_MissingModelId_ThrowsException() {
Expand Down

0 comments on commit 93cc995

Please sign in to comment.