Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[8.x] [Inference API] Make message content optional in unified API (#118998) #119226

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,12 @@ public record Message(Content content, String role, @Nullable String name, @Null
);

static {
PARSER.declareField(constructorArg(), (p, c) -> parseContent(p), new ParseField("content"), ObjectParser.ValueType.VALUE_ARRAY);
PARSER.declareField(
optionalConstructorArg(),
(p, c) -> parseContent(p),
new ParseField("content"),
ObjectParser.ValueType.VALUE_ARRAY
);
PARSER.declareString(constructorArg(), new ParseField("role"));
PARSER.declareString(optionalConstructorArg(), new ParseField("name"));
PARSER.declareString(optionalConstructorArg(), new ParseField("tool_call_id"));
Expand All @@ -143,7 +148,7 @@ private static Content parseContent(XContentParser parser) throws IOException {

public Message(StreamInput in) throws IOException {
this(
in.readNamedWriteable(Content.class),
in.readOptionalNamedWriteable(Content.class),
in.readString(),
in.readOptionalString(),
in.readOptionalString(),
Expand All @@ -153,7 +158,7 @@ public Message(StreamInput in) throws IOException {

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeNamedWriteable(content);
out.writeOptionalNamedWriteable(content);
out.writeString(role);
out.writeOptionalString(name);
out.writeOptionalString(toolCallId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
for (UnifiedCompletionRequest.Message message : unifiedRequest.messages()) {
builder.startObject();
{
if (message.content() instanceof UnifiedCompletionRequest.ContentString contentString) {
if (message.content() == null) {
// content is optional
} else if (message.content() instanceof UnifiedCompletionRequest.ContentString contentString) {
builder.field(CONTENT_FIELD, contentString.content());
} else if (message.content() instanceof UnifiedCompletionRequest.ContentObjects contentObjects) {
builder.startArray(CONTENT_FIELD);
Expand All @@ -77,10 +79,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.endObject();
}
builder.endArray();
} else {
throw new IllegalArgumentException(
Strings.format("Unsupported message.content class received: %s", message.content().getClass().getSimpleName())
);
}

builder.field(ROLE_FIELD, message.role());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -702,122 +702,62 @@ public void testSerializationWithBooleanFields() throws IOException {
assertJsonEquals(expectedJsonFalse, jsonStringFalse);
}

// 9. Serialization with Missing Required Fields
// Test with missing required fields to ensure appropriate exceptions are thrown.
public void testSerializationWithMissingRequiredFields() {
// Create a message with missing content (required field)
// 9. a test without the content field to show that the content field is optional
public void testSerializationWithoutContentField() throws IOException {
UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message(
null, // missing content
OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD,
null,
null,
null
);
var messageList = new ArrayList<UnifiedCompletionRequest.Message>();
messageList.add(message);
// Create the unified request
UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest(
messageList,
null, // model
null, // maxCompletionTokens
null, // stop
null, // temperature
null, // toolChoice
null, // tools
null // topP
);

// Create the unified chat input
UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true);

OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null);

// Create the entity
OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model);

// Attempt to serialize to XContent and expect an exception
try {
XContentBuilder builder = JsonXContent.contentBuilder();
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
fail("Expected an exception due to missing required fields");
} catch (NullPointerException | IOException e) {
// Expected exception
}
}

// 10. Serialization with Mixed Valid and Invalid Data
// Test with a mix of valid and invalid data to ensure the serializer handles it gracefully.
public void testSerializationWithMixedValidAndInvalidData() throws IOException {
// Create a valid message
UnifiedCompletionRequest.Message validMessage = new UnifiedCompletionRequest.Message(
new UnifiedCompletionRequest.ContentString("Valid content"),
OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD,
"validName",
"validToolCallId",
Collections.singletonList(
new UnifiedCompletionRequest.ToolCall(
"validId",
new UnifiedCompletionRequest.ToolCall.FunctionField("validArguments", "validFunctionName"),
"validType"
)
)
);

// Create an invalid message with null content
UnifiedCompletionRequest.Message invalidMessage = new UnifiedCompletionRequest.Message(
null, // invalid content
OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD,
"invalidName",
"invalidToolCallId",
"assistant",
"name\nwith\nnewlines",
"tool_call_id\twith\ttabs",
Collections.singletonList(
new UnifiedCompletionRequest.ToolCall(
"invalidId",
new UnifiedCompletionRequest.ToolCall.FunctionField("invalidArguments", "invalidFunctionName"),
"invalidType"
"id\\with\\backslashes",
new UnifiedCompletionRequest.ToolCall.FunctionField("arguments\"with\"quotes", "function_name/with/slashes"),
"type"
)
)
);
var messageList = new ArrayList<UnifiedCompletionRequest.Message>();
messageList.add(validMessage);
messageList.add(invalidMessage);
// Create the unified request with both valid and invalid messages
UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest(
messageList,
"model-name",
100L, // maxCompletionTokens
Collections.singletonList("stop"),
0.9f, // temperature
new UnifiedCompletionRequest.ToolChoiceString("tool_choice"),
Collections.singletonList(
new UnifiedCompletionRequest.Tool(
"type",
new UnifiedCompletionRequest.Tool.FunctionField(
"Fetches the weather in the given location",
"get_weather",
createParameters(),
true
)
)
),
0.8f // topP
);
messageList.add(message);
UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest(messageList, null, null, null, null, null, null, null);

// Create the unified chat input
UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true);
OpenAiChatCompletionModel model = createChatCompletionModel("test-url", "organizationId", "api-key", "test-endpoint", null);

OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null);

// Create the entity
OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model);

// Serialize to XContent and verify
try {
XContentBuilder builder = JsonXContent.contentBuilder();
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
fail("Expected an exception due to invalid data");
} catch (NullPointerException | IOException e) {
// Expected exception
}
XContentBuilder builder = JsonXContent.contentBuilder();
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);

String jsonString = Strings.toString(builder);
String expectedJson = """
{
"messages": [
{
"role": "assistant",
"name": "name\\nwith\\nnewlines",
"tool_call_id": "tool_call_id\\twith\\ttabs",
"tool_calls": [
{
"id": "id\\\\with\\\\backslashes",
"function": {
"arguments": "arguments\\"with\\"quotes",
"name": "function_name/with/slashes"
},
"type": "type"
}
]
}
],
"model": "test-endpoint",
"n": 1,
"stream": true,
"stream_options": {
"include_usage": true
}
}
""";
assertJsonEquals(jsonString, expectedJson);
}

public static Map<String, Object> createParameters() {
Expand Down