Skip to content

Commit

Permalink
[Inference API] Add unified api for chat completions (elastic#117589)
Browse files Browse the repository at this point in the history
* Adding some shell classes

* modeling the request objects

* Writeable changes to schema

* Working parsing tests

* Creating a new action

* Add outbound request writing (WIP)

* Improvements to request serialization

* Adding separate transport classes

* separate out unified request and combine inputs

* Reworking unified inputs

* Adding unsupported operation calls

* Fixing parsing logic

* get the build working

* Update docs/changelog/117589.yaml

* Fixing injection issue

* Allowing model to be overridden but not working yet

* Fixing issues

* Switch field name for tool

* Add suport for toolCalls and refusal in streaming completion

* Working tool call response

* Separate unified and legacy code paths

* Updated the parser, but there are some class cast exceptions to fix

* Refactoring tests and request entities

* Parse response from OpenAI

* Removing unused request classes

* precommit

* Adding tests for UnifiedCompletionAction Request

* Refactoring stop to be a list of strings

* Testing for OpenAI response parsing

* Refactoring transport action tests to test unified validation code

* Fixing various tests

* Fixing license header

* Reformat streaming results

* Finalize response format

* remove debug logs

* remove changes for debugging

* Task type and base inference action tests

* Adding openai service tests

* Adding model tests

* tests for StreamingUnifiedChatCompletionResultsTests toXContentChunked

* Fixing change log and removing commented out code

* Switch usage to accept null

* Adding test for TestStreamingCompletionServiceExtension

* Avoid serializing empty lists + request entity tests

* Register named writeables from UnifiedCompletionRequest

* Removing commented code

* Clean up and add more of an explination

* remove duplicate test

* remove old todos

* Refactoring some duplication

* Adding javadoc

* Addressing feedback

---------

Co-authored-by: Jonathan Buttner <[email protected]>
Co-authored-by: Jonathan Buttner <[email protected]>
(cherry picked from commit 467fdb8)

# Conflicts:
#	x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java
#	x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java
  • Loading branch information
maxhniebergall committed Dec 16, 2024
1 parent 2d10bfb commit f6af65f
Show file tree
Hide file tree
Showing 105 changed files with 5,644 additions and 871 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/117589.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 117589
summary: "Add Inference Unified API for chat completions for OpenAI"
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.common.collect.Iterators;
import org.elasticsearch.xcontent.ToXContent;

import java.util.Collections;
import java.util.Iterator;

public enum ChunkedToXContentHelper {
Expand Down Expand Up @@ -53,6 +54,14 @@ public static Iterator<ToXContent> field(String name, String value) {
return Iterators.single(((builder, params) -> builder.field(name, value)));
}

public static Iterator<ToXContent> optionalField(String name, String value) {
if (value == null) {
return Collections.emptyIterator();
} else {
return field(name, value);
}
}

/**
* Creates an Iterator of a single ToXContent object that serializes the given object as a single chunk. Just wraps {@link
* Iterators#single}, but still useful because it avoids any type ambiguity.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,23 @@ void infer(
);

/**
* Perform completion inference on the model using the unified schema.
*
* @param model The model
* @param request Parameters for the request
* @param timeout The timeout for the request
* @param listener Inference result listener
*/
void unifiedCompletionInfer(
Model model,
UnifiedCompletionRequest request,
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
);

/**
* Chunk long text.
*
* @param model The model
* @param query Inference query, mainly for re-ranking
* @param input Inference input
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ public static TaskType fromString(String name) {
}

public static TaskType fromStringOrStatusException(String name) {
if (name == null) {
throw new ElasticsearchStatusException("Task type must not be null", RestStatus.BAD_REQUEST);
}

try {
TaskType taskType = TaskType.fromString(name);
return Objects.requireNonNull(taskType);
Expand Down
Loading

0 comments on commit f6af65f

Please sign in to comment.