diff --git a/go.mod b/go.mod index ae3ddf4d..baea63b3 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/prometheus/client_golang v1.19.1 github.com/r3labs/sse/v2 v2.10.0 github.com/rudderlabs/analytics-go v3.3.3+incompatible - github.com/sashabaranov/go-openai v1.25.0 + github.com/sashabaranov/go-openai v1.29.1 github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.8.4 golang.org/x/text v0.16.0 diff --git a/go.sum b/go.sum index 53200437..d8dab06d 100644 --- a/go.sum +++ b/go.sum @@ -229,6 +229,8 @@ github.com/rudderlabs/analytics-go v3.3.3+incompatible/go.mod h1:LF8/ty9kUX4PTY3 github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/sashabaranov/go-openai v1.25.0 h1:3h3DtJ55zQJqc+BR4y/iTcPhLk4pewJpyO+MXW2RdW0= github.com/sashabaranov/go-openai v1.25.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= +github.com/sashabaranov/go-openai v1.29.1 h1:AlB+vwpg1tibwr83OKXLsI4V1rnafVyTlw0BjR+6WUM= +github.com/sashabaranov/go-openai v1.29.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/segmentio/backo-go v1.0.1 h1:68RQccglxZeyURy93ASB/2kc9QudzgIDexJ927N++y4= github.com/segmentio/backo-go v1.0.1/go.mod h1:9/Rh6yILuLysoQnZ2oNooD2g7aBnvM7r/fNVxRNWfBc= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= diff --git a/server/ai/configuration.go b/server/ai/configuration.go index 3a33ffe4..32515ac6 100644 --- a/server/ai/configuration.go +++ b/server/ai/configuration.go @@ -27,7 +27,7 @@ func (c *BotConfig) IsValid() bool { isInvalid := c.Name == "" || c.DisplayName == "" || c.Service.Type == "" || - (c.Service.Type == "openaicompatible" && c.Service.APIURL == "") || - (c.Service.Type != "asksage" && c.Service.Type != "openaicompatible" && c.Service.APIKey == "") + ((c.Service.Type == "openaicompatible" || c.Service.Type == "azure") && c.Service.APIURL == "") || + (c.Service.Type != "asksage" && c.Service.Type != "openaicompatible" && c.Service.Type != "azure" && c.Service.APIKey == "") return !isInvalid } diff --git a/server/ai/openai/openai.go b/server/ai/openai/openai.go index ffffcdfd..30f2fbb9 100644 --- a/server/ai/openai/openai.go +++ b/server/ai/openai/openai.go @@ -10,7 +10,6 @@ import ( "image/png" "io" "net/http" - "net/url" "strings" "time" @@ -39,40 +38,49 @@ const OpenAIMaxImageSize = 20 * 1024 * 1024 // 20 MB var ErrStreamingTimeout = errors.New("timeout streaming") -func NewCompatible(llmService ai.ServiceConfig, httpClient *http.Client, metricsService metrics.LLMetrics) *OpenAI { - apiKey := llmService.APIKey - endpointURL := strings.TrimSuffix(llmService.APIURL, "/") - defaultModel := llmService.DefaultModel - config := openaiClient.DefaultConfig(apiKey) - config.BaseURL = endpointURL - config.HTTPClient = httpClient - - parsedURL, err := url.Parse(endpointURL) - if err == nil && strings.HasSuffix(parsedURL.Host, "openai.azure.com") { - config = openaiClient.DefaultAzureConfig(apiKey, endpointURL) - config.APIVersion = "2023-07-01-preview" - } +func NewAzure(llmService ai.ServiceConfig, httpClient *http.Client, metricsService metrics.LLMetrics) *OpenAI { + return newOpenAI(llmService, httpClient, metricsService, + func(apiKey string) openaiClient.ClientConfig { + config := openaiClient.DefaultAzureConfig(apiKey, strings.TrimSuffix(llmService.APIURL, "/")) + config.APIVersion = "2024-06-01" + return config + }, + ) +} - streamingTimeout := StreamingTimeoutDefault - if llmService.StreamingTimeoutSeconds > 0 { - streamingTimeout = time.Duration(llmService.StreamingTimeoutSeconds) * time.Second - } - return &OpenAI{ - client: openaiClient.NewClientWithConfig(config), - defaultModel: defaultModel, - tokenLimit: llmService.TokenLimit, - streamingTimeout: streamingTimeout, - metricsService: metricsService, - } +func NewCompatible(llmService ai.ServiceConfig, httpClient *http.Client, metricsService metrics.LLMetrics) *OpenAI { + return newOpenAI(llmService, httpClient, metricsService, + func(apiKey string) openaiClient.ClientConfig { + config := openaiClient.DefaultConfig(apiKey) + config.BaseURL = strings.TrimSuffix(llmService.APIURL, "/") + return config + }, + ) } func New(llmService ai.ServiceConfig, httpClient *http.Client, metricsService metrics.LLMetrics) *OpenAI { + return newOpenAI(llmService, httpClient, metricsService, + func(apiKey string) openaiClient.ClientConfig { + config := openaiClient.DefaultConfig(apiKey) + config.OrgID = llmService.OrgID + return config + }, + ) +} + +func newOpenAI( + llmService ai.ServiceConfig, + httpClient *http.Client, + metricsService metrics.LLMetrics, + baseConfigFunc func(apiKey string) openaiClient.ClientConfig, +) *OpenAI { + apiKey := llmService.APIKey defaultModel := llmService.DefaultModel if defaultModel == "" { defaultModel = openaiClient.GPT3Dot5Turbo } - config := openaiClient.DefaultConfig(llmService.APIKey) - config.OrgID = llmService.OrgID + + config := baseConfigFunc(apiKey) config.HTTPClient = httpClient streamingTimeout := StreamingTimeoutDefault diff --git a/server/plugin.go b/server/plugin.go index 0f40afc1..bfbc6b1b 100644 --- a/server/plugin.go +++ b/server/plugin.go @@ -154,6 +154,8 @@ func (p *Plugin) getLLM(llmBotConfig ai.BotConfig) ai.LanguageModel { llm = openai.New(llmBotConfig.Service, p.llmUpstreamHTTPClient, llmMetrics) case "openaicompatible": llm = openai.NewCompatible(llmBotConfig.Service, p.llmUpstreamHTTPClient, llmMetrics) + case "azure": + llm = openai.NewAzure(llmBotConfig.Service, p.llmUpstreamHTTPClient, llmMetrics) case "anthropic": llm = anthropic.New(llmBotConfig.Service, p.llmUpstreamHTTPClient, llmMetrics) case "asksage": @@ -185,6 +187,8 @@ func (p *Plugin) getTranscribe() ai.Transcriber { return openai.New(botConfig.Service, p.llmUpstreamHTTPClient, llmMetrics) case "openaicompatible": return openai.NewCompatible(botConfig.Service, p.llmUpstreamHTTPClient, llmMetrics) + case "azure": + return openai.NewAzure(botConfig.Service, p.llmUpstreamHTTPClient, llmMetrics) } return nil } diff --git a/webapp/src/components/system_console/bot.tsx b/webapp/src/components/system_console/bot.tsx index 7a76df0a..108c1e45 100644 --- a/webapp/src/components/system_console/bot.tsx +++ b/webapp/src/components/system_console/bot.tsx @@ -44,6 +44,7 @@ type Props = { const mapServiceTypeToDisplayName = new Map([ ['openai', 'OpenAI'], ['openaicompatible', 'OpenAI Compatible'], + ['azure', 'Azure'], ['anthropic', 'Anthropic'], ['asksage', 'Ask Sage'], ]); @@ -58,8 +59,8 @@ const Bot = (props: Props) => { const missingInfo = props.bot.name === '' || props.bot.displayName === '' || props.bot.service.type === '' || - (props.bot.service.type !== 'asksage' && props.bot.service.type !== 'openaicompatible' && props.bot.service.apiKey === '') || - (props.bot.service.type === 'openaicompatible' && props.bot.service.apiURL === ''); + (props.bot.service.type !== 'asksage' && props.bot.service.type !== 'openaicompatible' && props.bot.service.type !== 'azure' && props.bot.service.apiKey === '') || + ((props.bot.service.type === 'openaicompatible' || props.bot.service.type === 'azure') && props.bot.service.apiURL === ''); const invalidUsername = props.bot.name !== '' && (!(/^[a-z0-9.\-_]+$/).test(props.bot.name) || !(/[a-z]/).test(props.bot.name.charAt(0))); return ( @@ -121,6 +122,7 @@ const Bot = (props: Props) => { > {'OpenAI'} {'OpenAI Compatible'} + {'Azure'} {'Anthropic'} {'Ask Sage (Experimental)'} @@ -135,7 +137,7 @@ const Bot = (props: Props) => { value={props.bot.customInstructions} onChange={(e) => props.onChange({...props.bot, customInstructions: e.target.value})} /> - { (props.bot.service.type === 'openai' || props.bot.service.type === 'openaicompatible') && ( + { (props.bot.service.type === 'openai' || props.bot.service.type === 'openaicompatible' || props.bot.service.type === 'azure') && ( <> { const type = props.service.type; const intl = useIntl(); const hasAPIKey = type !== 'asksage'; - const isOpenAIType = type === 'openai' || type === 'openaicompatible'; + const isOpenAIType = type === 'openai' || type === 'openaicompatible' || type === 'azure'; return ( <> - {type === 'openaicompatible' && ( + {(type === 'openaicompatible' || type === 'azure') && (