Skip to content

Commit

Permalink
add openai strict (structured outputs) mode
Browse files Browse the repository at this point in the history
  • Loading branch information
h0rv committed Aug 27, 2024
1 parent c0ae56f commit 22ce522
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 151 deletions.
2 changes: 1 addition & 1 deletion examples/auto_ticketer/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func main() {

client := instructor.FromOpenAI(
openai.NewClient(os.Getenv("OPENAI_API_KEY")),
instructor.WithMode(instructor.ModeStructuredOutputs),
instructor.WithMode(instructor.ModeJSONStrict),
instructor.WithMaxRetries(0),
)

Expand Down
4 changes: 1 addition & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@ require (
github.com/go-playground/validator/v10 v10.21.0
github.com/invopop/jsonschema v0.12.0
github.com/liushuangls/go-anthropic/v2 v2.1.0
github.com/sashabaranov/go-openai v1.28.1
github.com/sashabaranov/go-openai v1.29.0
)

replace github.com/sashabaranov/go-openai => ../go-openai

require (
github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/buger/jsonparser v1.1.1 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sashabaranov/go-openai v1.29.0 h1:eBH6LSjtX4md5ImDCX8hNhHQvaRf22zujiERoQpsvLo=
github.com/sashabaranov/go-openai v1.29.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
Expand Down
13 changes: 7 additions & 6 deletions pkg/instructor/mode_enum.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ package instructor
type Mode = string

const (
ModeToolCall Mode = "tool_call"
ModeJSON Mode = "json_mode"
ModeStructuredOutputs Mode = "structured_outputs_mode"
ModeJSONSchema Mode = "json_schema_mode"
ModeMarkdownJSON Mode = "markdown_json_mode"
ModeDefault Mode = ModeJSONSchema
ModeToolCall Mode = "tool_call_mode"
ModeToolCallStrict Mode = "tool_call_strict_mode"
ModeJSON Mode = "json_mode"
ModeJSONStrict Mode = "json_strict_mode"
ModeJSONSchema Mode = "json_schema_mode"
ModeMarkdownJSON Mode = "markdown_json_mode"
ModeDefault Mode = ModeJSONSchema
)
153 changes: 32 additions & 121 deletions pkg/instructor/openai_chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (

"github.com/invopop/jsonschema"
openai "github.com/sashabaranov/go-openai"
openaiJSONSchema "github.com/sashabaranov/go-openai/jsonschema"
)

func (i *InstructorOpenAI) CreateChatCompletion(
Expand Down Expand Up @@ -43,10 +42,12 @@ func (i *InstructorOpenAI) chat(ctx context.Context, request interface{}, schema

switch i.Mode() {
case ModeToolCall:
return i.chatToolCall(ctx, &req, schema)
return i.chatToolCall(ctx, &req, schema, false)
case ModeToolCallStrict:
return i.chatToolCall(ctx, &req, schema, true)
case ModeJSON:
return i.chatJSON(ctx, &req, schema)
case ModeStructuredOutputs:
case ModeJSONStrict:
return i.chatJSONStrict(ctx, &req, schema)
case ModeJSONSchema:
return i.chatJSONSchema(ctx, &req, schema)
Expand All @@ -55,9 +56,9 @@ func (i *InstructorOpenAI) chat(ctx context.Context, request interface{}, schema
}
}

func (i *InstructorOpenAI) chatToolCall(ctx context.Context, request *openai.ChatCompletionRequest, schema *Schema) (string, *openai.ChatCompletionResponse, error) {
func (i *InstructorOpenAI) chatToolCall(ctx context.Context, request *openai.ChatCompletionRequest, schema *Schema, strict bool) (string, *openai.ChatCompletionResponse, error) {

request.Tools = createOpenAITools(schema)
request.Tools = createOpenAITools(schema, strict)

resp, err := i.Client.CreateChatCompletion(ctx, *request)
if err != nil {
Expand Down Expand Up @@ -123,31 +124,6 @@ func (i *InstructorOpenAI) chatJSON(ctx context.Context, request *openai.ChatCom

func (i *InstructorOpenAI) chatJSONStrict(ctx context.Context, request *openai.ChatCompletionRequest, schema *Schema) (string, *openai.ChatCompletionResponse, error) {

// var s []byte
// s, _ = json.MarshalIndent(schema, "", " ")
// fmt.Println(string(s))

// oaiSchema := convertToOpenAIJSONSchema(schema.Schema)
// fmt.Printf(`
// Type %v
// Description %v
// Enum %v
// Properties %v
// Required %v
// Items %v
// AdditionalProperties %v
// `,
// oaiSchema.Type,
// oaiSchema.Description,
// oaiSchema.Enum,
// oaiSchema.Properties,
// oaiSchema.Required,
// oaiSchema.Items,
// oaiSchema.AdditionalProperties,
// )
// s, _ = json.MarshalIndent(oaiSchema, "", " ")
// fmt.Println(string(s))

structName := schema.NameFromRef()

type SchemaWrapper struct {
Expand All @@ -158,32 +134,25 @@ func (i *InstructorOpenAI) chatJSONStrict(ctx context.Context, request *openai.C
Definitions *jsonschema.Definitions `json:"$defs"`
}

required := []string{structName}
// // for p := schema.Definitions.; p != nil; p.Next() {
// for k := range schema.Definitions {
// required = append(required, k)
// }

properties := make(jsonschema.Definitions)
properties[structName] = schema.Definitions[structName]

schemaWrapper := SchemaWrapper{
Type: "object",
Required: required,
Type: "object",
Required: []string{structName},
Definitions: &schema.Schema.Definitions,
Properties: &jsonschema.Definitions{
structName: schema.Definitions[structName],
},
AdditionalProperties: false,
Properties: &properties,
Definitions: &schema.Schema.Definitions,
}

rawSchema, _ := json.Marshal(schemaWrapper)

request.ResponseFormat = &openai.ChatCompletionResponseFormat{
Type: openai.ChatCompletionResponseFormatTypeJSONSchema,
JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{
Name: schema.NameFromRef(),
Name: structName,
Description: schema.Description,
Schema: json.RawMessage(rawSchema),
Strict: true,
SchemaRaw: toPtr(rawSchema),
},
}

Expand All @@ -194,15 +163,6 @@ func (i *InstructorOpenAI) chatJSONStrict(ctx context.Context, request *openai.C

text := resp.Choices[0].Message.Content

// TODO:
/*
Get struct contents inside:
{
"MyStructName": {
... // what we want to marshall into struct
}
}
*/
resMap := make(map[string]any)
_ = json.Unmarshal([]byte(text), &resMap)

Expand All @@ -226,73 +186,6 @@ func (i *InstructorOpenAI) chatJSONSchema(ctx context.Context, request *openai.C
return text, &resp, nil
}

func convertToOpenAIJSONSchema(schema *jsonschema.Schema) *openaiJSONSchema.Definition {

oaiSchema := openaiJSONSchema.Definition{}

// Initialize properties map
oaiSchema.Properties = make(map[string]openaiJSONSchema.Definition)

// Convert type; default to object
if schema.Type != "" {
oaiSchema.Type = openaiJSONSchema.DataType(schema.Type)
} else {
oaiSchema.Type = openaiJSONSchema.Object
}

// Convert description
oaiSchema.Description = schema.Description

// Convert enum
if schema.Enum != nil {
oaiSchema.Enum = make([]string, len(schema.Enum))
for i, v := range schema.Enum {
oaiSchema.Enum[i] = fmt.Sprintf("%v", v)
}
}

// Convert properties
if schema.Properties != nil {
for p := schema.Properties.Oldest(); p != nil; p = p.Next() {
key, value := p.Key, p.Value
propertySchema := convertToOpenAIJSONSchema(value)
oaiSchema.Properties[key] = *propertySchema
}
}

// Convert items
if schema.Items != nil {
itemsSchema := convertToOpenAIJSONSchema(schema.Items)
oaiSchema.Items = itemsSchema
}

// Convert additional properties
if schema.AdditionalProperties != nil {
additionalPropertiesSchema := convertToOpenAIJSONSchema(schema.AdditionalProperties)
oaiSchema.AdditionalProperties = additionalPropertiesSchema
}

// Convert defintions
if schema.Definitions != nil {
for key, value := range schema.Definitions {
oaiSchema.Required = append(oaiSchema.Required, key)
fmt.Printf("%+v\n", oaiSchema.Required)

definitionSchema := convertToOpenAIJSONSchema(value)
oaiSchema.Properties[key] = *definitionSchema
}
}

if len(oaiSchema.Properties) > 0 {
oaiSchema.Required = []string{}
for key := range oaiSchema.Properties {
oaiSchema.Required = append(oaiSchema.Required, key)
}
}

return &oaiSchema
}

func (i *InstructorOpenAI) emptyResponseWithUsageSum(usage *UsageSum) interface{} {
return &openai.ChatCompletionResponse{
Usage: openai.Usage{
Expand Down Expand Up @@ -357,6 +250,24 @@ Make sure to return an instance of the JSON, not the schema itself
return msg
}

func createOpenAITools(schema *Schema, strict bool) []openai.Tool {
tools := make([]openai.Tool, 0, len(schema.Functions))
for _, function := range schema.Functions {
f := openai.FunctionDefinition{
Name: function.Name,
Description: function.Description,
Parameters: function.Parameters,
Strict: strict,
}
t := openai.Tool{
Type: "function",
Function: &f,
}
tools = append(tools, t)
}
return tools
}

func nilOpenaiRespWithUsage(resp *openai.ChatCompletionResponse) *openai.ChatCompletionResponse {
if resp == nil {
return nil
Expand Down
25 changes: 5 additions & 20 deletions pkg/instructor/openai_chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ func (i *InstructorOpenAI) chatStream(ctx context.Context, request interface{},

switch i.Mode() {
case ModeToolCall:
return i.chatToolCallStream(ctx, &req, schema)
return i.chatToolCallStream(ctx, &req, schema, false)
case ModeToolCallStrict:
return i.chatToolCallStream(ctx, &req, schema, true)
case ModeJSON:
return i.chatJSONStream(ctx, &req, schema)
case ModeJSONSchema:
Expand All @@ -46,8 +48,8 @@ func (i *InstructorOpenAI) chatStream(ctx context.Context, request interface{},
}
}

func (i *InstructorOpenAI) chatToolCallStream(ctx context.Context, request *openai.ChatCompletionRequest, schema *Schema) (<-chan string, error) {
request.Tools = createOpenAITools(schema)
func (i *InstructorOpenAI) chatToolCallStream(ctx context.Context, request *openai.ChatCompletionRequest, schema *Schema, strict bool) (<-chan string, error) {
request.Tools = createOpenAITools(schema, strict)
return i.createStream(ctx, request)
}

Expand All @@ -63,23 +65,6 @@ func (i *InstructorOpenAI) chatJSONSchemaStream(ctx context.Context, request *op
return i.createStream(ctx, request)
}

func createOpenAITools(schema *Schema) []openai.Tool {
tools := make([]openai.Tool, 0, len(schema.Functions))
for _, function := range schema.Functions {
f := openai.FunctionDefinition{
Name: function.Name,
Description: function.Description,
Parameters: function.Parameters,
}
t := openai.Tool{
Type: "function",
Function: &f,
}
tools = append(tools, t)
}
return tools
}

func createJSONMessageStream(schema *Schema) *openai.ChatCompletionMessage {
message := fmt.Sprintf(`
Please respond with a JSON array where the elements following JSON schema:
Expand Down

0 comments on commit 22ce522

Please sign in to comment.