From 76e5279a81e8db29491d157374ce4ad1fdacfd2c Mon Sep 17 00:00:00 2001 From: Grey Baker Date: Fri, 16 Aug 2024 11:44:17 -0400 Subject: [PATCH] Allow structured outputs via function calling --- api_integration_test.go | 76 +++++++++++++++++++++++++++++++++++++++++ chat.go | 1 + chat_test.go | 26 ++++++++++++++ 3 files changed, 103 insertions(+) diff --git a/api_integration_test.go b/api_integration_test.go index 3084268e..57f7c40f 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -239,3 +239,79 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) { } } } + +func TestChatCompletionStructuredOutputsFunctionCalling(t *testing.T) { + apiToken := os.Getenv("OPENAI_TOKEN") + if apiToken == "" { + t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") + } + + var err error + c := openai.NewClient(apiToken) + ctx := context.Background() + + resp, err := c.CreateChatCompletion( + ctx, + openai.ChatCompletionRequest{ + Model: openai.GPT4oMini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: "Please enter a string, and we will convert it into the following naming conventions:" + + "1. PascalCase: Each word starts with an uppercase letter, with no spaces or separators." + + "2. CamelCase: The first word starts with a lowercase letter, " + + "and subsequent words start with an uppercase letter, with no spaces or separators." + + "3. KebabCase: All letters are lowercase, with words separated by hyphens `-`." + + "4. SnakeCase: All letters are lowercase, with words separated by underscores `_`.", + }, + { + Role: openai.ChatMessageRoleUser, + Content: "Hello World", + }, + }, + Tools: []openai.Tool{ + { + Type: openai.ToolTypeFunction, + Function: &openai.FunctionDefinition{ + Name: "display_cases", + Strict: true, + Parameters: &jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "PascalCase": { + Type: jsonschema.String, + }, + "CamelCase": { + Type: jsonschema.String, + }, + "KebabCase": { + Type: jsonschema.String, + }, + "SnakeCase": { + Type: jsonschema.String, + }, + }, + Required: []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"}, + AdditionalProperties: false, + }, + }, + }, + }, + ToolChoice: openai.ToolChoice{ + Type: openai.ToolTypeFunction, + Function: openai.ToolFunction{ + Name: "display_cases", + }, + }, + }, + ) + checks.NoError(t, err, "CreateChatCompletion (use structured outputs response) returned error") + var result = make(map[string]string) + err = json.Unmarshal([]byte(resp.Choices[0].Message.ToolCalls[0].Function.Arguments), &result) + checks.NoError(t, err, "CreateChatCompletion (use structured outputs response) unmarshal error") + for _, key := range []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"} { + if _, ok := result[key]; !ok { + t.Errorf("key:%s does not exist.", key) + } + } +} diff --git a/chat.go b/chat.go index 826fd3bd..97c89a49 100644 --- a/chat.go +++ b/chat.go @@ -264,6 +264,7 @@ type ToolFunction struct { type FunctionDefinition struct { Name string `json:"name"` Description string `json:"description,omitempty"` + Strict bool `json:"strict,omitempty"` // Parameters is an object describing the function. // You can pass json.RawMessage to describe the schema, // or you can pass in a struct which serializes to the proper JSON schema. diff --git a/chat_test.go b/chat_test.go index 520bf5ca..37dc09d4 100644 --- a/chat_test.go +++ b/chat_test.go @@ -277,6 +277,32 @@ func TestChatCompletionsFunctions(t *testing.T) { }) checks.NoError(t, err, "CreateChatCompletion with functions error") }) + t.Run("StructuredOutputs", func(t *testing.T) { + type testMessage struct { + Count int `json:"count"` + Words []string `json:"words"` + } + msg := testMessage{ + Count: 2, + Words: []string{"hello", "world"}, + } + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo0613, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Functions: []openai.FunctionDefinition{{ + Name: "test", + Strict: true, + Parameters: &msg, + }}, + }) + checks.NoError(t, err, "CreateChatCompletion with functions error") + }) } func TestAzureChatCompletions(t *testing.T) {