From 22ce52226344cfe06f8596214860b69e503a647d Mon Sep 17 00:00:00 2001 From: Robby Date: Tue, 27 Aug 2024 12:19:51 -0400 Subject: [PATCH] add openai strict (structured outputs) mode --- examples/auto_ticketer/main.go | 2 +- go.mod | 4 +- go.sum | 2 + pkg/instructor/mode_enum.go | 13 +-- pkg/instructor/openai_chat.go | 153 ++++++--------------------- pkg/instructor/openai_chat_stream.go | 25 +---- 6 files changed, 48 insertions(+), 151 deletions(-) diff --git a/examples/auto_ticketer/main.go b/examples/auto_ticketer/main.go index 0a1803a..50fb20c 100644 --- a/examples/auto_ticketer/main.go +++ b/examples/auto_ticketer/main.go @@ -47,7 +47,7 @@ func main() { client := instructor.FromOpenAI( openai.NewClient(os.Getenv("OPENAI_API_KEY")), - instructor.WithMode(instructor.ModeStructuredOutputs), + instructor.WithMode(instructor.ModeJSONStrict), instructor.WithMaxRetries(0), ) diff --git a/go.mod b/go.mod index 0c66d81..45c16c5 100644 --- a/go.mod +++ b/go.mod @@ -7,11 +7,9 @@ require ( github.com/go-playground/validator/v10 v10.21.0 github.com/invopop/jsonschema v0.12.0 github.com/liushuangls/go-anthropic/v2 v2.1.0 - github.com/sashabaranov/go-openai v1.28.1 + github.com/sashabaranov/go-openai v1.29.0 ) -replace github.com/sashabaranov/go-openai => ../go-openai - require ( github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.1 // indirect diff --git a/go.sum b/go.sum index cc5deb4..4c8050f 100644 --- a/go.sum +++ b/go.sum @@ -29,6 +29,8 @@ github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0 github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/sashabaranov/go-openai v1.29.0 h1:eBH6LSjtX4md5ImDCX8hNhHQvaRf22zujiERoQpsvLo= +github.com/sashabaranov/go-openai v1.29.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= diff --git a/pkg/instructor/mode_enum.go b/pkg/instructor/mode_enum.go index ebd9df5..1a776ea 100644 --- a/pkg/instructor/mode_enum.go +++ b/pkg/instructor/mode_enum.go @@ -3,10 +3,11 @@ package instructor type Mode = string const ( - ModeToolCall Mode = "tool_call" - ModeJSON Mode = "json_mode" - ModeStructuredOutputs Mode = "structured_outputs_mode" - ModeJSONSchema Mode = "json_schema_mode" - ModeMarkdownJSON Mode = "markdown_json_mode" - ModeDefault Mode = ModeJSONSchema + ModeToolCall Mode = "tool_call_mode" + ModeToolCallStrict Mode = "tool_call_strict_mode" + ModeJSON Mode = "json_mode" + ModeJSONStrict Mode = "json_strict_mode" + ModeJSONSchema Mode = "json_schema_mode" + ModeMarkdownJSON Mode = "markdown_json_mode" + ModeDefault Mode = ModeJSONSchema ) diff --git a/pkg/instructor/openai_chat.go b/pkg/instructor/openai_chat.go index f7c018d..0c32071 100644 --- a/pkg/instructor/openai_chat.go +++ b/pkg/instructor/openai_chat.go @@ -8,7 +8,6 @@ import ( "github.com/invopop/jsonschema" openai "github.com/sashabaranov/go-openai" - openaiJSONSchema "github.com/sashabaranov/go-openai/jsonschema" ) func (i *InstructorOpenAI) CreateChatCompletion( @@ -43,10 +42,12 @@ func (i *InstructorOpenAI) chat(ctx context.Context, request interface{}, schema switch i.Mode() { case ModeToolCall: - return i.chatToolCall(ctx, &req, schema) + return i.chatToolCall(ctx, &req, schema, false) + case ModeToolCallStrict: + return i.chatToolCall(ctx, &req, schema, true) case ModeJSON: return i.chatJSON(ctx, &req, schema) - case ModeStructuredOutputs: + case ModeJSONStrict: return i.chatJSONStrict(ctx, &req, schema) case ModeJSONSchema: return i.chatJSONSchema(ctx, &req, schema) @@ -55,9 +56,9 @@ func (i *InstructorOpenAI) chat(ctx context.Context, request interface{}, schema } } -func (i *InstructorOpenAI) chatToolCall(ctx context.Context, request *openai.ChatCompletionRequest, schema *Schema) (string, *openai.ChatCompletionResponse, error) { +func (i *InstructorOpenAI) chatToolCall(ctx context.Context, request *openai.ChatCompletionRequest, schema *Schema, strict bool) (string, *openai.ChatCompletionResponse, error) { - request.Tools = createOpenAITools(schema) + request.Tools = createOpenAITools(schema, strict) resp, err := i.Client.CreateChatCompletion(ctx, *request) if err != nil { @@ -123,31 +124,6 @@ func (i *InstructorOpenAI) chatJSON(ctx context.Context, request *openai.ChatCom func (i *InstructorOpenAI) chatJSONStrict(ctx context.Context, request *openai.ChatCompletionRequest, schema *Schema) (string, *openai.ChatCompletionResponse, error) { - // var s []byte - // s, _ = json.MarshalIndent(schema, "", " ") - // fmt.Println(string(s)) - - // oaiSchema := convertToOpenAIJSONSchema(schema.Schema) - // fmt.Printf(` - // Type %v - // Description %v - // Enum %v - // Properties %v - // Required %v - // Items %v - // AdditionalProperties %v - // `, - // oaiSchema.Type, - // oaiSchema.Description, - // oaiSchema.Enum, - // oaiSchema.Properties, - // oaiSchema.Required, - // oaiSchema.Items, - // oaiSchema.AdditionalProperties, - // ) - // s, _ = json.MarshalIndent(oaiSchema, "", " ") - // fmt.Println(string(s)) - structName := schema.NameFromRef() type SchemaWrapper struct { @@ -158,21 +134,14 @@ func (i *InstructorOpenAI) chatJSONStrict(ctx context.Context, request *openai.C Definitions *jsonschema.Definitions `json:"$defs"` } - required := []string{structName} - // // for p := schema.Definitions.; p != nil; p.Next() { - // for k := range schema.Definitions { - // required = append(required, k) - // } - - properties := make(jsonschema.Definitions) - properties[structName] = schema.Definitions[structName] - schemaWrapper := SchemaWrapper{ - Type: "object", - Required: required, + Type: "object", + Required: []string{structName}, + Definitions: &schema.Schema.Definitions, + Properties: &jsonschema.Definitions{ + structName: schema.Definitions[structName], + }, AdditionalProperties: false, - Properties: &properties, - Definitions: &schema.Schema.Definitions, } rawSchema, _ := json.Marshal(schemaWrapper) @@ -180,10 +149,10 @@ func (i *InstructorOpenAI) chatJSONStrict(ctx context.Context, request *openai.C request.ResponseFormat = &openai.ChatCompletionResponseFormat{ Type: openai.ChatCompletionResponseFormatTypeJSONSchema, JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{ - Name: schema.NameFromRef(), + Name: structName, Description: schema.Description, + Schema: json.RawMessage(rawSchema), Strict: true, - SchemaRaw: toPtr(rawSchema), }, } @@ -194,15 +163,6 @@ func (i *InstructorOpenAI) chatJSONStrict(ctx context.Context, request *openai.C text := resp.Choices[0].Message.Content - // TODO: - /* - Get struct contents inside: - { - "MyStructName": { - ... // what we want to marshall into struct - } - } - */ resMap := make(map[string]any) _ = json.Unmarshal([]byte(text), &resMap) @@ -226,73 +186,6 @@ func (i *InstructorOpenAI) chatJSONSchema(ctx context.Context, request *openai.C return text, &resp, nil } -func convertToOpenAIJSONSchema(schema *jsonschema.Schema) *openaiJSONSchema.Definition { - - oaiSchema := openaiJSONSchema.Definition{} - - // Initialize properties map - oaiSchema.Properties = make(map[string]openaiJSONSchema.Definition) - - // Convert type; default to object - if schema.Type != "" { - oaiSchema.Type = openaiJSONSchema.DataType(schema.Type) - } else { - oaiSchema.Type = openaiJSONSchema.Object - } - - // Convert description - oaiSchema.Description = schema.Description - - // Convert enum - if schema.Enum != nil { - oaiSchema.Enum = make([]string, len(schema.Enum)) - for i, v := range schema.Enum { - oaiSchema.Enum[i] = fmt.Sprintf("%v", v) - } - } - - // Convert properties - if schema.Properties != nil { - for p := schema.Properties.Oldest(); p != nil; p = p.Next() { - key, value := p.Key, p.Value - propertySchema := convertToOpenAIJSONSchema(value) - oaiSchema.Properties[key] = *propertySchema - } - } - - // Convert items - if schema.Items != nil { - itemsSchema := convertToOpenAIJSONSchema(schema.Items) - oaiSchema.Items = itemsSchema - } - - // Convert additional properties - if schema.AdditionalProperties != nil { - additionalPropertiesSchema := convertToOpenAIJSONSchema(schema.AdditionalProperties) - oaiSchema.AdditionalProperties = additionalPropertiesSchema - } - - // Convert defintions - if schema.Definitions != nil { - for key, value := range schema.Definitions { - oaiSchema.Required = append(oaiSchema.Required, key) - fmt.Printf("%+v\n", oaiSchema.Required) - - definitionSchema := convertToOpenAIJSONSchema(value) - oaiSchema.Properties[key] = *definitionSchema - } - } - - if len(oaiSchema.Properties) > 0 { - oaiSchema.Required = []string{} - for key := range oaiSchema.Properties { - oaiSchema.Required = append(oaiSchema.Required, key) - } - } - - return &oaiSchema -} - func (i *InstructorOpenAI) emptyResponseWithUsageSum(usage *UsageSum) interface{} { return &openai.ChatCompletionResponse{ Usage: openai.Usage{ @@ -357,6 +250,24 @@ Make sure to return an instance of the JSON, not the schema itself return msg } +func createOpenAITools(schema *Schema, strict bool) []openai.Tool { + tools := make([]openai.Tool, 0, len(schema.Functions)) + for _, function := range schema.Functions { + f := openai.FunctionDefinition{ + Name: function.Name, + Description: function.Description, + Parameters: function.Parameters, + Strict: strict, + } + t := openai.Tool{ + Type: "function", + Function: &f, + } + tools = append(tools, t) + } + return tools +} + func nilOpenaiRespWithUsage(resp *openai.ChatCompletionResponse) *openai.ChatCompletionResponse { if resp == nil { return nil diff --git a/pkg/instructor/openai_chat_stream.go b/pkg/instructor/openai_chat_stream.go index 8a46c0e..72120cf 100644 --- a/pkg/instructor/openai_chat_stream.go +++ b/pkg/instructor/openai_chat_stream.go @@ -36,7 +36,9 @@ func (i *InstructorOpenAI) chatStream(ctx context.Context, request interface{}, switch i.Mode() { case ModeToolCall: - return i.chatToolCallStream(ctx, &req, schema) + return i.chatToolCallStream(ctx, &req, schema, false) + case ModeToolCallStrict: + return i.chatToolCallStream(ctx, &req, schema, true) case ModeJSON: return i.chatJSONStream(ctx, &req, schema) case ModeJSONSchema: @@ -46,8 +48,8 @@ func (i *InstructorOpenAI) chatStream(ctx context.Context, request interface{}, } } -func (i *InstructorOpenAI) chatToolCallStream(ctx context.Context, request *openai.ChatCompletionRequest, schema *Schema) (<-chan string, error) { - request.Tools = createOpenAITools(schema) +func (i *InstructorOpenAI) chatToolCallStream(ctx context.Context, request *openai.ChatCompletionRequest, schema *Schema, strict bool) (<-chan string, error) { + request.Tools = createOpenAITools(schema, strict) return i.createStream(ctx, request) } @@ -63,23 +65,6 @@ func (i *InstructorOpenAI) chatJSONSchemaStream(ctx context.Context, request *op return i.createStream(ctx, request) } -func createOpenAITools(schema *Schema) []openai.Tool { - tools := make([]openai.Tool, 0, len(schema.Functions)) - for _, function := range schema.Functions { - f := openai.FunctionDefinition{ - Name: function.Name, - Description: function.Description, - Parameters: function.Parameters, - } - t := openai.Tool{ - Type: "function", - Function: &f, - } - tools = append(tools, t) - } - return tools -} - func createJSONMessageStream(schema *Schema) *openai.ChatCompletionMessage { message := fmt.Sprintf(` Please respond with a JSON array where the elements following JSON schema: