Skip to content

Commit

Permalink
add support for function calling
Browse files Browse the repository at this point in the history
  • Loading branch information
snehalchennuru committed Jul 6, 2023
1 parent e348aa5 commit 1843926
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 3 deletions.
8 changes: 7 additions & 1 deletion gpt3.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ type EmbeddingEngine string
const (
GPT3Dot5Turbo = "gpt-3.5-turbo"
GPT3Dot5Turbo0301 = "gpt-3.5-turbo-0301"
GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613"
TextSimilarityAda001 = "text-similarity-ada-001"
TextSimilarityBabbage001 = "text-similarity-babbage-001"
TextSimilarityCurie001 = "text-similarity-curie-001"
Expand Down Expand Up @@ -180,8 +181,13 @@ func (c *client) Engine(ctx context.Context, engine string) (*EngineObject, erro

func (c *client) ChatCompletion(ctx context.Context, request ChatCompletionRequest) (*ChatCompletionResponse, error) {
if request.Model == "" {
request.Model = GPT3Dot5Turbo
if request.Functions == nil {
request.Model = GPT3Dot5Turbo
} else {
request.Model = GPT3Dot5Turbo0613
}
}

request.Stream = false

req, err := c.newRequest(ctx, "POST", "/chat/completions", request)
Expand Down
26 changes: 26 additions & 0 deletions gpt3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,32 @@ func TestResponses(t *testing.T) {
},
},
},
{
"ChatCompletionWithFunctionCall",
func() (interface{}, error) {
return client.ChatCompletion(ctx, gpt3.ChatCompletionRequest{})
},
&gpt3.ChatCompletionResponse{
ID: "chatcmpl-123",
Object: "messages",
Created: 123456789,
Model: "gpt-3.5-turbo-0613",
Choices: []gpt3.ChatCompletionResponseChoice{
{
Index: 0,
FinishReason: "function_call",
Message: gpt3.ChatCompletionResponseMessage{
Role: "assistant",
Content: "",
FunctionCall: &gpt3.Function{
Name: "get_current_weather",
Arguments: `"{"location": "Boston, MA"}"`,
},
},
},
},
},
},
{
"Completion",
func() (interface{}, error) {
Expand Down
42 changes: 40 additions & 2 deletions models.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,40 @@ type ChatCompletionRequestMessage struct {

// Content is the content of the message
Content string `json:"content"`

// FunctionCall is the name and arguments of a function that should be called, as generated by the model.
FunctionCall *Function `json:"function_call,omitempty"`

// Name is the the name of the author of this message. `name` is required if role is `function`, and it should be the name of the function whose response is in the `content`.
Name string `json:"name,omitempty"`
}

// Function represents a function with a name and arguments.
type Function struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}

// ChatCompletionFunctions represents the functions the model may generate JSON inputs for.
type ChatCompletionFunctions struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Parameters ChatCompletionFunctionParameters `json:"parameters"`
}

// ChatCompletionFunctionParameters captures the metadata of the function parameter.
type ChatCompletionFunctionParameters struct {
Type string `json:"type"`
Description string `json:"description,omitempty"`
Properties map[string]FunctionParameterPropertyMetadata `json:"properties"`
Required []string `json:"required"`
}

// FunctionParameterPropertyMetadata represents the metadata of the function parameter property.
type FunctionParameterPropertyMetadata struct {
Type string `json:"type"`
Description string `json:"description,omitempty"`
Enum []string `json:"enum,omitempty"`
}

// ChatCompletionRequest is a request for the chat completion API
Expand All @@ -49,6 +83,9 @@ type ChatCompletionRequest struct {
// Messages is a list of messages to use as the context for the chat completion.
Messages []ChatCompletionRequestMessage `json:"messages"`

// Functions is a list of functions the model may generate JSON inputs for.
Functions []ChatCompletionFunctions `json:"functions"`

// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic
Temperature *float32 `json:"temperature,omitempty"`

Expand Down Expand Up @@ -152,8 +189,9 @@ type LogprobResult struct {

// ChatCompletionResponseMessage is a message returned in the response to the Chat Completions API
type ChatCompletionResponseMessage struct {
Role string `json:"role"`
Content string `json:"content"`
Role string `json:"role"`
Content string `json:"content"`
FunctionCall *Function `json:"function_call,omitempty"`
}

// ChatCompletionResponseChoice is one of the choices returned in the response to the Chat Completions API
Expand Down

0 comments on commit 1843926

Please sign in to comment.