From 184392610604933198d72f87770bda71d59562e2 Mon Sep 17 00:00:00 2001 From: snehalchennuru Date: Thu, 6 Jul 2023 09:24:09 -0700 Subject: [PATCH] add support for function calling --- gpt3.go | 8 +++++++- gpt3_test.go | 26 ++++++++++++++++++++++++++ models.go | 42 ++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 73 insertions(+), 3 deletions(-) diff --git a/gpt3.go b/gpt3.go index edc298c..b86d4c4 100644 --- a/gpt3.go +++ b/gpt3.go @@ -32,6 +32,7 @@ type EmbeddingEngine string const ( GPT3Dot5Turbo = "gpt-3.5-turbo" GPT3Dot5Turbo0301 = "gpt-3.5-turbo-0301" + GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613" TextSimilarityAda001 = "text-similarity-ada-001" TextSimilarityBabbage001 = "text-similarity-babbage-001" TextSimilarityCurie001 = "text-similarity-curie-001" @@ -180,8 +181,13 @@ func (c *client) Engine(ctx context.Context, engine string) (*EngineObject, erro func (c *client) ChatCompletion(ctx context.Context, request ChatCompletionRequest) (*ChatCompletionResponse, error) { if request.Model == "" { - request.Model = GPT3Dot5Turbo + if request.Functions == nil { + request.Model = GPT3Dot5Turbo + } else { + request.Model = GPT3Dot5Turbo0613 + } } + request.Stream = false req, err := c.newRequest(ctx, "POST", "/chat/completions", request) diff --git a/gpt3_test.go b/gpt3_test.go index c268fbc..e71dac2 100644 --- a/gpt3_test.go +++ b/gpt3_test.go @@ -213,6 +213,32 @@ func TestResponses(t *testing.T) { }, }, }, + { + "ChatCompletionWithFunctionCall", + func() (interface{}, error) { + return client.ChatCompletion(ctx, gpt3.ChatCompletionRequest{}) + }, + &gpt3.ChatCompletionResponse{ + ID: "chatcmpl-123", + Object: "messages", + Created: 123456789, + Model: "gpt-3.5-turbo-0613", + Choices: []gpt3.ChatCompletionResponseChoice{ + { + Index: 0, + FinishReason: "function_call", + Message: gpt3.ChatCompletionResponseMessage{ + Role: "assistant", + Content: "", + FunctionCall: &gpt3.Function{ + Name: "get_current_weather", + Arguments: `"{"location": "Boston, MA"}"`, + }, + }, + }, + }, + }, + }, { "Completion", func() (interface{}, error) { diff --git a/models.go b/models.go index b81cba2..386e79f 100644 --- a/models.go +++ b/models.go @@ -39,6 +39,40 @@ type ChatCompletionRequestMessage struct { // Content is the content of the message Content string `json:"content"` + + // FunctionCall is the name and arguments of a function that should be called, as generated by the model. + FunctionCall *Function `json:"function_call,omitempty"` + + // Name is the the name of the author of this message. `name` is required if role is `function`, and it should be the name of the function whose response is in the `content`. + Name string `json:"name,omitempty"` +} + +// Function represents a function with a name and arguments. +type Function struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +// ChatCompletionFunctions represents the functions the model may generate JSON inputs for. +type ChatCompletionFunctions struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Parameters ChatCompletionFunctionParameters `json:"parameters"` +} + +// ChatCompletionFunctionParameters captures the metadata of the function parameter. +type ChatCompletionFunctionParameters struct { + Type string `json:"type"` + Description string `json:"description,omitempty"` + Properties map[string]FunctionParameterPropertyMetadata `json:"properties"` + Required []string `json:"required"` +} + +// FunctionParameterPropertyMetadata represents the metadata of the function parameter property. +type FunctionParameterPropertyMetadata struct { + Type string `json:"type"` + Description string `json:"description,omitempty"` + Enum []string `json:"enum,omitempty"` } // ChatCompletionRequest is a request for the chat completion API @@ -49,6 +83,9 @@ type ChatCompletionRequest struct { // Messages is a list of messages to use as the context for the chat completion. Messages []ChatCompletionRequestMessage `json:"messages"` + // Functions is a list of functions the model may generate JSON inputs for. + Functions []ChatCompletionFunctions `json:"functions"` + // What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic Temperature *float32 `json:"temperature,omitempty"` @@ -152,8 +189,9 @@ type LogprobResult struct { // ChatCompletionResponseMessage is a message returned in the response to the Chat Completions API type ChatCompletionResponseMessage struct { - Role string `json:"role"` - Content string `json:"content"` + Role string `json:"role"` + Content string `json:"content"` + FunctionCall *Function `json:"function_call,omitempty"` } // ChatCompletionResponseChoice is one of the choices returned in the response to the Chat Completions API