diff --git a/api_integration_test.go b/api_integration_test.go index e0dea497..8968fd75 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -250,7 +250,17 @@ func TestChatCompletionResponseFormat_JSONSchemaRaw(t *testing.T) { c := openai.NewClient(apiToken) ctx := context.Background() - schema := []byte(`{"type":"object","properties":{"CamelCase":{"type":"string"},"KebabCase":{"type":"string"},"PascalCase":{"type":"string"},"SnakeCase":{"type":"string"}},"required":["PascalCase","CamelCase","KebabCase","SnakeCase"],"additionalProperties":false}`) + schema := []byte(`{ + "type": "object", + "properties": { + "CamelCase": {"type": "string"}, + "KebabCase": {"type": "string"}, + "PascalCase": {"type": "string"}, + "SnakeCase": {"type": "string"} + }, + "required": ["PascalCase", "CamelCase", "KebabCase", "SnakeCase"], + "additionalProperties": false + }`) resp, err := c.CreateChatCompletion( ctx, diff --git a/chat.go b/chat.go index 5c5a6a72..fe14958f 100644 --- a/chat.go +++ b/chat.go @@ -195,7 +195,7 @@ type ChatCompletionResponseFormatJSONSchema struct { } func (c *ChatCompletionResponseFormatJSONSchema) MarshalJSON() ([]byte, error) { - type Alias ChatCompletionResponseFormatJSONSchema + type Alias ChatCompletionResponseFormatJSONSchema // prevent recursive marshalling var data struct { *Alias Schema interface{} `json:"schema,omitempty"` @@ -203,13 +203,14 @@ func (c *ChatCompletionResponseFormatJSONSchema) MarshalJSON() ([]byte, error) { data.Alias = (*Alias)(c) - data.Schema = c.Schema if c.SchemaRaw != nil { var rawSchema interface{} if err := json.Unmarshal(*c.SchemaRaw, &rawSchema); err != nil { return nil, err } data.Schema = rawSchema + } else { + data.Schema = c.Schema } return json.Marshal(data) diff --git a/chat_test.go b/chat_test.go index e73418f3..fc1ee568 100644 --- a/chat_test.go +++ b/chat_test.go @@ -528,7 +528,7 @@ func TestFinishReason(t *testing.T) { } } -func TestChatCompletionResponseFormat_JSONSchema_MarshalJSON(t *testing.T) { +func TestChatCompletionResponseFormatJSONSchemaMarshalJSON(t *testing.T) { tests := []struct { name string input openai.ChatCompletionResponseFormatJSONSchema @@ -542,7 +542,7 @@ func TestChatCompletionResponseFormat_JSONSchema_MarshalJSON(t *testing.T) { SchemaRaw: nil, Strict: false, }, - expected: `{"name":"TestName","strict":false}`, + expected: `{"name":"TestName","strict":false,"schema":{}}`, wantErr: false, }, { @@ -552,7 +552,7 @@ func TestChatCompletionResponseFormat_JSONSchema_MarshalJSON(t *testing.T) { SchemaRaw: func() *[]byte { b := []byte(`{"key":"value"}`); return &b }(), Strict: true, }, - expected: `{"name":"TestName","schema":{"key":"value"},"strict":true}`, + expected: `{"name":"TestName","strict":true,"schema":{"key":"value"}}`, wantErr: false, }, { @@ -570,12 +570,21 @@ func TestChatCompletionResponseFormat_JSONSchema_MarshalJSON(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.input.MarshalJSON() + if (err != nil) != tt.wantErr { t.Errorf("MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) return } + + if tt.wantErr { + if len(got) != 0 { + t.Errorf("Expected empty output on error, got: %s", string(got)) + } + return + } + if string(got) != tt.expected { - t.Errorf("MarshalJSON() got = %v, expected %v", string(got), tt.expected) + t.Errorf("MarshalJSON() got = %s, expected %s", string(got), tt.expected) } }) }