From 3ecb618fb8b08f89e351b041c8be5ea40421f335 Mon Sep 17 00:00:00 2001 From: enzowritescode <1328683+enzowritescode@users.noreply.github.com> Date: Wed, 21 Aug 2024 07:05:01 -0600 Subject: [PATCH] Anthropic request headers helper (#233) * Add helper function for preparing common request headers * Rename variable to not collide with imported package * Get rid of unnecessary variable colliding with imported package * Enhancements to Anthropic setRequestHeaders() * PR feedback --- server/ai/anthropic/client.go | 22 +++++++++++++--------- server/metrics.go | 10 +++++----- server/plugin.go | 16 ++++++++-------- 3 files changed, 26 insertions(+), 22 deletions(-) diff --git a/server/ai/anthropic/client.go b/server/ai/anthropic/client.go index 307fa5b0..2ef3eb63 100644 --- a/server/ai/anthropic/client.go +++ b/server/ai/anthropic/client.go @@ -15,7 +15,6 @@ import ( const ( MessageEndpoint = "https://api.anthropic.com/v1/messages" - APIKeyHeader = "X-API-Key" //nolint:gosec StopReasonStopSequence = "stop_sequence" StopReasonMaxTokens = "max_tokens" @@ -97,9 +96,7 @@ func (c *Client) MessageCompletionNoStream(completionRequest MessageRequest) (st return "", fmt.Errorf("could not create request: %w", err) } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("X-API-Key", c.apiKey) - req.Header.Set("anthropic-version", "2023-06-01") + c.setRequestHeaders(req, false) resp, err := c.httpClient.Do(req) if err != nil { @@ -135,11 +132,7 @@ func (c *Client) MessageCompletion(completionRequest MessageRequest) (*ai.TextSt return nil, err } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("X-API-Key", c.apiKey) - req.Header.Set("Accept", "text/event-stream") - req.Header.Set("Connection", "keep-alive") - req.Header.Set("anthropic-version", "2023-06-01") + c.setRequestHeaders(req, true) output := make(chan string) errChan := make(chan error) @@ -203,3 +196,14 @@ func (c *Client) MessageCompletion(completionRequest MessageRequest) (*ai.TextSt return &ai.TextStreamResult{Stream: output, Err: errChan}, nil } + +func (c *Client) setRequestHeaders(req *http.Request, isStreaming bool) { + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-API-Key", c.apiKey) + req.Header.Set("anthropic-version", "2023-06-01") + + if isStreaming { + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Connection", "keep-alive") + } +} diff --git a/server/metrics.go b/server/metrics.go index cd5c3e48..adbfaf99 100644 --- a/server/metrics.go +++ b/server/metrics.go @@ -19,12 +19,12 @@ func (p *Plugin) GetMetrics() metrics.Metrics { } func (p *Plugin) metricsMiddleware(c *gin.Context) { - metrics := p.GetMetrics() - if metrics == nil { + llmMetrics := p.GetMetrics() + if llmMetrics == nil { c.Next() return } - p.GetMetrics().IncrementHTTPRequests() + llmMetrics.IncrementHTTPRequests() now := time.Now() c.Next() @@ -34,9 +34,9 @@ func (p *Plugin) metricsMiddleware(c *gin.Context) { status := c.Writer.Status() if status < 200 || status > 299 { - p.GetMetrics().IncrementHTTPErrors() + llmMetrics.IncrementHTTPErrors() } endpoint := c.HandlerName() - p.GetMetrics().ObserveAPIEndpointDuration(endpoint, c.Request.Method, strconv.Itoa(status), elapsed) + llmMetrics.ObserveAPIEndpointDuration(endpoint, c.Request.Method, strconv.Itoa(status), elapsed) } diff --git a/server/plugin.go b/server/plugin.go index 3dbbd6d3..04a621af 100644 --- a/server/plugin.go +++ b/server/plugin.go @@ -139,18 +139,18 @@ func (p *Plugin) OnDeactivate() error { } func (p *Plugin) getLLM(llmBotConfig ai.BotConfig) ai.LanguageModel { - metrics := p.metricsService.GetMetricsForAIService(llmBotConfig.Name) + llmMetrics := p.metricsService.GetMetricsForAIService(llmBotConfig.Name) var llm ai.LanguageModel switch llmBotConfig.Service.Type { case "openai": - llm = openai.New(llmBotConfig.Service, metrics) + llm = openai.New(llmBotConfig.Service, llmMetrics) case "openaicompatible": - llm = openai.NewCompatible(llmBotConfig.Service, metrics) + llm = openai.NewCompatible(llmBotConfig.Service, llmMetrics) case "anthropic": - llm = anthropic.New(llmBotConfig.Service, metrics) + llm = anthropic.New(llmBotConfig.Service, llmMetrics) case "asksage": - llm = asksage.New(llmBotConfig.Service, metrics) + llm = asksage.New(llmBotConfig.Service, llmMetrics) } cfg := p.getConfiguration() @@ -172,12 +172,12 @@ func (p *Plugin) getTranscribe() ai.Transcriber { break } } - metrics := p.metricsService.GetMetricsForAIService(botConfig.Name) + llmMetrics := p.metricsService.GetMetricsForAIService(botConfig.Name) switch botConfig.Service.Type { case "openai": - return openai.New(botConfig.Service, metrics) + return openai.New(botConfig.Service, llmMetrics) case "openaicompatible": - return openai.NewCompatible(botConfig.Service, metrics) + return openai.NewCompatible(botConfig.Service, llmMetrics) } return nil }