Skip to content

Commit

Permalink
Merge pull request #31 from bakks/main
Browse files Browse the repository at this point in the history
Add streaming chat completions
  • Loading branch information
tylermann authored Mar 4, 2023
2 parents 8543706 + 529277f commit ab408b5
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 0 deletions.
59 changes: 59 additions & 0 deletions gpt3.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ type Client interface {
// is what powers the ChatGPT experience.
ChatCompletion(ctx context.Context, request ChatCompletionRequest) (*ChatCompletionResponse, error)

// ChatCompletion creates a completion with the Chat completion endpoint which
// is what powers the ChatGPT experience.
ChatCompletionStream(ctx context.Context, request ChatCompletionRequest, onData func(*ChatCompletionStreamResponse)) error

// Completion creates a completion with the default engine. This is the main endpoint of the API
// which auto-completes based on the given prompt.
Completion(ctx context.Context, request CompletionRequest) (*CompletionResponse, error)
Expand Down Expand Up @@ -166,6 +170,11 @@ func (c *client) Engine(ctx context.Context, engine string) (*EngineObject, erro
}

func (c *client) ChatCompletion(ctx context.Context, request ChatCompletionRequest) (*ChatCompletionResponse, error) {
if request.Model == "" {
request.Model = GPT3Dot5Turbo
}
request.Stream = false

req, err := c.newRequest(ctx, "POST", "/chat/completions", request)
if err != nil {
return nil, err
Expand All @@ -183,6 +192,56 @@ func (c *client) ChatCompletion(ctx context.Context, request ChatCompletionReque
return output, nil
}

func (c *client) ChatCompletionStream(
ctx context.Context,
request ChatCompletionRequest,
onData func(*ChatCompletionStreamResponse)) error {
if request.Model == "" {
request.Model = GPT3Dot5Turbo
}
request.Stream = true

req, err := c.newRequest(ctx, "POST", "/chat/completions", request)
if err != nil {
return err
}

resp, err := c.performRequest(req)
if err != nil {
return err
}

reader := bufio.NewReader(resp.Body)
defer resp.Body.Close()

for {
line, err := reader.ReadBytes('\n')
if err != nil {
return err
}

// make sure there isn't any extra whitespace before or after
line = bytes.TrimSpace(line)
// the completion API only returns data events
if !bytes.HasPrefix(line, dataPrefix) {
continue
}
line = bytes.TrimPrefix(line, dataPrefix)

// the stream is completed when terminated by [DONE]
if bytes.HasPrefix(line, doneSequence) {
break
}
output := new(ChatCompletionStreamResponse)
if err := json.Unmarshal(line, output); err != nil {
return fmt.Errorf("invalid json stream data: %v", err)
}
onData(output)
}

return nil
}

func (c *client) Completion(ctx context.Context, request CompletionRequest) (*CompletionResponse, error) {
return c.CompletionWithEngine(ctx, c.defaultEngine, request)
}
Expand Down
46 changes: 46 additions & 0 deletions models.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,36 @@ type ChatCompletionRequest struct {

// Messages is a list of messages to use as the context for the chat completion.
Messages []ChatCompletionRequestMessage `json:"messages"`

// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic
Temperature float32 `json:"temperature,omitempty"`

// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
TopP float32 `json:"top_p,omitempty"`

// Number of responses to generate
N int `json:"n,omitempty"`

// Whether or not to stream responses back as they are generated
Stream bool `json:"stream,omitempty"`

// Up to 4 sequences where the API will stop generating further tokens.
Stop []string `json:"stop,omitempty"`

// MaxTokens is the maximum number of tokens to return.
MaxTokens int `json:"max_tokens,omitempty"`

// (-2, 2) Penalize tokens that haven't appeared yet in the history.
PresencePenalty float32 `json:"presence_penalty,omitempty"`

// (-2, 2) Penalize tokens that appear too frequently in the history.
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`

// Modify the probability of specific tokens appearing in the completion.
LogitBias map[string]float32 `json:"logit_bias,omitempty"`

// Can be used to identify an end-user
User string `json:"user,omitempty"`
}

// CompletionRequest is a request for the completions API
Expand Down Expand Up @@ -133,6 +163,13 @@ type ChatCompletionResponseChoice struct {
Message ChatCompletionResponseMessage `json:"message"`
}

// ChatCompletionResponseChoice is one of the choices returned in the response to the Chat Completions API
type ChatCompletionStreamResponseChoice struct {
Index int `json:"index"`
FinishReason string `json:"finish_reason"`
Delta ChatCompletionResponseMessage `json:"delta"`
}

// ChatCompletionsResponseUsage is the object that returns how many tokens the completion's request used
type ChatCompletionsResponseUsage struct {
PromptTokens int `json:"prompt_tokens"`
Expand All @@ -150,6 +187,15 @@ type ChatCompletionResponse struct {
Usage ChatCompletionsResponseUsage `json:"usage"`
}

type ChatCompletionStreamResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
Model string `json:"model"`
Choices []ChatCompletionStreamResponseChoice `json:"choices"`
Usage ChatCompletionsResponseUsage `json:"usage"`
}

// CompletionResponseChoice is one of the choices returned in the response to the Completions API
type CompletionResponseChoice struct {
Text string `json:"text"`
Expand Down

0 comments on commit ab408b5

Please sign in to comment.