From 6d248d5d70bdcc81c2cb67e47532af94793e9008 Mon Sep 17 00:00:00 2001 From: Robby <45851384+h0rv@users.noreply.github.com> Date: Tue, 27 Aug 2024 14:13:01 -0400 Subject: [PATCH] add openai strict (structured outputs) mode (#38) * init * update from fork * add openai strict (structured outputs) mode * revert reciept * strict mode working * fix strict mode example * add example to readme * update readme dropdowns * add openai api key to running readme * g co vision/receipt/main.go --------- Co-authored-by: Robby --- README.md | 245 ++++++++++++++++++++++----- examples/auto_ticketer/main.go | 154 +++++++++++++++++ examples/vision/receipt/main.go | 2 +- go.mod | 2 +- go.sum | 4 +- pkg/instructor/chat.go | 2 +- pkg/instructor/mode_enum.go | 12 +- pkg/instructor/openai_chat.go | 81 ++++++++- pkg/instructor/openai_chat_stream.go | 25 +-- pkg/instructor/schema.go | 5 + 10 files changed, 451 insertions(+), 81 deletions(-) create mode 100644 examples/auto_ticketer/main.go diff --git a/README.md b/README.md index 0d840b6..e6dd418 100644 --- a/README.md +++ b/README.md @@ -34,16 +34,13 @@ As shown in the example below, by adding extra metadata to each struct field (vi > For more information on the `jsonschema` tags available, see the [`jsonschema` godoc](https://pkg.go.dev/github.com/invopop/jsonschema?utm_source=godoc). -
-Running +Running ```bash export OPENAI_API_KEY= go run examples/user/main.go ``` -
- ```go package main @@ -105,16 +102,13 @@ Age: %d
Function Calling with OpenAI -
-Running +Running ```bash export OPENAI_API_KEY= go run examples/function_calling/main.go ``` -
- ```go package main @@ -195,16 +189,13 @@ func main() {
Text Classification with Anthropic -
-Running +Running ```bash export ANTHROPIC_API_KEY= go run examples/classification/main.go ``` -
- ```go package main @@ -298,18 +289,15 @@ func assert(condition bool, message string) {
Images with OpenAI -![List of books](https://raw.githubusercontent.com/instructor-ai/instructor-go/main/examples/images/openai/books.png) +![List of books](https://raw.githubusercontent.com/instructor-ai/instructor-go/main/examples/vision/openai/books.png) -
-Running +Running ```bash export OPENAI_API_KEY= -go run examples/images/openai/main.go +go run examples/vision/openai/main.go ``` -
- ```go package main @@ -437,18 +425,15 @@ func main() {
Images with Anthropic -![List of movies](https://raw.githubusercontent.com/instructor-ai/instructor-go/main/examples/images/anthropic/movies.png) +![List of movies](https://raw.githubusercontent.com/instructor-ai/instructor-go/main/examples/vision/anthropic/movies.png) -
-Running +Running ```bash export ANTHROPIC_API_KEY= -go run examples/images/anthropic/main.go +go run examples/vision/anthropic/main.go ``` -
- ```go package main @@ -608,16 +593,13 @@ func urlToBase64(url string) (string, error) {
Streaming with OpenAI -
-Running +Running ```bash export OPENAI_API_KEY= go run examples/streaming/openai/main.go ``` -
- ```go package main @@ -754,16 +736,13 @@ Product list:
Document Segmentation with Cohere -
-Running +Running ```bash export COHERE_API_KEY= go run examples/document_segmentation/main.go ``` -
- ```go package main @@ -915,16 +894,13 @@ func getSectionsText(structuredDoc *StructuredDocument, line2text map[int]string
Streaming with Cohere -
-Running +Running ```bash export COHERE_API_KEY= go run examples/streaming/cohere/main.go ``` -
- ```go package main @@ -1008,15 +984,12 @@ func toPtr[T any](val T) *T {
Local, Self-Hosted Models with Ollama (via OpenAI API Support) -
-Running +Running ```bash go run examples/ollama/main.go ``` -
- ```go package main @@ -1091,6 +1064,7 @@ func main() {
+ Receipt Item Extraction from Image (using OpenAI GPT-4o)

@@ -1098,15 +1072,12 @@ func main() { Receipt 2

-
-Running +Running ```bash go run examples/vision/receipt/main.go ``` -
- ```go package main @@ -1267,6 +1238,192 @@ Total: $98.21
+
+ +Task Ticket Creator from Transcript - OpenAI Structured Outputs (Strict JSON Mode) + +
+ +Running + +```bash +export OPENAI_API_KEY= +go run examples/auto_ticketer/main.go +``` + +```go +``` + +
+ +Task Ticket Creator from Transcript - OpenAI Structured Outputs (Strict JSON Mode) + +Running + +```bash +export OPENAI_API_KEY= +go run examples/auto_ticketer/main.go +``` + +```go +/* + * Original example in Python: https://github.com/jxnl/instructor/blob/11125a7c831a26e2a4deaef4129f2b4845a7e079/examples/auto-ticketer/run.py + */ + +package main + +import ( + "context" + "fmt" + "os" + "strings" + + "github.com/instructor-ai/instructor-go/pkg/instructor" + openai "github.com/sashabaranov/go-openai" +) + +type PriorityEnum string + +const ( + High PriorityEnum = "High" + Medium PriorityEnum = "Medium" + Low PriorityEnum = "Low" +) + +type Subtask struct { + ID int `json:"id" jsonschema:"title=unique identifier for the subtask,description=Unique identifier for the subtask"` + Name string `json:"name" jsonschema:"title=name of the subtask,description=Informative title of the subtask"` +} + +type Ticket struct { + ID int `json:"id" jsonschema:"title=unique identifier for the ticket,description=Unique identifier for the ticket"` + Name string `json:"name" jsonschema:"title=name of the task,description=Title of the task"` + Description string `json:"description" jsonschema:"title=description of the task,description=Detailed description of the task"` + Priority PriorityEnum `json:"priority" jsonschema:"title=priority level,description=Priority level"` + Assignees []string `json:"assignees" jsonschema:"title=list of users assigned to the task,description=List of users assigned to the task"` + Subtasks []Subtask `json:"subtasks" jsonschema:"title=list of subtasks associated with the main task,description=List of subtasks associated with the main task"` + Dependencies []int `json:"dependencies" jsonschema:"title=list of ticket IDs that this ticket depends on,description=List of ticket IDs that this ticket depends on"` +} + +type ActionItems struct { + Tickets []Ticket `json:"tickets"` +} + +func (ai ActionItems) String() string { + var sb strings.Builder + + for _, ticket := range ai.Tickets { + sb.WriteString(fmt.Sprintf("Ticket ID: %d\n", ticket.ID)) + sb.WriteString(fmt.Sprintf(" Name: %s\n", ticket.Name)) + sb.WriteString(fmt.Sprintf(" Description: %s\n", ticket.Description)) + sb.WriteString(fmt.Sprintf(" Priority: %s\n", ticket.Priority)) + sb.WriteString(fmt.Sprintf(" Assignees: %s\n", strings.Join(ticket.Assignees, ", "))) + + if len(ticket.Subtasks) > 0 { + sb.WriteString(" Subtasks:\n") + for _, subtask := range ticket.Subtasks { + sb.WriteString(fmt.Sprintf(" - Subtask ID: %d, Name: %s\n", subtask.ID, subtask.Name)) + } + } + + if len(ticket.Dependencies) > 0 { + sb.WriteString(fmt.Sprintf(" Dependencies: %v\n", ticket.Dependencies)) + } + + sb.WriteString("\n") + } + + return sb.String() +} + +func main() { + ctx := context.Background() + + client := instructor.FromOpenAI( + openai.NewClient(os.Getenv("OPENAI_API_KEY")), + instructor.WithMode(instructor.ModeJSONStrict), + instructor.WithMaxRetries(0), + ) + + transcript := ` +Alice: Hey team, we have several critical tasks we need to tackle for the upcoming release. First, we need to work on improving the authentication system. It's a top priority. + +Bob: Got it, Alice. I can take the lead on the authentication improvements. Are there any specific areas you want me to focus on? + +Alice: Good question, Bob. We need both a front-end revamp and back-end optimization. So basically, two sub-tasks. + +Carol: I can help with the front-end part of the authentication system. + +Bob: Great, Carol. I'll handle the back-end optimization then. + +Alice: Perfect. Now, after the authentication system is improved, we have to integrate it with our new billing system. That's a medium priority task. + +Carol: Is the new billing system already in place? + +Alice: No, it's actually another task. So it's a dependency for the integration task. Bob, can you also handle the billing system? + +Bob: Sure, but I'll need to complete the back-end optimization of the authentication system first, so it's dependent on that. + +Alice: Understood. Lastly, we also need to update our user documentation to reflect all these changes. It's a low-priority task but still important. + +Carol: I can take that on once the front-end changes for the authentication system are done. So, it would be dependent on that. + +Alice: Sounds like a plan. Let's get these tasks modeled out and get started. +` + + var actionItems ActionItems + _, err := client.CreateChatCompletion( + ctx, + openai.ChatCompletionRequest{ + Model: openai.GPT4oMini20240718, + Temperature: .2, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: "The following is a transcript of a meeting between a manager and their team. The manager is assigning tasks to their team members and creating action items for them to complete.", + }, + { + Role: openai.ChatMessageRoleUser, + Content: fmt.Sprintf("Create the action items for the following transcript: %s", transcript), + }, + }, + }, + &actionItems, + ) + if err != nil { + panic(err) + } + + println(actionItems.String()) + /* + Ticket ID: 1 + Name: Improve Authentication System + Description: Revamp the front-end and optimize the back-end of the authentication system. + Priority: high + Assignees: Bob, Carol + Subtasks: + - Subtask ID: 1, Name: Front-end Revamp + - Subtask ID: 2, Name: Back-end Optimization + + Ticket ID: 2 + Name: Integrate Authentication with New Billing System + Description: Integrate the improved authentication system with the new billing system. + Priority: medium + Assignees: Bob + Dependencies: [1] + + Ticket ID: 3 + Name: Update User Documentation + Description: Update the user documentation to reflect changes made to the authentication system. + Priority: low + Assignees: Carol + Dependencies: [1] + */ +} +``` + +
+ ## Providers Instructor Go supports the following LLM provider APIs: diff --git a/examples/auto_ticketer/main.go b/examples/auto_ticketer/main.go new file mode 100644 index 0000000..4393e38 --- /dev/null +++ b/examples/auto_ticketer/main.go @@ -0,0 +1,154 @@ +/* + * Original example in Python: https://github.com/jxnl/instructor/blob/11125a7c831a26e2a4deaef4129f2b4845a7e079/examples/auto-ticketer/run.py + */ + +package main + +import ( + "context" + "fmt" + "os" + "strings" + + "github.com/instructor-ai/instructor-go/pkg/instructor" + openai "github.com/sashabaranov/go-openai" +) + +type PriorityEnum string + +const ( + High PriorityEnum = "High" + Medium PriorityEnum = "Medium" + Low PriorityEnum = "Low" +) + +type Subtask struct { + ID int `json:"id" jsonschema:"title=unique identifier for the subtask,description=Unique identifier for the subtask"` + Name string `json:"name" jsonschema:"title=name of the subtask,description=Informative title of the subtask"` +} + +type Ticket struct { + ID int `json:"id" jsonschema:"title=unique identifier for the ticket,description=Unique identifier for the ticket"` + Name string `json:"name" jsonschema:"title=name of the task,description=Title of the task"` + Description string `json:"description" jsonschema:"title=description of the task,description=Detailed description of the task"` + Priority PriorityEnum `json:"priority" jsonschema:"title=priority level,description=Priority level"` + Assignees []string `json:"assignees" jsonschema:"title=list of users assigned to the task,description=List of users assigned to the task"` + Subtasks []Subtask `json:"subtasks" jsonschema:"title=list of subtasks associated with the main task,description=List of subtasks associated with the main task"` + Dependencies []int `json:"dependencies" jsonschema:"title=list of ticket IDs that this ticket depends on,description=List of ticket IDs that this ticket depends on"` +} + +type ActionItems struct { + Tickets []Ticket `json:"tickets"` +} + +func (ai ActionItems) String() string { + var sb strings.Builder + + for _, ticket := range ai.Tickets { + sb.WriteString(fmt.Sprintf("Ticket ID: %d\n", ticket.ID)) + sb.WriteString(fmt.Sprintf(" Name: %s\n", ticket.Name)) + sb.WriteString(fmt.Sprintf(" Description: %s\n", ticket.Description)) + sb.WriteString(fmt.Sprintf(" Priority: %s\n", ticket.Priority)) + sb.WriteString(fmt.Sprintf(" Assignees: %s\n", strings.Join(ticket.Assignees, ", "))) + + if len(ticket.Subtasks) > 0 { + sb.WriteString(" Subtasks:\n") + for _, subtask := range ticket.Subtasks { + sb.WriteString(fmt.Sprintf(" - Subtask ID: %d, Name: %s\n", subtask.ID, subtask.Name)) + } + } + + if len(ticket.Dependencies) > 0 { + sb.WriteString(fmt.Sprintf(" Dependencies: %v\n", ticket.Dependencies)) + } + + sb.WriteString("\n") + } + + return sb.String() +} + +func main() { + ctx := context.Background() + + client := instructor.FromOpenAI( + openai.NewClient(os.Getenv("OPENAI_API_KEY")), + instructor.WithMode(instructor.ModeJSONStrict), + instructor.WithMaxRetries(0), + ) + + transcript := ` +Alice: Hey team, we have several critical tasks we need to tackle for the upcoming release. First, we need to work on improving the authentication system. It's a top priority. + +Bob: Got it, Alice. I can take the lead on the authentication improvements. Are there any specific areas you want me to focus on? + +Alice: Good question, Bob. We need both a front-end revamp and back-end optimization. So basically, two sub-tasks. + +Carol: I can help with the front-end part of the authentication system. + +Bob: Great, Carol. I'll handle the back-end optimization then. + +Alice: Perfect. Now, after the authentication system is improved, we have to integrate it with our new billing system. That's a medium priority task. + +Carol: Is the new billing system already in place? + +Alice: No, it's actually another task. So it's a dependency for the integration task. Bob, can you also handle the billing system? + +Bob: Sure, but I'll need to complete the back-end optimization of the authentication system first, so it's dependent on that. + +Alice: Understood. Lastly, we also need to update our user documentation to reflect all these changes. It's a low-priority task but still important. + +Carol: I can take that on once the front-end changes for the authentication system are done. So, it would be dependent on that. + +Alice: Sounds like a plan. Let's get these tasks modeled out and get started. +` + + var actionItems ActionItems + _, err := client.CreateChatCompletion( + ctx, + openai.ChatCompletionRequest{ + Model: openai.GPT4oMini20240718, + Temperature: .2, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: "The following is a transcript of a meeting between a manager and their team. The manager is assigning tasks to their team members and creating action items for them to complete.", + }, + { + Role: openai.ChatMessageRoleUser, + Content: fmt.Sprintf("Create the action items for the following transcript: %s", transcript), + }, + }, + }, + &actionItems, + ) + if err != nil { + panic(err) + } + + println(actionItems.String()) + /* + Ticket ID: 1 + Name: Improve Authentication System + Description: Revamp the front-end and optimize the back-end of the authentication system. + Priority: high + Assignees: Bob, Carol + Subtasks: + - Subtask ID: 1, Name: Front-end Revamp + - Subtask ID: 2, Name: Back-end Optimization + + Ticket ID: 2 + Name: Integrate Authentication with New Billing System + Description: Integrate the improved authentication system with the new billing system. + Priority: medium + Assignees: Bob + Dependencies: [1] + + Ticket ID: 3 + Name: Update User Documentation + Description: Update the user documentation to reflect changes made to the authentication system. + Priority: low + Assignees: Carol + Dependencies: [1] + */ +} diff --git a/examples/vision/receipt/main.go b/examples/vision/receipt/main.go index daa79e7..2413a22 100644 --- a/examples/vision/receipt/main.go +++ b/examples/vision/receipt/main.go @@ -97,7 +97,7 @@ func main() { urls := []string{ // source: https://templates.mediamodifier.com/645124ff36ed2f5227cbf871/supermarket-receipt-template.jpg - "https://raw.githubusercontent.com/instructor-ai/instructor-go/main/examples/vision/receipt/supermarket-receipt-template.jpg", + "https://raw.githubusercontent.com/instructor-ai/instructor-go/main/examples/vision/receipt/supermarket-receipt-template.jpg", // source: https://ocr.space/Content/Images/receipt-ocr-original.jpg "https://raw.githubusercontent.com/instructor-ai/instructor-go/main/examples/vision/receipt/receipt-ocr-original.jpg", } diff --git a/go.mod b/go.mod index a3fb0c9..45c16c5 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/go-playground/validator/v10 v10.21.0 github.com/invopop/jsonschema v0.12.0 github.com/liushuangls/go-anthropic/v2 v2.1.0 - github.com/sashabaranov/go-openai v1.24.1 + github.com/sashabaranov/go-openai v1.29.0 ) require ( diff --git a/go.sum b/go.sum index 818bbbb..4c8050f 100644 --- a/go.sum +++ b/go.sum @@ -29,8 +29,8 @@ github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0 github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/sashabaranov/go-openai v1.24.1 h1:DWK95XViNb+agQtuzsn+FyHhn3HQJ7Va8z04DQDJ1MI= -github.com/sashabaranov/go-openai v1.24.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= +github.com/sashabaranov/go-openai v1.29.0 h1:eBH6LSjtX4md5ImDCX8hNhHQvaRf22zujiERoQpsvLo= +github.com/sashabaranov/go-openai v1.29.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= diff --git a/pkg/instructor/chat.go b/pkg/instructor/chat.go index 4a6ec41..c7cc5d3 100644 --- a/pkg/instructor/chat.go +++ b/pkg/instructor/chat.go @@ -29,7 +29,7 @@ func chatHandler(i Instructor, ctx context.Context, request interface{}, respons // keep a running total of usage usage := &UsageSum{} - for attempt := 0; attempt < i.MaxRetries(); attempt++ { + for attempt := 0; attempt <= i.MaxRetries(); attempt++ { text, resp, err := i.chat(ctx, request, schema) if err != nil { diff --git a/pkg/instructor/mode_enum.go b/pkg/instructor/mode_enum.go index ad55484..1a776ea 100644 --- a/pkg/instructor/mode_enum.go +++ b/pkg/instructor/mode_enum.go @@ -3,9 +3,11 @@ package instructor type Mode = string const ( - ModeToolCall Mode = "tool_call" - ModeJSON Mode = "json_mode" - ModeJSONSchema Mode = "json_schema_mode" - ModeMarkdownJSON Mode = "markdown_json_mode" - ModeDefault Mode = ModeJSONSchema + ModeToolCall Mode = "tool_call_mode" + ModeToolCallStrict Mode = "tool_call_strict_mode" + ModeJSON Mode = "json_mode" + ModeJSONStrict Mode = "json_strict_mode" + ModeJSONSchema Mode = "json_schema_mode" + ModeMarkdownJSON Mode = "markdown_json_mode" + ModeDefault Mode = ModeJSONSchema ) diff --git a/pkg/instructor/openai_chat.go b/pkg/instructor/openai_chat.go index b6f73a7..ab9b515 100644 --- a/pkg/instructor/openai_chat.go +++ b/pkg/instructor/openai_chat.go @@ -6,9 +6,18 @@ import ( "errors" "fmt" + "github.com/invopop/jsonschema" openai "github.com/sashabaranov/go-openai" ) +type ResponseFormatSchemaWrapper struct { + Type string `json:"type"` + Required []string `json:"required"` + AdditionalProperties bool `json:"additionalProperties"` + Properties *jsonschema.Definitions `json:"properties"` + Definitions *jsonschema.Definitions `json:"$defs"` +} + func (i *InstructorOpenAI) CreateChatCompletion( ctx context.Context, request openai.ChatCompletionRequest, @@ -41,9 +50,13 @@ func (i *InstructorOpenAI) chat(ctx context.Context, request interface{}, schema switch i.Mode() { case ModeToolCall: - return i.chatToolCall(ctx, &req, schema) + return i.chatToolCall(ctx, &req, schema, false) + case ModeToolCallStrict: + return i.chatToolCall(ctx, &req, schema, true) case ModeJSON: - return i.chatJSON(ctx, &req, schema) + return i.chatJSON(ctx, &req, schema, false) + case ModeJSONStrict: + return i.chatJSON(ctx, &req, schema, true) case ModeJSONSchema: return i.chatJSONSchema(ctx, &req, schema) default: @@ -51,9 +64,9 @@ func (i *InstructorOpenAI) chat(ctx context.Context, request interface{}, schema } } -func (i *InstructorOpenAI) chatToolCall(ctx context.Context, request *openai.ChatCompletionRequest, schema *Schema) (string, *openai.ChatCompletionResponse, error) { +func (i *InstructorOpenAI) chatToolCall(ctx context.Context, request *openai.ChatCompletionRequest, schema *Schema, strict bool) (string, *openai.ChatCompletionResponse, error) { - request.Tools = createOpenAITools(schema) + request.Tools = createOpenAITools(schema, strict) resp, err := i.Client.CreateChatCompletion(ctx, *request) if err != nil { @@ -100,12 +113,40 @@ func (i *InstructorOpenAI) chatToolCall(ctx context.Context, request *openai.Cha return string(resultJSON), &resp, nil } -func (i *InstructorOpenAI) chatJSON(ctx context.Context, request *openai.ChatCompletionRequest, schema *Schema) (string, *openai.ChatCompletionResponse, error) { +func (i *InstructorOpenAI) chatJSON(ctx context.Context, request *openai.ChatCompletionRequest, schema *Schema, strict bool) (string, *openai.ChatCompletionResponse, error) { + + structName := schema.NameFromRef() request.Messages = prepend(request.Messages, *createJSONMessage(schema)) - // Set JSON mode - request.ResponseFormat = &openai.ChatCompletionResponseFormat{Type: openai.ChatCompletionResponseFormatTypeJSONObject} + if strict { + schemaWrapper := ResponseFormatSchemaWrapper{ + Type: "object", + Required: []string{structName}, + Definitions: &schema.Schema.Definitions, + Properties: &jsonschema.Definitions{ + structName: schema.Definitions[structName], + }, + AdditionalProperties: false, + } + + schemaJSON, _ := json.Marshal(schemaWrapper) + schemaRaw := json.RawMessage(schemaJSON) + + request.ResponseFormat = &openai.ChatCompletionResponseFormat{ + Type: openai.ChatCompletionResponseFormatTypeJSONSchema, + JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{ + Name: structName, + Description: schema.Description, + Schema: schemaRaw, + Strict: true, + }, + } + } else { + request.ResponseFormat = &openai.ChatCompletionResponseFormat{ + Type: openai.ChatCompletionResponseFormatTypeJSONObject, + } + } resp, err := i.Client.CreateChatCompletion(ctx, *request) if err != nil { @@ -114,6 +155,14 @@ func (i *InstructorOpenAI) chatJSON(ctx context.Context, request *openai.ChatCom text := resp.Choices[0].Message.Content + if strict { + resMap := make(map[string]any) + _ = json.Unmarshal([]byte(text), &resMap) + + cleanedText, _ := json.Marshal(resMap[structName]) + text = string(cleanedText) + } + return text, &resp, nil } @@ -195,6 +244,24 @@ Make sure to return an instance of the JSON, not the schema itself return msg } +func createOpenAITools(schema *Schema, strict bool) []openai.Tool { + tools := make([]openai.Tool, 0, len(schema.Functions)) + for _, function := range schema.Functions { + f := openai.FunctionDefinition{ + Name: function.Name, + Description: function.Description, + Parameters: function.Parameters, + Strict: strict, + } + t := openai.Tool{ + Type: "function", + Function: &f, + } + tools = append(tools, t) + } + return tools +} + func nilOpenaiRespWithUsage(resp *openai.ChatCompletionResponse) *openai.ChatCompletionResponse { if resp == nil { return nil diff --git a/pkg/instructor/openai_chat_stream.go b/pkg/instructor/openai_chat_stream.go index 8a46c0e..72120cf 100644 --- a/pkg/instructor/openai_chat_stream.go +++ b/pkg/instructor/openai_chat_stream.go @@ -36,7 +36,9 @@ func (i *InstructorOpenAI) chatStream(ctx context.Context, request interface{}, switch i.Mode() { case ModeToolCall: - return i.chatToolCallStream(ctx, &req, schema) + return i.chatToolCallStream(ctx, &req, schema, false) + case ModeToolCallStrict: + return i.chatToolCallStream(ctx, &req, schema, true) case ModeJSON: return i.chatJSONStream(ctx, &req, schema) case ModeJSONSchema: @@ -46,8 +48,8 @@ func (i *InstructorOpenAI) chatStream(ctx context.Context, request interface{}, } } -func (i *InstructorOpenAI) chatToolCallStream(ctx context.Context, request *openai.ChatCompletionRequest, schema *Schema) (<-chan string, error) { - request.Tools = createOpenAITools(schema) +func (i *InstructorOpenAI) chatToolCallStream(ctx context.Context, request *openai.ChatCompletionRequest, schema *Schema, strict bool) (<-chan string, error) { + request.Tools = createOpenAITools(schema, strict) return i.createStream(ctx, request) } @@ -63,23 +65,6 @@ func (i *InstructorOpenAI) chatJSONSchemaStream(ctx context.Context, request *op return i.createStream(ctx, request) } -func createOpenAITools(schema *Schema) []openai.Tool { - tools := make([]openai.Tool, 0, len(schema.Functions)) - for _, function := range schema.Functions { - f := openai.FunctionDefinition{ - Name: function.Name, - Description: function.Description, - Parameters: function.Parameters, - } - t := openai.Tool{ - Type: "function", - Function: &f, - } - tools = append(tools, t) - } - return tools -} - func createJSONMessageStream(schema *Schema) *openai.ChatCompletionMessage { message := fmt.Sprintf(` Please respond with a JSON array where the elements following JSON schema: diff --git a/pkg/instructor/schema.go b/pkg/instructor/schema.go index c5868ef..8b9605a 100644 --- a/pkg/instructor/schema.go +++ b/pkg/instructor/schema.go @@ -3,6 +3,7 @@ package instructor import ( "encoding/json" "reflect" + "strings" "github.com/invopop/jsonschema" ) @@ -69,3 +70,7 @@ func ToFunctionSchema(tType reflect.Type, tSchema *jsonschema.Schema) []Function return fds } + +func (s *Schema) NameFromRef() string { + return strings.Split(s.Ref, "/")[2] // ex: '#/$defs/MyStruct' +}