Skip to content

Commit

Permalink
#171: update chat.go
Browse files Browse the repository at this point in the history
  • Loading branch information
mkrueger12 committed Mar 24, 2024
1 parent 8c969df commit 0a50c1a
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 46 deletions.
40 changes: 4 additions & 36 deletions pkg/providers/cohere/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,38 +15,6 @@ import (
"go.uber.org/zap"
)

type ChatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}

type ChatHistory struct {
Role string `json:"role"`
Message string `json:"message"`
User string `json:"user,omitempty"`
}

// ChatRequest is a request to complete a chat completion..
type ChatRequest struct {
Model string `json:"model"`
Message string `json:"message"`
Temperature float64 `json:"temperature,omitempty"`
PreambleOverride string `json:"preamble_override,omitempty"`
ChatHistory []ChatHistory `json:"chat_history,omitempty"`
ConversationID string `json:"conversation_id,omitempty"`
PromptTruncation string `json:"prompt_truncation,omitempty"`
Connectors []string `json:"connectors,omitempty"`
SearchQueriesOnly bool `json:"search_queries_only,omitempty"`
CitiationQuality string `json:"citiation_quality,omitempty"`
Stream bool `json:"stream,omitempty"`
}

type Connectors struct {
ID string `json:"id"`
UserAccessToken string `json:"user_access_token"`
ContOnFail string `json:"continue_on_failure"`
Options map[string]string `json:"options"`
}

// NewChatRequestFromConfig fills the struct from the config. Not using reflection because of performance penalty it gives
func NewChatRequestFromConfig(cfg *Config) *ChatRequest {
Expand All @@ -67,7 +35,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest {
// Chat sends a chat request to the specified cohere model.
func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) {
// Create a new chat request
chatRequest := c.createChatRequestSchema(request)
chatRequest := c.createRequestSchema(request)

chatResponse, err := c.doChatRequest(ctx, chatRequest)
if err != nil {
Expand All @@ -81,9 +49,9 @@ func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schem
return chatResponse, nil
}

func (c *Client) createChatRequestSchema(request *schemas.ChatRequest) *ChatRequest {
func (c *Client) createRequestSchema(request *schemas.ChatRequest) *ChatRequest {
// TODO: consider using objectpool to optimize memory allocation
chatRequest := c.chatRequestTemplate // hoping to get a copy of the template
chatRequest := *c.chatRequestTemplate // hoping to get a copy of the template
chatRequest.Message = request.Message.Content

// Build the Cohere specific ChatHistory
Expand All @@ -100,7 +68,7 @@ func (c *Client) createChatRequestSchema(request *schemas.ChatRequest) *ChatRequ
}
}

return chatRequest
return &chatRequest
}

func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) {
Expand Down
20 changes: 10 additions & 10 deletions pkg/providers/cohere/chat_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func TestCohere_ChatStreamRequest(t *testing.T) {

for name, streamFile := range tests {
t.Run(name, func(t *testing.T) {
openAIMock := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cohereMock := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rawPayload, _ := io.ReadAll(r.Body)

var data interface{}
Expand All @@ -47,7 +47,7 @@ func TestCohere_ChatStreamRequest(t *testing.T) {

chatResponse, err := os.ReadFile(filepath.Clean(streamFile))
if err != nil {
t.Errorf("error reading openai chat mock response: %v", err)
t.Errorf("error reading cohere chat mock response: %v", err)
}

w.Header().Set("Content-Type", "text/event-stream")
Expand All @@ -58,14 +58,14 @@ func TestCohere_ChatStreamRequest(t *testing.T) {
}
})

openAIServer := httptest.NewServer(openAIMock)
defer openAIServer.Close()
cohereServer := httptest.NewServer(cohereMock)
defer cohereServer.Close()

ctx := context.Background()
providerCfg := DefaultConfig()
clientCfg := clients.DefaultClientConfig()

providerCfg.BaseURL = openAIServer.URL
providerCfg.BaseURL = cohereServer.URL

client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock())
require.NoError(t, err)
Expand Down Expand Up @@ -99,7 +99,7 @@ func TestCohere_ChatStreamRequestInterrupted(t *testing.T) {

for name, streamFile := range tests {
t.Run(name, func(t *testing.T) {
openAIMock := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cohereMock := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rawPayload, _ := io.ReadAll(r.Body)

var data interface{}
Expand All @@ -111,7 +111,7 @@ func TestCohere_ChatStreamRequestInterrupted(t *testing.T) {

chatResponse, err := os.ReadFile(filepath.Clean(streamFile))
if err != nil {
t.Errorf("error reading openai chat mock response: %v", err)
t.Errorf("error reading cohere chat mock response: %v", err)
}

w.Header().Set("Content-Type", "text/event-stream")
Expand All @@ -122,14 +122,14 @@ func TestCohere_ChatStreamRequestInterrupted(t *testing.T) {
}
})

openAIServer := httptest.NewServer(openAIMock)
defer openAIServer.Close()
cohereServer := httptest.NewServer(cohereMock)
defer cohereServer.Close()

ctx := context.Background()
providerCfg := DefaultConfig()
clientCfg := clients.DefaultClientConfig()

providerCfg.BaseURL = openAIServer.URL
providerCfg.BaseURL = cohereServer.URL

client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock())
require.NoError(t, err)
Expand Down
33 changes: 33 additions & 0 deletions pkg/providers/cohere/schemas.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,36 @@ type FinalResponse struct {
Meta Meta `json:"meta"`
FinishReason string `json:"finish_reason"`
}

type ChatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}

type ChatHistory struct {
Role string `json:"role"`
Message string `json:"message"`
User string `json:"user,omitempty"`
}

// ChatRequest is a request to complete a chat completion..
type ChatRequest struct {
Model string `json:"model"`
Message string `json:"message"`
Temperature float64 `json:"temperature,omitempty"`
PreambleOverride string `json:"preamble_override,omitempty"`
ChatHistory []ChatHistory `json:"chat_history,omitempty"`
ConversationID string `json:"conversation_id,omitempty"`
PromptTruncation string `json:"prompt_truncation,omitempty"`
Connectors []string `json:"connectors,omitempty"`
SearchQueriesOnly bool `json:"search_queries_only,omitempty"`
CitiationQuality string `json:"citiation_quality,omitempty"`
Stream bool `json:"stream,omitempty"`
}

type Connectors struct {
ID string `json:"id"`
UserAccessToken string `json:"user_access_token"`
ContOnFail string `json:"continue_on_failure"`
Options map[string]string `json:"options"`
}

0 comments on commit 0a50c1a

Please sign in to comment.