Skip to content

Commit

Permalink
Merge pull request #10 from JaderDias/main
Browse files Browse the repository at this point in the history
Add Edits endpoint
  • Loading branch information
tylermann authored Jun 30, 2022
2 parents 1e36ea2 + 522e092 commit 9e6c11b
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 9 deletions.
10 changes: 10 additions & 0 deletions cmd/test/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module example.com/main

go 1.18

replace github.com/PullRequestInc/go-gpt3 => ../../

require (
github.com/PullRequestInc/go-gpt3 v0.0.0-00010101000000-000000000000
github.com/joho/godotenv v1.4.0
)
51 changes: 51 additions & 0 deletions cmd/test/go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/joefitzgerald/rainbow-reporter v0.1.0/go.mod h1:481CNgqmVHQZzdIbN52CupLJyoVwB10FQ/IQlF1pdL8=
github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg=
github.com/joho/godotenv v1.4.0 h1:3l4+N6zfMWnkbPEXKng2o2/MR5mSwTrBih4ZEkkz1lg=
github.com/joho/godotenv v1.4.0/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/maxbrunsfeld/counterfeiter/v6 v6.2.3/go.mod h1:1ftk08SazyElaaNvmqAfZWGwJzshjCfBXDLoQtPAMNk=
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/gomega v1.9.0/go.mod h1:Ho0h+IUsWyvy1OpqCwxlQ/21gkhVunqlU8fDGcoTdcA=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sclevine/spec v1.2.0/go.mod h1:W4J29eT/Kzv7/b9IWLB055Z+qvVC9vt0Arko24q7p+U=
github.com/sclevine/spec v1.4.0/go.mod h1:LvpgJaFyvQzRvc1kaDs0bulYwzC70PbiYjC4QnFHkOM=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20190628185345-da137c7871d7 h1:rTIdg5QFRR7XCaK4LCjBiPbx8j4DQRpdYMnGn/bJUEU=
golang.org/x/net v0.0.0-20190628185345-da137c7871d7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20200301222351-066e0c02454c/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
12 changes: 12 additions & 0 deletions cmd/test/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,16 @@ func main() {
if err != nil {
log.Fatalln(err)
}

fmt.Print("\n\nedits API:\n")

editsResponse, err := client.Edits(ctx, gpt3.EditsRequest{
Model: "text-davinci-edit-001",
Input: "What day of the wek is it?",
Instruction: "Fix the spelling mistakes",
})
if err != nil {
log.Fatalln(err)
}
log.Printf("%+v\n", editsResponse)
}
22 changes: 21 additions & 1 deletion gpt3.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ type Client interface {
// CompletionStreamWithEngine is the same as CompletionStream except allows overriding the default engine on the client
CompletionStreamWithEngine(ctx context.Context, engine string, request CompletionRequest, onData func(*CompletionResponse)) error

// Given a prompt and an instruction, the model will return an edited version of the prompt.
Edits(ctx context.Context, request EditsRequest) (*EditsResponse, error)

// Search performs a semantic search over a list of documents with the default engine.
Search(ctx context.Context, request SearchRequest) (*SearchResponse, error)

Expand Down Expand Up @@ -204,6 +207,23 @@ func (c *client) CompletionStreamWithEngine(
return nil
}

func (c *client) Edits(ctx context.Context, request EditsRequest) (*EditsResponse, error) {
req, err := c.newRequest(ctx, "POST", "/edits", request)
if err != nil {
return nil, err
}
resp, err := c.performRequest(req)
if err != nil {
return nil, err
}

output := new(EditsResponse)
if err := getResponseObject(resp, output); err != nil {
return nil, err
}
return output, nil
}

func (c *client) Search(ctx context.Context, request SearchRequest) (*SearchResponse, error) {
return c.SearchWithEngine(ctx, c.defaultEngine, request)
}
Expand Down Expand Up @@ -288,7 +308,7 @@ func (c *client) newRequest(ctx context.Context, method, path string, payload in
if err != nil {
return nil, err
}
if (len(c.idOrg) > 0) {
if len(c.idOrg) > 0 {
req.Header.Set("OpenAI-Organization", c.idOrg)
}
req.Header.Set("Content-type", "application/json")
Expand Down
22 changes: 14 additions & 8 deletions gpt3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,21 +47,21 @@ func TestRequestCreationFails(t *testing.T) {
func() (interface{}, error) {
return client.Engines(ctx)
},
"Get https://api.openai.com/v1/engines: request error",
"Get \"https://api.openai.com/v1/engines\": request error",
},
{
"Engine",
func() (interface{}, error) {
return client.Engine(ctx, gpt3.DefaultEngine)
},
"Get https://api.openai.com/v1/engines/davinci: request error",
"Get \"https://api.openai.com/v1/engines/davinci\": request error",
},
{
"Completion",
func() (interface{}, error) {
return client.Completion(ctx, gpt3.CompletionRequest{})
},
"Post https://api.openai.com/v1/engines/davinci/completions: request error",
"Post \"https://api.openai.com/v1/engines/davinci/completions\": request error",
}, {
"CompletionStream",
func() (interface{}, error) {
Expand All @@ -71,13 +71,13 @@ func TestRequestCreationFails(t *testing.T) {
}
return rsp, client.CompletionStream(ctx, gpt3.CompletionRequest{}, onData)
},
"Post https://api.openai.com/v1/engines/davinci/completions: request error",
"Post \"https://api.openai.com/v1/engines/davinci/completions\": request error",
}, {
"CompletionWithEngine",
func() (interface{}, error) {
return client.CompletionWithEngine(ctx, gpt3.AdaEngine, gpt3.CompletionRequest{})
},
"Post https://api.openai.com/v1/engines/ada/completions: request error",
"Post \"https://api.openai.com/v1/engines/ada/completions\": request error",
}, {
"CompletionStreamWithEngine",
func() (interface{}, error) {
Expand All @@ -87,19 +87,25 @@ func TestRequestCreationFails(t *testing.T) {
}
return rsp, client.CompletionStreamWithEngine(ctx, gpt3.AdaEngine, gpt3.CompletionRequest{}, onData)
},
"Post https://api.openai.com/v1/engines/ada/completions: request error",
"Post \"https://api.openai.com/v1/engines/ada/completions\": request error",
}, {
"Edits",
func() (interface{}, error) {
return client.Edits(ctx, gpt3.EditsRequest{})
},
"Post \"https://api.openai.com/v1/edits\": request error",
}, {
"Search",
func() (interface{}, error) {
return client.Search(ctx, gpt3.SearchRequest{})
},
"Post https://api.openai.com/v1/engines/davinci/search: request error",
"Post \"https://api.openai.com/v1/engines/davinci/search\": request error",
}, {
"SearchWithEngine",
func() (interface{}, error) {
return client.SearchWithEngine(ctx, gpt3.AdaEngine, gpt3.SearchRequest{})
},
"Post https://api.openai.com/v1/engines/ada/search: request error",
"Post \"https://api.openai.com/v1/engines/ada/search\": request error",
},
}

Expand Down
37 changes: 37 additions & 0 deletions models.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,22 @@ type CompletionRequest struct {
Stream bool `json:"stream,omitempty"`
}

// EditsRequest is a request for the edits API
type EditsRequest struct {
// ID of the model to use. You can use the List models API to see all of your available models, or see our Model overview for descriptions of them.
Model string `json:"model"`
// The input text to use as a starting point for the edit.
Input string `json:"input"`
// The instruction that tells the model how to edit the prompt.
Instruction string `json:"instruction"`
// Sampling temperature to use
Temperature *float32 `json:"temperature,omitempty"`
// Alternative to temperature for nucleus sampling
TopP *float32 `json:"top_p,omitempty"`
// How many edits to generate for the input and instruction. Defaults to 1
N *int `json:"n"`
}

// LogprobResult represents logprob result of Choice
type LogprobResult struct {
Tokens []string `json:"tokens"`
Expand All @@ -86,6 +102,27 @@ type CompletionResponse struct {
Choices []CompletionResponseChoice `json:"choices"`
}

// EditsResponse is the full response from a request to the edits API
type EditsResponse struct {
Object string `json:"object"`
Created int `json:"created"`
Choices []EditsResponseChoice `json:"choices"`
Usage EditsResponseUsage `json:"usage"`
}

// EditsResponseChoice is one of the choices returned in the response to the Edits API
type EditsResponseChoice struct {
Text string `json:"text"`
Index int `json:"index"`
}

// EditsResponseUsage is a structure used in the response from a request to the edits API
type EditsResponseUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}

// SearchRequest is a request for the document search API
type SearchRequest struct {
Documents []string `json:"documents"`
Expand Down

0 comments on commit 9e6c11b

Please sign in to comment.