diff --git a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java index e596be626b518..32ed68953041a 100644 --- a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java +++ b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java @@ -122,7 +122,12 @@ public record Message(Content content, String role, @Nullable String name, @Null ); static { - PARSER.declareField(constructorArg(), (p, c) -> parseContent(p), new ParseField("content"), ObjectParser.ValueType.VALUE_ARRAY); + PARSER.declareField( + optionalConstructorArg(), + (p, c) -> parseContent(p), + new ParseField("content"), + ObjectParser.ValueType.VALUE_ARRAY + ); PARSER.declareString(constructorArg(), new ParseField("role")); PARSER.declareString(optionalConstructorArg(), new ParseField("name")); PARSER.declareString(optionalConstructorArg(), new ParseField("tool_call_id")); @@ -143,7 +148,7 @@ private static Content parseContent(XContentParser parser) throws IOException { public Message(StreamInput in) throws IOException { this( - in.readNamedWriteable(Content.class), + in.readOptionalNamedWriteable(Content.class), in.readString(), in.readOptionalString(), in.readOptionalString(), @@ -153,7 +158,7 @@ public Message(StreamInput in) throws IOException { @Override public void writeTo(StreamOutput out) throws IOException { - out.writeNamedWriteable(content); + out.writeOptionalNamedWriteable(content); out.writeString(role); out.writeOptionalString(name); out.writeOptionalString(toolCallId); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java index 5b7b274f2351b..67818f0958a12 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java @@ -66,7 +66,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws for (UnifiedCompletionRequest.Message message : unifiedRequest.messages()) { builder.startObject(); { - if (message.content() instanceof UnifiedCompletionRequest.ContentString contentString) { + if (message.content() == null) { + // content is optional + } else if (message.content() instanceof UnifiedCompletionRequest.ContentString contentString) { builder.field(CONTENT_FIELD, contentString.content()); } else if (message.content() instanceof UnifiedCompletionRequest.ContentObjects contentObjects) { builder.startArray(CONTENT_FIELD); @@ -77,10 +79,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.endObject(); } builder.endArray(); - } else { - throw new IllegalArgumentException( - Strings.format("Unsupported message.content class received: %s", message.content().getClass().getSimpleName()) - ); } builder.field(ROLE_FIELD, message.role()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java index f945c154ea234..2037c77a3cf2a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java @@ -702,122 +702,62 @@ public void testSerializationWithBooleanFields() throws IOException { assertJsonEquals(expectedJsonFalse, jsonStringFalse); } - // 9. Serialization with Missing Required Fields - // Test with missing required fields to ensure appropriate exceptions are thrown. - public void testSerializationWithMissingRequiredFields() { - // Create a message with missing content (required field) + // 9. a test without the content field to show that the content field is optional + public void testSerializationWithoutContentField() throws IOException { UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( - null, // missing content - OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, - null, null, - null - ); - var messageList = new ArrayList(); - messageList.add(message); - // Create the unified request - UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( - messageList, - null, // model - null, // maxCompletionTokens - null, // stop - null, // temperature - null, // toolChoice - null, // tools - null // topP - ); - - // Create the unified chat input - UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - - OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); - - // Create the entity - OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); - - // Attempt to serialize to XContent and expect an exception - try { - XContentBuilder builder = JsonXContent.contentBuilder(); - entity.toXContent(builder, ToXContent.EMPTY_PARAMS); - fail("Expected an exception due to missing required fields"); - } catch (NullPointerException | IOException e) { - // Expected exception - } - } - - // 10. Serialization with Mixed Valid and Invalid Data - // Test with a mix of valid and invalid data to ensure the serializer handles it gracefully. - public void testSerializationWithMixedValidAndInvalidData() throws IOException { - // Create a valid message - UnifiedCompletionRequest.Message validMessage = new UnifiedCompletionRequest.Message( - new UnifiedCompletionRequest.ContentString("Valid content"), - OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, - "validName", - "validToolCallId", - Collections.singletonList( - new UnifiedCompletionRequest.ToolCall( - "validId", - new UnifiedCompletionRequest.ToolCall.FunctionField("validArguments", "validFunctionName"), - "validType" - ) - ) - ); - - // Create an invalid message with null content - UnifiedCompletionRequest.Message invalidMessage = new UnifiedCompletionRequest.Message( - null, // invalid content - OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, - "invalidName", - "invalidToolCallId", + "assistant", + "name\nwith\nnewlines", + "tool_call_id\twith\ttabs", Collections.singletonList( new UnifiedCompletionRequest.ToolCall( - "invalidId", - new UnifiedCompletionRequest.ToolCall.FunctionField("invalidArguments", "invalidFunctionName"), - "invalidType" + "id\\with\\backslashes", + new UnifiedCompletionRequest.ToolCall.FunctionField("arguments\"with\"quotes", "function_name/with/slashes"), + "type" ) ) ); var messageList = new ArrayList(); - messageList.add(validMessage); - messageList.add(invalidMessage); - // Create the unified request with both valid and invalid messages - UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( - messageList, - "model-name", - 100L, // maxCompletionTokens - Collections.singletonList("stop"), - 0.9f, // temperature - new UnifiedCompletionRequest.ToolChoiceString("tool_choice"), - Collections.singletonList( - new UnifiedCompletionRequest.Tool( - "type", - new UnifiedCompletionRequest.Tool.FunctionField( - "Fetches the weather in the given location", - "get_weather", - createParameters(), - true - ) - ) - ), - 0.8f // topP - ); + messageList.add(message); + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest(messageList, null, null, null, null, null, null, null); - // Create the unified chat input UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + OpenAiChatCompletionModel model = createChatCompletionModel("test-url", "organizationId", "api-key", "test-endpoint", null); - OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); - - // Create the entity OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); - // Serialize to XContent and verify - try { - XContentBuilder builder = JsonXContent.contentBuilder(); - entity.toXContent(builder, ToXContent.EMPTY_PARAMS); - fail("Expected an exception due to invalid data"); - } catch (NullPointerException | IOException e) { - // Expected exception - } + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "role": "assistant", + "name": "name\\nwith\\nnewlines", + "tool_call_id": "tool_call_id\\twith\\ttabs", + "tool_calls": [ + { + "id": "id\\\\with\\\\backslashes", + "function": { + "arguments": "arguments\\"with\\"quotes", + "name": "function_name/with/slashes" + }, + "type": "type" + } + ] + } + ], + "model": "test-endpoint", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); } public static Map createParameters() {