Skip to content

Commit

Permalink
Merge branch 'master' into add-raw-response-schema
Browse files Browse the repository at this point in the history
  • Loading branch information
h0rv committed Aug 22, 2024
2 parents 03508ae + 6d02119 commit e2224ea
Show file tree
Hide file tree
Showing 19 changed files with 423 additions and 73 deletions.
76 changes: 76 additions & 0 deletions api_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,82 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) {
}
}

func TestChatCompletionStructuredOutputsFunctionCalling(t *testing.T) {
apiToken := os.Getenv("OPENAI_TOKEN")
if apiToken == "" {
t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.")
}

var err error
c := openai.NewClient(apiToken)
ctx := context.Background()

resp, err := c.CreateChatCompletion(
ctx,
openai.ChatCompletionRequest{
Model: openai.GPT4oMini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
Content: "Please enter a string, and we will convert it into the following naming conventions:" +
"1. PascalCase: Each word starts with an uppercase letter, with no spaces or separators." +
"2. CamelCase: The first word starts with a lowercase letter, " +
"and subsequent words start with an uppercase letter, with no spaces or separators." +
"3. KebabCase: All letters are lowercase, with words separated by hyphens `-`." +
"4. SnakeCase: All letters are lowercase, with words separated by underscores `_`.",
},
{
Role: openai.ChatMessageRoleUser,
Content: "Hello World",
},
},
Tools: []openai.Tool{
{
Type: openai.ToolTypeFunction,
Function: &openai.FunctionDefinition{
Name: "display_cases",
Strict: true,
Parameters: &jsonschema.Definition{
Type: jsonschema.Object,
Properties: map[string]jsonschema.Definition{
"PascalCase": {
Type: jsonschema.String,
},
"CamelCase": {
Type: jsonschema.String,
},
"KebabCase": {
Type: jsonschema.String,
},
"SnakeCase": {
Type: jsonschema.String,
},
},
Required: []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"},
AdditionalProperties: false,
},
},
},
},
ToolChoice: openai.ToolChoice{
Type: openai.ToolTypeFunction,
Function: openai.ToolFunction{
Name: "display_cases",
},
},
},
)
checks.NoError(t, err, "CreateChatCompletion (use structured outputs response) returned error")
var result = make(map[string]string)
err = json.Unmarshal([]byte(resp.Choices[0].Message.ToolCalls[0].Function.Arguments), &result)
checks.NoError(t, err, "CreateChatCompletion (use structured outputs response) unmarshal error")
for _, key := range []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"} {
if _, ok := result[key]; !ok {
t.Errorf("key:%s does not exist.", key)
}
}
}

func TestChatCompletionResponseFormat_JSONSchemaRaw(t *testing.T) {
apiToken := os.Getenv("OPENAI_TOKEN")
if apiToken == "" {
Expand Down
24 changes: 19 additions & 5 deletions api_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,15 @@ func TestAzureFullURL(t *testing.T) {
Name string
BaseURL string
AzureModelMapper map[string]string
Suffix string
Model string
Expect string
}{
{
"AzureBaseURLWithSlashAutoStrip",
"https://httpbin.org/",
nil,
"/chat/completions",
"chatgpt-demo",
"https://httpbin.org/" +
"openai/deployments/chatgpt-demo" +
Expand All @@ -128,19 +130,28 @@ func TestAzureFullURL(t *testing.T) {
"AzureBaseURLWithoutSlashOK",
"https://httpbin.org",
nil,
"/chat/completions",
"chatgpt-demo",
"https://httpbin.org/" +
"openai/deployments/chatgpt-demo" +
"/chat/completions?api-version=2023-05-15",
},
{
"",
"https://httpbin.org",
nil,
"/assistants?limit=10",
"chatgpt-demo",
"https://httpbin.org/openai/assistants?api-version=2023-05-15&limit=10",
},
}

for _, c := range cases {
t.Run(c.Name, func(t *testing.T) {
az := DefaultAzureConfig("dummy", c.BaseURL)
cli := NewClientWithConfig(az)
// /openai/deployments/{engine}/chat/completions?api-version={api_version}
actual := cli.fullURL("/chat/completions", c.Model)
actual := cli.fullURL(c.Suffix, withModel(c.Model))
if actual != c.Expect {
t.Errorf("Expected %s, got %s", c.Expect, actual)
}
Expand All @@ -153,19 +164,22 @@ func TestCloudflareAzureFullURL(t *testing.T) {
cases := []struct {
Name string
BaseURL string
Suffix string
Expect string
}{
{
"CloudflareAzureBaseURLWithSlashAutoStrip",
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/",
"/chat/completions",
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" +
"chat/completions?api-version=2023-05-15",
},
{
"CloudflareAzureBaseURLWithoutSlashOK",
"",
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo",
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" +
"chat/completions?api-version=2023-05-15",
"/assistants?limit=10",
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo" +
"/assistants?api-version=2023-05-15&limit=10",
},
}

Expand All @@ -176,7 +190,7 @@ func TestCloudflareAzureFullURL(t *testing.T) {

cli := NewClientWithConfig(az)

actual := cli.fullURL("/chat/completions")
actual := cli.fullURL(c.Suffix)
if actual != c.Expect {
t.Errorf("Expected %s, got %s", c.Expect, actual)
}
Expand Down
9 changes: 7 additions & 2 deletions audio.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,13 @@ func (c *Client) callAudioAPI(
}

urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix)
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model),
withBody(&formBody), withContentType(builder.FormDataContentType()))
req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL(urlSuffix, withModel(request.Model)),
withBody(&formBody),
withContentType(builder.FormDataContentType()),
)
if err != nil {
return AudioResponse{}, err
}
Expand Down
8 changes: 7 additions & 1 deletion chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ type ToolFunction struct {
type FunctionDefinition struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Strict bool `json:"strict,omitempty"`
// Parameters is an object describing the function.
// You can pass json.RawMessage to describe the schema,
// or you can pass in a struct which serializes to the proper JSON schema.
Expand Down Expand Up @@ -381,7 +382,12 @@ func (c *Client) CreateChatCompletion(
return
}

req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request))
req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL(urlSuffix, withModel(request.Model)),
withBody(request),
)
if err != nil {
return
}
Expand Down
7 changes: 6 additions & 1 deletion chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,12 @@ func (c *Client) CreateChatCompletionStream(
}

request.Stream = true
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request))
req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL(urlSuffix, withModel(request.Model)),
withBody(request),
)
if err != nil {
return nil, err
}
Expand Down
26 changes: 26 additions & 0 deletions chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,32 @@ func TestChatCompletionsFunctions(t *testing.T) {
})
checks.NoError(t, err, "CreateChatCompletion with functions error")
})
t.Run("StructuredOutputs", func(t *testing.T) {
type testMessage struct {
Count int `json:"count"`
Words []string `json:"words"`
}
msg := testMessage{
Count: 2,
Words: []string{"hello", "world"},
}
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
MaxTokens: 5,
Model: openai.GPT3Dot5Turbo0613,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "Hello!",
},
},
Functions: []openai.FunctionDefinition{{
Name: "test",
Strict: true,
Parameters: &msg,
}},
})
checks.NoError(t, err, "CreateChatCompletion with functions error")
})
}

func TestAzureChatCompletions(t *testing.T) {
Expand Down
84 changes: 54 additions & 30 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,42 +222,66 @@ func decodeString(body io.Reader, output *string) error {
return nil
}

type fullURLOptions struct {
model string
}

type fullURLOption func(*fullURLOptions)

func withModel(model string) fullURLOption {
return func(args *fullURLOptions) {
args.model = model
}
}

var azureDeploymentsEndpoints = []string{
"/completions",
"/embeddings",
"/chat/completions",
"/audio/transcriptions",
"/audio/translations",
"/audio/speech",
"/images/generations",
}

// fullURL returns full URL for request.
// args[0] is model name, if API type is Azure, model name is required to get deployment name.
func (c *Client) fullURL(suffix string, args ...any) string {
// /openai/deployments/{model}/chat/completions?api-version={api_version}
func (c *Client) fullURL(suffix string, setters ...fullURLOption) string {
baseURL := strings.TrimRight(c.config.BaseURL, "/")
args := fullURLOptions{}
for _, setter := range setters {
setter(&args)
}

if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD {
baseURL := c.config.BaseURL
baseURL = strings.TrimRight(baseURL, "/")
parseURL, _ := url.Parse(baseURL)
query := parseURL.Query()
query.Add("api-version", c.config.APIVersion)
// if suffix is /models change to {endpoint}/openai/models?api-version=2022-12-01
// https://learn.microsoft.com/en-us/rest/api/cognitiveservices/azureopenaistable/models/list?tabs=HTTP
if containsSubstr([]string{"/models", "/assistants", "/threads", "/files"}, suffix) {
return fmt.Sprintf("%s/%s%s?%s", baseURL, azureAPIPrefix, suffix, query.Encode())
}
azureDeploymentName := "UNKNOWN"
if len(args) > 0 {
model, ok := args[0].(string)
if ok {
azureDeploymentName = c.config.GetAzureDeploymentByModel(model)
}
}
return fmt.Sprintf("%s/%s/%s/%s%s?%s",
baseURL, azureAPIPrefix, azureDeploymentsPrefix,
azureDeploymentName, suffix, query.Encode(),
)
baseURL = c.baseURLWithAzureDeployment(baseURL, suffix, args.model)
}

if c.config.APIVersion != "" {
suffix = c.suffixWithAPIVersion(suffix)
}
return fmt.Sprintf("%s%s", baseURL, suffix)
}

// https://developers.cloudflare.com/ai-gateway/providers/azureopenai/
if c.config.APIType == APITypeCloudflareAzure {
baseURL := c.config.BaseURL
baseURL = strings.TrimRight(baseURL, "/")
return fmt.Sprintf("%s%s?api-version=%s", baseURL, suffix, c.config.APIVersion)
func (c *Client) suffixWithAPIVersion(suffix string) string {
parsedSuffix, err := url.Parse(suffix)
if err != nil {
panic("failed to parse url suffix")
}
query := parsedSuffix.Query()
query.Add("api-version", c.config.APIVersion)
return fmt.Sprintf("%s?%s", parsedSuffix.Path, query.Encode())
}

return fmt.Sprintf("%s%s", c.config.BaseURL, suffix)
func (c *Client) baseURLWithAzureDeployment(baseURL, suffix, model string) (newBaseURL string) {
baseURL = fmt.Sprintf("%s/%s", strings.TrimRight(baseURL, "/"), azureAPIPrefix)
if containsSubstr(azureDeploymentsEndpoints, suffix) {
azureDeploymentName := c.config.GetAzureDeploymentByModel(model)
if azureDeploymentName == "" {
azureDeploymentName = "UNKNOWN"
}
baseURL = fmt.Sprintf("%s/%s/%s", baseURL, azureDeploymentsPrefix, azureDeploymentName)
}
return baseURL
}

func (c *Client) handleErrorResp(resp *http.Response) error {
Expand Down
Loading

0 comments on commit e2224ea

Please sign in to comment.