Skip to content

Commit

Permalink
feat: support groq ai model (alibaba#967)
Browse files Browse the repository at this point in the history
  • Loading branch information
cr7258 authored May 17, 2024
1 parent 3119ec8 commit 31242c3
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 1 deletion.
85 changes: 85 additions & 0 deletions plugins/wasm-go/extensions/ai-proxy/provider/groq.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package provider

import (
"fmt"

"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)

// groqProvider is the provider for Groq service.
const (
groqDomain = "api.groq.com"
groqChatCompletionPath = "/openai/v1/chat/completions"
)

type groqProviderInitializer struct{}

func (m *groqProviderInitializer) ValidateConfig(config ProviderConfig) error {
return nil
}

func (m *groqProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
return &groqProvider{
config: config,
contextCache: createContextCache(&config),
}, nil
}

type groqProvider struct {
config ProviderConfig
contextCache *contextCache
}

func (m *groqProvider) GetProviderType() string {
return providerTypeGroq
}

func (m *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestPath(groqChatCompletionPath)
_ = util.OverwriteRequestHost(groqDomain)
_ = proxywasm.ReplaceHttpRequestHeader("Authorization", "Bearer "+m.config.GetRandomToken())

if m.contextCache == nil {
ctx.DontReadRequestBody()
} else {
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
}

return types.ActionContinue, nil
}

func (m *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
if m.contextCache == nil {
return types.ActionContinue, nil
}
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
}
2 changes: 1 addition & 1 deletion plugins/wasm-go/extensions/ai-proxy/provider/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)

// azureProvider is the provider for Azure OpenAI service.
// openaiProvider is the provider for OpenAI service.

const (
openaiDomain = "api.openai.com"
Expand Down
2 changes: 2 additions & 0 deletions plugins/wasm-go/extensions/ai-proxy/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ const (
providerTypeAzure = "azure"
providerTypeQwen = "qwen"
providerTypeOpenAI = "openai"
providerTypeGroq = "groq"

protocolOpenAI = "openai"
protocolOriginal = "original"
Expand Down Expand Up @@ -51,6 +52,7 @@ var (
providerTypeAzure: &azureProviderInitializer{},
providerTypeQwen: &qwenProviderInitializer{},
providerTypeOpenAI: &openaiProviderInitializer{},
providerTypeGroq: &groqProviderInitializer{},
}
)

Expand Down

0 comments on commit 31242c3

Please sign in to comment.