From a1ed57c95b9f12e3a5727d378a05c7281368cb75 Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Wed, 18 Dec 2024 10:06:12 -0500 Subject: [PATCH 1/4] Allow for null/empty content field --- .../org/elasticsearch/inference/UnifiedCompletionRequest.java | 2 +- .../openai/OpenAiUnifiedChatCompletionRequestEntity.java | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java index e596be626b518..71881b105c1d8 100644 --- a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java +++ b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java @@ -122,7 +122,7 @@ public record Message(Content content, String role, @Nullable String name, @Null ); static { - PARSER.declareField(constructorArg(), (p, c) -> parseContent(p), new ParseField("content"), ObjectParser.ValueType.VALUE_ARRAY); + PARSER.declareField(optionalConstructorArg(), (p, c) -> parseContent(p), new ParseField("content"), ObjectParser.ValueType.VALUE_ARRAY); PARSER.declareString(constructorArg(), new ParseField("role")); PARSER.declareString(optionalConstructorArg(), new ParseField("name")); PARSER.declareString(optionalConstructorArg(), new ParseField("tool_call_id")); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java index 50339bf851f7d..f28c1b3fe8a55 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java @@ -78,6 +78,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } builder.endArray(); } + case null -> { + // do nothing + } } builder.field(ROLE_FIELD, message.role()); From 44bb102e401daa53275ba794e3bf220086d6f723 Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Wed, 18 Dec 2024 12:50:29 -0500 Subject: [PATCH 2/4] remove tests which checked for null content --- ...ifiedChatCompletionRequestEntityTests.java | 118 ------------------ 1 file changed, 118 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java index f945c154ea234..3fbe0986a46f1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java @@ -702,124 +702,6 @@ public void testSerializationWithBooleanFields() throws IOException { assertJsonEquals(expectedJsonFalse, jsonStringFalse); } - // 9. Serialization with Missing Required Fields - // Test with missing required fields to ensure appropriate exceptions are thrown. - public void testSerializationWithMissingRequiredFields() { - // Create a message with missing content (required field) - UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( - null, // missing content - OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, - null, - null, - null - ); - var messageList = new ArrayList(); - messageList.add(message); - // Create the unified request - UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( - messageList, - null, // model - null, // maxCompletionTokens - null, // stop - null, // temperature - null, // toolChoice - null, // tools - null // topP - ); - - // Create the unified chat input - UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - - OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); - - // Create the entity - OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); - - // Attempt to serialize to XContent and expect an exception - try { - XContentBuilder builder = JsonXContent.contentBuilder(); - entity.toXContent(builder, ToXContent.EMPTY_PARAMS); - fail("Expected an exception due to missing required fields"); - } catch (NullPointerException | IOException e) { - // Expected exception - } - } - - // 10. Serialization with Mixed Valid and Invalid Data - // Test with a mix of valid and invalid data to ensure the serializer handles it gracefully. - public void testSerializationWithMixedValidAndInvalidData() throws IOException { - // Create a valid message - UnifiedCompletionRequest.Message validMessage = new UnifiedCompletionRequest.Message( - new UnifiedCompletionRequest.ContentString("Valid content"), - OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, - "validName", - "validToolCallId", - Collections.singletonList( - new UnifiedCompletionRequest.ToolCall( - "validId", - new UnifiedCompletionRequest.ToolCall.FunctionField("validArguments", "validFunctionName"), - "validType" - ) - ) - ); - - // Create an invalid message with null content - UnifiedCompletionRequest.Message invalidMessage = new UnifiedCompletionRequest.Message( - null, // invalid content - OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, - "invalidName", - "invalidToolCallId", - Collections.singletonList( - new UnifiedCompletionRequest.ToolCall( - "invalidId", - new UnifiedCompletionRequest.ToolCall.FunctionField("invalidArguments", "invalidFunctionName"), - "invalidType" - ) - ) - ); - var messageList = new ArrayList(); - messageList.add(validMessage); - messageList.add(invalidMessage); - // Create the unified request with both valid and invalid messages - UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( - messageList, - "model-name", - 100L, // maxCompletionTokens - Collections.singletonList("stop"), - 0.9f, // temperature - new UnifiedCompletionRequest.ToolChoiceString("tool_choice"), - Collections.singletonList( - new UnifiedCompletionRequest.Tool( - "type", - new UnifiedCompletionRequest.Tool.FunctionField( - "Fetches the weather in the given location", - "get_weather", - createParameters(), - true - ) - ) - ), - 0.8f // topP - ); - - // Create the unified chat input - UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - - OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); - - // Create the entity - OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); - - // Serialize to XContent and verify - try { - XContentBuilder builder = JsonXContent.contentBuilder(); - entity.toXContent(builder, ToXContent.EMPTY_PARAMS); - fail("Expected an exception due to invalid data"); - } catch (NullPointerException | IOException e) { - // Expected exception - } - } - public static Map createParameters() { Map parameters = new LinkedHashMap<>(); parameters.put("type", "object"); From 902352ac82dfa1d0b0a54ef7f0d86d6733691ba5 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Wed, 18 Dec 2024 17:59:04 +0000 Subject: [PATCH 3/4] [CI] Auto commit changes from spotless --- .../elasticsearch/inference/UnifiedCompletionRequest.java | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java index 71881b105c1d8..fd94a47edccd7 100644 --- a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java +++ b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java @@ -122,7 +122,12 @@ public record Message(Content content, String role, @Nullable String name, @Null ); static { - PARSER.declareField(optionalConstructorArg(), (p, c) -> parseContent(p), new ParseField("content"), ObjectParser.ValueType.VALUE_ARRAY); + PARSER.declareField( + optionalConstructorArg(), + (p, c) -> parseContent(p), + new ParseField("content"), + ObjectParser.ValueType.VALUE_ARRAY + ); PARSER.declareString(constructorArg(), new ParseField("role")); PARSER.declareString(optionalConstructorArg(), new ParseField("name")); PARSER.declareString(optionalConstructorArg(), new ParseField("tool_call_id")); From 6ce7ac6c7c4695e105a9d6266ddc34e2f2bf9739 Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Thu, 19 Dec 2024 12:50:24 -0500 Subject: [PATCH 4/4] Improvements from review --- .../inference/UnifiedCompletionRequest.java | 4 +- ...ifiedChatCompletionRequestEntityTests.java | 58 +++++++++++++++++++ 2 files changed, 60 insertions(+), 2 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java index fd94a47edccd7..32ed68953041a 100644 --- a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java +++ b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java @@ -148,7 +148,7 @@ private static Content parseContent(XContentParser parser) throws IOException { public Message(StreamInput in) throws IOException { this( - in.readNamedWriteable(Content.class), + in.readOptionalNamedWriteable(Content.class), in.readString(), in.readOptionalString(), in.readOptionalString(), @@ -158,7 +158,7 @@ public Message(StreamInput in) throws IOException { @Override public void writeTo(StreamOutput out) throws IOException { - out.writeNamedWriteable(content); + out.writeOptionalNamedWriteable(content); out.writeString(role); out.writeOptionalString(name); out.writeOptionalString(toolCallId); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java index 3fbe0986a46f1..2037c77a3cf2a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java @@ -702,6 +702,64 @@ public void testSerializationWithBooleanFields() throws IOException { assertJsonEquals(expectedJsonFalse, jsonStringFalse); } + // 9. a test without the content field to show that the content field is optional + public void testSerializationWithoutContentField() throws IOException { + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + null, + "assistant", + "name\nwith\nnewlines", + "tool_call_id\twith\ttabs", + Collections.singletonList( + new UnifiedCompletionRequest.ToolCall( + "id\\with\\backslashes", + new UnifiedCompletionRequest.ToolCall.FunctionField("arguments\"with\"quotes", "function_name/with/slashes"), + "type" + ) + ) + ); + var messageList = new ArrayList(); + messageList.add(message); + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest(messageList, null, null, null, null, null, null, null); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + OpenAiChatCompletionModel model = createChatCompletionModel("test-url", "organizationId", "api-key", "test-endpoint", null); + + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "role": "assistant", + "name": "name\\nwith\\nnewlines", + "tool_call_id": "tool_call_id\\twith\\ttabs", + "tool_calls": [ + { + "id": "id\\\\with\\\\backslashes", + "function": { + "arguments": "arguments\\"with\\"quotes", + "name": "function_name/with/slashes" + }, + "type": "type" + } + ] + } + ], + "model": "test-endpoint", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + public static Map createParameters() { Map parameters = new LinkedHashMap<>(); parameters.put("type", "object");