diff --git a/CODEOWNERS b/CODEOWNERS index 3d36c596c3..c875a7968a 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -2,7 +2,8 @@ /envoy @gengleilei @johnlanni /istio @SpecialYang @johnlanni /pkg @SpecialYang @johnlanni @CH3CHO -/plugins @johnlanni @WeixinX @CH3CHO +/plugins @johnlanni @CH3CHO @rinfx +/plugins/wasm-go/extensions/ai-proxy @cr7258 @CH3CHO @rinfx /plugins/wasm-rust @007gzs @jizhuozhi /registry @NameHaibinZhang @2456868764 @johnlanni /test @Xunzhuo @2456868764 @CH3CHO diff --git a/README.md b/README.md index fd15371c6b..e27042f67c 100644 --- a/README.md +++ b/README.md @@ -6,9 +6,14 @@

AI Native API Gateway

+
+ [![Build Status](https://github.com/alibaba/higress/actions/workflows/build-and-test.yaml/badge.svg?branch=main)](https://github.com/alibaba/higress/actions) [![license](https://img.shields.io/github/license/alibaba/higress.svg)](https://www.apache.org/licenses/LICENSE-2.0.html) +alibaba%2Fhigress | Trendshift +
+ [**官网**](https://higress.cn/)   |   [**文档**](https://higress.cn/docs/latest/overview/what-is-higress/)   |   [**博客**](https://higress.cn/blog/)   | @@ -17,6 +22,7 @@   [**AI插件**](https://higress.cn/plugin/)   +

English | 中文 | 日本語

diff --git a/plugins/wasm-go/extensions/ai-cache/cache/provider.go b/plugins/wasm-go/extensions/ai-cache/cache/provider.go index 1238d21570..d68acd5099 100644 --- a/plugins/wasm-go/extensions/ai-cache/cache/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/cache/provider.go @@ -2,6 +2,7 @@ package cache import ( "errors" + "strings" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/tidwall/gjson" @@ -62,7 +63,12 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { c.serviceName = json.Get("serviceName").String() c.servicePort = int(json.Get("servicePort").Int()) if !json.Get("servicePort").Exists() { - c.servicePort = 6379 + if strings.HasSuffix(c.serviceName, ".static") { + // use default logic port which is 80 for static service + c.servicePort = 80 + } else { + c.servicePort = 6379 + } } c.serviceHost = json.Get("serviceHost").String() c.username = json.Get("username").String() diff --git a/plugins/wasm-go/extensions/ai-cache/core.go b/plugins/wasm-go/extensions/ai-cache/core.go index 19a9b2b856..b46fd28e8e 100644 --- a/plugins/wasm-go/extensions/ai-cache/core.go +++ b/plugins/wasm-go/extensions/ai-cache/core.go @@ -74,6 +74,9 @@ func processCacheHit(key string, response string, stream bool, ctx wrapper.HttpC ctx.SetContext(CACHE_KEY_CONTEXT_KEY, nil) + ctx.SetUserAttribute("cache_status", "hit") + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + if stream { proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}, []byte(fmt.Sprintf(c.StreamResponseTemplate, escapedResponse)), -1) } else { diff --git a/plugins/wasm-go/extensions/ai-cache/go.mod b/plugins/wasm-go/extensions/ai-cache/go.mod index e4aae265e0..56bea605f4 100644 --- a/plugins/wasm-go/extensions/ai-cache/go.mod +++ b/plugins/wasm-go/extensions/ai-cache/go.mod @@ -8,14 +8,14 @@ replace github.com/alibaba/higress/plugins/wasm-go => ../.. require ( github.com/alibaba/higress/plugins/wasm-go v1.4.2 - github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f + github.com/google/uuid v1.6.0 + github.com/higress-group/proxy-wasm-go-sdk v1.0.0 github.com/tidwall/gjson v1.17.3 github.com/tidwall/resp v0.1.1 // github.com/weaviate/weaviate-go-client/v4 v4.15.1 ) require ( - github.com/google/uuid v1.6.0 // indirect github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect github.com/magefile/mage v1.14.0 // indirect github.com/stretchr/testify v1.9.0 // indirect diff --git a/plugins/wasm-go/extensions/ai-cache/go.sum b/plugins/wasm-go/extensions/ai-cache/go.sum index 7ada0c8b70..0a3635868b 100644 --- a/plugins/wasm-go/extensions/ai-cache/go.sum +++ b/plugins/wasm-go/extensions/ai-cache/go.sum @@ -3,8 +3,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= +github.com/higress-group/proxy-wasm-go-sdk v1.0.0 h1:BZRNf4R7jr9hwRivg/E29nkVaKEak5MWjBDhWjuHijU= +github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0= github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo= github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/plugins/wasm-go/extensions/ai-cache/main.go b/plugins/wasm-go/extensions/ai-cache/main.go index 1aca29f0ec..62edb80dcb 100644 --- a/plugins/wasm-go/extensions/ai-cache/main.go +++ b/plugins/wasm-go/extensions/ai-cache/main.go @@ -128,9 +128,15 @@ func onHttpRequestBody(ctx wrapper.HttpContext, c config.PluginConfig, body []by func onHttpResponseHeaders(ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log) types.Action { skipCache := ctx.GetContext(SKIP_CACHE_HEADER) if skipCache != nil { + ctx.SetUserAttribute("cache_status", "skip") + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) ctx.DontReadResponseBody() return types.ActionContinue } + if ctx.GetContext(CACHE_KEY_CONTEXT_KEY) != nil { + ctx.SetUserAttribute("cache_status", "miss") + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + } contentType, _ := proxywasm.GetHttpResponseHeader("content-type") if strings.Contains(contentType, "text/event-stream") { ctx.SetContext(STREAM_CONTEXT_KEY, struct{}{}) diff --git a/plugins/wasm-go/extensions/ai-cache/util.go b/plugins/wasm-go/extensions/ai-cache/util.go index 983dfbb25a..7fbd4954e2 100644 --- a/plugins/wasm-go/extensions/ai-cache/util.go +++ b/plugins/wasm-go/extensions/ai-cache/util.go @@ -101,55 +101,58 @@ func processStreamLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, chun } func processSSEMessage(ctx wrapper.HttpContext, c config.PluginConfig, sseMessage string, log wrapper.Log) (string, error) { - subMessages := strings.Split(sseMessage, "\n") - var message string - for _, msg := range subMessages { - if strings.HasPrefix(msg, "data:") { - message = msg - break + content := "" + for _, chunk := range strings.Split(sseMessage, "\n\n") { + log.Infof("chunk _ : %s", chunk) + subMessages := strings.Split(chunk, "\n") + var message string + for _, msg := range subMessages { + if strings.HasPrefix(msg, "data:") { + message = msg + break + } + } + if len(message) < 6 { + return content, fmt.Errorf("[processSSEMessage] invalid message: %s", message) } - } - if len(message) < 6 { - return "", fmt.Errorf("[processSSEMessage] invalid message: %s", message) - } - // skip the prefix "data:" - bodyJson := message[5:] + // skip the prefix "data:" + bodyJson := message[5:] - if strings.TrimSpace(bodyJson) == "[DONE]" { - return "", nil - } + if strings.TrimSpace(bodyJson) == "[DONE]" { + return content, nil + } - // Extract values from JSON fields - responseBody := gjson.Get(bodyJson, c.CacheStreamValueFrom) - toolCalls := gjson.Get(bodyJson, c.CacheToolCallsFrom) + // Extract values from JSON fields + responseBody := gjson.Get(bodyJson, c.CacheStreamValueFrom) + toolCalls := gjson.Get(bodyJson, c.CacheToolCallsFrom) - if toolCalls.Exists() { - // TODO: Temporarily store the tool_calls value in the context for processing - ctx.SetContext(TOOL_CALLS_CONTEXT_KEY, toolCalls.String()) - } - - // Check if the ResponseBody field exists - if !responseBody.Exists() { - if ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) != nil { - log.Debugf("[processSSEMessage] unable to extract content from message; cache content is not nil: %s", message) - return "", nil + if toolCalls.Exists() { + // TODO: Temporarily store the tool_calls value in the context for processing + ctx.SetContext(TOOL_CALLS_CONTEXT_KEY, toolCalls.String()) } - return "", fmt.Errorf("[processSSEMessage] unable to extract content from message; cache content is nil: %s", message) - } else { - tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) - // If there is no content in the cache, initialize and set the content - if tempContentI == nil { - content := responseBody.String() - ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, content) - return content, nil - } + // Check if the ResponseBody field exists + if !responseBody.Exists() { + if ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) != nil { + log.Debugf("[processSSEMessage] unable to extract content from message; cache content is not nil: %s", message) + return content, nil + } + return content, fmt.Errorf("[processSSEMessage] unable to extract content from message; cache content is nil: %s", message) + } else { + tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) - // Update the content in the cache - appendMsg := responseBody.String() - content := tempContentI.(string) + appendMsg - ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, content) - return content, nil + // If there is no content in the cache, initialize and set the content + if tempContentI == nil { + content = responseBody.String() + ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, content) + } else { + // Update the content in the cache + appendMsg := responseBody.String() + content = tempContentI.(string) + appendMsg + ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, content) + } + } } + return content, nil } diff --git a/plugins/wasm-go/extensions/ai-history/go.sum b/plugins/wasm-go/extensions/ai-history/go.sum index 6b1c2c3cd7..b4ab172fe2 100644 --- a/plugins/wasm-go/extensions/ai-history/go.sum +++ b/plugins/wasm-go/extensions/ai-history/go.sum @@ -3,15 +3,13 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= +github.com/higress-group/proxy-wasm-go-sdk v1.0.0 h1:BZRNf4R7jr9hwRivg/E29nkVaKEak5MWjBDhWjuHijU= github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0= github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo= github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= -github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94= github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= diff --git a/plugins/wasm-go/extensions/ai-history/main.go b/plugins/wasm-go/extensions/ai-history/main.go index 512e13f1c6..3f728dd96d 100644 --- a/plugins/wasm-go/extensions/ai-history/main.go +++ b/plugins/wasm-go/extensions/ai-history/main.go @@ -194,6 +194,12 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte ctx.SetContext(StreamContextKey, struct{}{}) } identityKey := ctx.GetStringContext(IdentityKey, "") + question := TrimQuote(bodyJson.Get(config.QuestionFrom.RequestBody).String()) + if question == "" { + log.Debug("parse question from request body failed") + return types.ActionContinue + } + ctx.SetContext(QuestionContextKey, question) err := config.redisClient.Get(config.CacheKeyPrefix+identityKey, func(response resp.Value) { if err := response.Error(); err != nil { log.Errorf("redis get failed, err:%v", err) @@ -230,13 +236,6 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte _ = proxywasm.SendHttpResponseWithDetail(200, "OK", [][2]string{{"content-type", "application/json; charset=utf-8"}}, res, -1) return } - question := TrimQuote(bodyJson.Get(config.QuestionFrom.RequestBody).String()) - if question == "" { - log.Debug("parse question from request body failed") - _ = proxywasm.ResumeHttpRequest() - return - } - ctx.SetContext(QuestionContextKey, question) fillHistoryCnt := getIntQueryParameter("fill_history_cnt", path, config.FillHistoryCnt) * 2 currJson := bodyJson.Get("messages").String() var currMessage []ChatHistory @@ -317,38 +316,39 @@ func getIntQueryParameter(name string, path string, defaultValue int) int { } func processSSEMessage(ctx wrapper.HttpContext, config PluginConfig, sseMessage string, log wrapper.Log) string { - subMessages := strings.Split(sseMessage, "\n") - var message string - for _, msg := range subMessages { - if strings.HasPrefix(msg, "data:") { - message = msg - break + content := "" + for _, chunk := range strings.Split(sseMessage, "\n\n") { + subMessages := strings.Split(chunk, "\n") + var message string + for _, msg := range subMessages { + if strings.HasPrefix(msg, "data:") { + message = msg + break + } } - } - if len(message) < 6 { - log.Errorf("invalid message:%s", message) - return "" - } - // skip the prefix "data:" - bodyJson := message[5:] - if gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Exists() { - tempContentI := ctx.GetContext(AnswerContentContextKey) - if tempContentI == nil { - content := TrimQuote(gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Raw) - ctx.SetContext(AnswerContentContextKey, content) + if len(message) < 6 { + log.Errorf("invalid message:%s", message) return content } - append := TrimQuote(gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Raw) - content := tempContentI.(string) + append - ctx.SetContext(AnswerContentContextKey, content) - return content - } else if gjson.Get(bodyJson, "choices.0.delta.content.tool_calls").Exists() { - // TODO: compatible with other providers - ctx.SetContext(ToolCallsContextKey, struct{}{}) - return "" + // skip the prefix "data:" + bodyJson := message[5:] + if gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Exists() { + tempContentI := ctx.GetContext(AnswerContentContextKey) + if tempContentI == nil { + content = TrimQuote(gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Raw) + ctx.SetContext(AnswerContentContextKey, content) + } else { + append := TrimQuote(gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Raw) + content = tempContentI.(string) + append + ctx.SetContext(AnswerContentContextKey, content) + } + } else if gjson.Get(bodyJson, "choices.0.delta.content.tool_calls").Exists() { + // TODO: compatible with other providers + ctx.SetContext(ToolCallsContextKey, struct{}{}) + } + log.Debugf("unknown message:%s", bodyJson) } - log.Debugf("unknown message:%s", bodyJson) - return "" + return content } func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action { diff --git a/plugins/wasm-go/extensions/ai-proxy/README.md b/plugins/wasm-go/extensions/ai-proxy/README.md index 8317f653d4..80b7c2a890 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README.md +++ b/plugins/wasm-go/extensions/ai-proxy/README.md @@ -174,9 +174,10 @@ Mistral 所对应的 `type` 为 `mistral`。它并无特有的配置字段。 MiniMax所对应的 `type` 为 `minimax`。它特有的配置字段如下: -| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | -| ---------------- | -------- | ------------------------------------------------------------ | ------ | ------------------------------------------------------------ | -| `minimaxGroupId` | string | 当使用`abab6.5-chat`, `abab6.5s-chat`, `abab5.5s-chat`, `abab5.5-chat`四种模型时必填 | - | 当使用`abab6.5-chat`, `abab6.5s-chat`, `abab5.5s-chat`, `abab5.5-chat`四种模型时会使用ChatCompletion Pro,需要设置groupID | +| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | +| ---------------- | -------- | ------------------------------ | ------ |----------------------------------------------------------------| +| `minimaxApiType` | string | v2 和 pro 中选填一项 | v2 | v2 代表 ChatCompletion v2 API,pro 代表 ChatCompletion Pro API | +| `minimaxGroupId` | string | `minimaxApiType` 为 pro 时必填 | - | `minimaxApiType` 为 pro 时使用 ChatCompletion Pro API,需要设置 groupID | #### Anthropic Claude @@ -1000,17 +1001,16 @@ provider: apiTokens: - "YOUR_MINIMAX_API_TOKEN" modelMapping: - "gpt-3": "abab6.5g-chat" - "gpt-4": "abab6.5-chat" - "*": "abab6.5g-chat" - minimaxGroupId: "YOUR_MINIMAX_GROUP_ID" + "gpt-3": "abab6.5s-chat" + "gpt-4": "abab6.5g-chat" + "*": "abab6.5t-chat" ``` **请求示例** ```json { - "model": "gpt-4-turbo", + "model": "gpt-3", "messages": [ { "role": "user", @@ -1025,27 +1025,33 @@ provider: ```json { - "id": "02b2251f8c6c09d68c1743f07c72afd7", + "id": "03ac4fcfe1c6cc9c6a60f9d12046e2b4", "choices": [ { "finish_reason": "stop", "index": 0, "message": { - "content": "你好!我是MM智能助理,一款由MiniMax自研的大型语言模型。我可以帮助你解答问题,提供信息,进行对话等。有什么可以帮助你的吗?", - "role": "assistant" + "content": "你好,我是一个由MiniMax公司研发的大型语言模型,名为MM智能助理。我可以帮助回答问题、提供信息、进行对话和执行多种语言处理任务。如果你有任何问题或需要帮助,请随时告诉我!", + "role": "assistant", + "name": "MM智能助理", + "audio_content": "" } } ], - "created": 1717760544, + "created": 1734155471, "model": "abab6.5s-chat", "object": "chat.completion", "usage": { - "total_tokens": 106 + "total_tokens": 116, + "total_characters": 0, + "prompt_tokens": 70, + "completion_tokens": 46 }, "input_sensitive": false, "output_sensitive": false, "input_sensitive_type": 0, "output_sensitive_type": 0, + "output_sensitive_int": 0, "base_resp": { "status_code": 0, "status_msg": "" diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 0bc62175e2..3f4dc49bab 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -89,29 +89,35 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf } if apiName == "" { - log.Debugf("[onHttpRequestHeader] unsupported path: %s", path.Path) - // _ = util.SendResponse(404, "ai-proxy.unknown_api", util.MimeTypeTextPlain, "API not found: "+path.Path) - log.Debugf("[onHttpRequestHeader] no send response") + log.Warnf("[onHttpRequestHeader] unsupported path: %s", path.Path) return types.ActionContinue } + // Disable the route re-calculation since the plugin may modify some headers related to the chosen route. + ctx.DisableReroute() + ctx.SetContext(ctxKeyApiName, apiName) + _, needHandleBody := activeProvider.(provider.ResponseBodyHandler) + _, needHandleStreamingBody := activeProvider.(provider.StreamingResponseBodyHandler) + if needHandleBody || needHandleStreamingBody { + proxywasm.RemoveHttpRequestHeader("Accept-Encoding") + } + if handler, ok := activeProvider.(provider.RequestHeadersHandler); ok { - // Disable the route re-calculation since the plugin may modify some headers related to the chosen route. - ctx.DisableReroute() // Set the apiToken for the current request. providerConfig.SetApiTokenInUse(ctx, log) hasRequestBody := wrapper.HasRequestBody() - action, err := handler.OnRequestHeaders(ctx, apiName, log) + err := handler.OnRequestHeaders(ctx, apiName, log) if err == nil { if hasRequestBody { + proxywasm.RemoveHttpRequestHeader("Content-Length") ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes) - // Always return types.HeaderStopIteration to support fallback routing, - // as long as onHttpRequestBody can be called. + // Delay the header processing to allow changing in OnRequestBody return types.HeaderStopIteration } - return action + ctx.DontReadRequestBody() + return types.ActionContinue } util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err)) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go index 6f42d570d0..fa5f1362c1 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go @@ -40,13 +40,13 @@ func (m *ai360Provider) GetProviderType() string { return providerTypeAi360 } -func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) // Delay the header processing to allow changing streaming mode in OnRequestBody - return types.HeaderStopIteration, nil + return nil } func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { @@ -58,7 +58,5 @@ func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, func (m *ai360Provider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { util.OverwriteRequestHostHeader(headers, ai360Domain) - util.OverwriteRequestAuthorizationHeader(headers, "Authorization "+m.config.GetApiTokenInUse(ctx)) - headers.Del("Accept-Encoding") - headers.Del("Content-Length") + util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx)) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go index b09cdd0951..9e02d0fd9a 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go @@ -53,12 +53,12 @@ func (m *azureProvider) GetProviderType() string { return providerTypeAzure } -func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) - return types.ActionContinue, nil + return nil } func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { @@ -86,6 +86,6 @@ func (m *azureProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName util.OverwriteRequestPathHeader(headers, m.serviceUrl.RequestURI()) } util.OverwriteRequestHostHeader(headers, m.serviceUrl.Host) - util.OverwriteRequestAuthorizationHeader(headers, "api-key "+m.config.GetApiTokenInUse(ctx)) + headers.Set("api-key", m.config.GetApiTokenInUse(ctx)) headers.Del("Content-Length") } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go index b43ba8ee26..759c2dd036 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go @@ -42,12 +42,12 @@ func (m *baichuanProvider) GetProviderType() string { return providerTypeBaichuan } -func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) - return types.ActionContinue, nil + return nil } func (m *baichuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go index 0908836290..595ef3d4ff 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go @@ -63,12 +63,12 @@ func (g *baiduProvider) GetProviderType() string { return providerTypeBaidu } -func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } g.config.handleRequestHeaders(g, ctx, apiName, log) - return types.ActionContinue, nil + return nil } func (g *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go index 8b98d62d64..9943469749 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go @@ -102,27 +102,25 @@ func (c *claudeProvider) GetProviderType() string { return providerTypeClaude } -func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } c.config.handleRequestHeaders(c, ctx, apiName, log) - return types.ActionContinue, nil + return nil } func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { util.OverwriteRequestPathHeader(headers, claudeChatCompletionPath) util.OverwriteRequestHostHeader(headers, claudeDomain) - headers.Add("x-api-key", c.config.GetApiTokenInUse(ctx)) + headers.Set("x-api-key", c.config.GetApiTokenInUse(ctx)) if c.config.claudeVersion == "" { c.config.claudeVersion = defaultVersion } - headers.Add("anthropic-version", c.config.claudeVersion) - headers.Del("Accept-Encoding") - headers.Del("Content-Length") + headers.Set("anthropic-version", c.config.claudeVersion) } func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go index 2f6108b0df..4340183ee4 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go @@ -42,12 +42,12 @@ func (c *cloudflareProvider) GetProviderType() string { return providerTypeCloudflare } -func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } c.config.handleRequestHeaders(c, ctx, apiName, log) - return types.ActionContinue, nil + return nil } func (c *cloudflareProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { @@ -61,6 +61,4 @@ func (c *cloudflareProvider) TransformRequestHeaders(ctx wrapper.HttpContext, ap util.OverwriteRequestPathHeader(headers, strings.Replace(cloudflareChatCompletionPath, "{account_id}", c.config.cloudflareAccountId, 1)) util.OverwriteRequestHostHeader(headers, cloudflareDomain) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+c.config.GetApiTokenInUse(ctx)) - headers.Del("Accept-Encoding") - headers.Del("Content-Length") } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go index 72dbaf280b..a3b930e7fb 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go @@ -3,11 +3,12 @@ package provider import ( "encoding/json" "errors" + "net/http" + "strings" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" - "net/http" - "strings" ) const ( @@ -54,12 +55,12 @@ func (m *cohereProvider) GetProviderType() string { return providerTypeCohere } -func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) - return types.ActionContinue, nil + return nil } func (m *cohereProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/coze.go b/plugins/wasm-go/extensions/ai-proxy/provider/coze.go index 878bbb9f9a..43cdca60fb 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/coze.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/coze.go @@ -6,7 +6,6 @@ import ( "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" ) const ( @@ -38,9 +37,9 @@ func (m *cozeProvider) GetProviderType() string { return providerTypeCoze } -func (m *cozeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (m *cozeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { m.config.handleRequestHeaders(m, ctx, apiName, log) - return types.ActionContinue, nil + return nil } func (m *cozeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go index bafe6b3dde..345a70c94a 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go @@ -76,19 +76,17 @@ func (d *deeplProvider) GetProviderType() string { return providerTypeDeepl } -func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } d.config.handleRequestHeaders(d, ctx, apiName, log) - return types.HeaderStopIteration, nil + return nil } func (d *deeplProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { util.OverwriteRequestPathHeader(headers, deeplChatCompletionPath) util.OverwriteRequestAuthorizationHeader(headers, "DeepL-Auth-Key "+d.config.GetApiTokenInUse(ctx)) - headers.Del("Content-Length") - headers.Del("Accept-Encoding") } func (d *deeplProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go index 9cad3928f5..7d240f09ae 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go @@ -2,10 +2,11 @@ package provider import ( "errors" + "net/http" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" - "net/http" ) // deepseekProvider is the provider for deepseek Ai service. @@ -41,12 +42,12 @@ func (m *deepseekProvider) GetProviderType() string { return providerTypeDeepSeek } -func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) - return types.ActionContinue, nil + return nil } func (m *deepseekProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go b/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go index 651b983206..96a4aab548 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go @@ -2,11 +2,12 @@ package provider import ( "errors" + "net/http" + "strings" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" - "net/http" - "strings" ) const ( @@ -39,12 +40,12 @@ func (m *doubaoProvider) GetProviderType() string { return providerTypeDoubao } -func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) - return types.ActionContinue, nil + return nil } func (m *doubaoProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go index a4c1ef2cd9..7a9b0a3dd0 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go @@ -51,20 +51,18 @@ func (g *geminiProvider) GetProviderType() string { return providerTypeGemini } -func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } g.config.handleRequestHeaders(g, ctx, apiName, log) // Delay the header processing to allow changing streaming mode in OnRequestBody - return types.HeaderStopIteration, nil + return nil } func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { util.OverwriteRequestHostHeader(headers, geminiDomain) - headers.Add(geminiApiKeyHeader, g.config.GetApiTokenInUse(ctx)) - headers.Del("Accept-Encoding") - headers.Del("Content-Length") + headers.Set(geminiApiKeyHeader, g.config.GetApiTokenInUse(ctx)) } func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/github.go b/plugins/wasm-go/extensions/ai-proxy/provider/github.go index 0a2b0c84de..348134c0a5 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/github.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/github.go @@ -2,11 +2,12 @@ package provider import ( "errors" + "net/http" + "strings" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" - "net/http" - "strings" ) // githubProvider is the provider for GitHub OpenAI service. @@ -42,13 +43,13 @@ func (m *githubProvider) GetProviderType() string { return providerTypeGithub } -func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) // Delay the header processing to allow changing streaming mode in OnRequestBody - return types.HeaderStopIteration, nil + return nil } func (m *githubProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { @@ -67,8 +68,6 @@ func (m *githubProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam util.OverwriteRequestPathHeader(headers, githubEmbeddingPath) } util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx)) - headers.Del("Accept-Encoding") - headers.Del("Content-Length") } func (m *githubProvider) GetApiName(path string) ApiName { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/groq.go b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go index dfbd971261..5f2734519d 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/groq.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go @@ -41,12 +41,12 @@ func (g *groqProvider) GetProviderType() string { return providerTypeGroq } -func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } g.config.handleRequestHeaders(g, ctx, apiName, log) - return types.ActionContinue, nil + return nil } func (g *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go index b6a49eb551..4b10a4d7c5 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go @@ -114,13 +114,13 @@ func (m *hunyuanProvider) GetProviderType() string { return providerTypeHunyuan } -func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) // Delay the header processing to allow changing streaming mode in OnRequestBody - return types.HeaderStopIteration, nil + return nil } func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { @@ -128,11 +128,8 @@ func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNa util.OverwriteRequestPathHeader(headers, hunyuanRequestPath) // 添加 hunyuan 需要的自定义字段 - headers.Add(actionKey, hunyuanChatCompletionTCAction) - headers.Add(versionKey, versionValue) - - headers.Del("Accept-Encoding") - headers.Del("Content-Length") + headers.Set(actionKey, hunyuanChatCompletionTCAction) + headers.Set(versionKey, versionValue) } // hunyuan 的 OnRequestBody 逻辑中包含了对 headers 签名的逻辑,并且插入 context 以后还要重新计算签名,因此无法复用 handleRequestBody 方法 diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go index 0bcf7ac326..9531edcf11 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go @@ -11,47 +11,37 @@ import ( "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" ) // minimaxProvider is the provider for minimax service. const ( - minimaxDomain = "api.minimax.chat" - // minimaxChatCompletionV2Path 接口请求响应格式与OpenAI相同 - // 接口文档: https://platform.minimaxi.com/document/guides/chat-model/V2?id=65e0736ab2845de20908e2dd + minimaxApiTypeV2 = "v2" // minimaxApiTypeV2 represents chat completion V2 API. + minimaxApiTypePro = "pro" // minimaxApiTypePro represents chat completion Pro API. + minimaxDomain = "api.minimax.chat" + // minimaxChatCompletionV2Path represents the API path for chat completion V2 API which has a response format similar to OpenAI's. minimaxChatCompletionV2Path = "/v1/text/chatcompletion_v2" - // minimaxChatCompletionProPath 接口请求响应格式与OpenAI不同 - // 接口文档: https://platform.minimaxi.com/document/guides/chat-model/pro/api?id=6569c85948bc7b684b30377e + // minimaxChatCompletionProPath represents the API path for chat completion Pro API which has a different response format from OpenAI's. minimaxChatCompletionProPath = "/v1/text/chatcompletion_pro" - senderTypeUser string = "USER" // 用户发送的内容 - senderTypeBot string = "BOT" // 模型生成的内容 + senderTypeUser string = "USER" // Content sent by the user. + senderTypeBot string = "BOT" // Content generated by the model. - // 默认机器人设置 + // Default bot settings. defaultBotName string = "MM智能助理" defaultBotSettingContent string = "MM智能助理是一款由MiniMax自研的,没有调用其他产品的接口的大型语言模型。MiniMax是一家中国科技公司,一直致力于进行大模型相关的研究。" defaultSenderName string = "小明" ) -// chatCompletionProModels 这些模型对应接口为ChatCompletion Pro -var chatCompletionProModels = map[string]struct{}{ - "abab6.5-chat": {}, - "abab6.5s-chat": {}, - "abab5.5s-chat": {}, - "abab5.5-chat": {}, -} - type minimaxProviderInitializer struct { } func (m *minimaxProviderInitializer) ValidateConfig(config ProviderConfig) error { - // 如果存在模型对应接口为ChatCompletion Pro必须配置minimaxGroupId - if len(config.modelMapping) > 0 && config.minimaxGroupId == "" { - for _, minimaxModel := range config.modelMapping { - if _, exists := chatCompletionProModels[minimaxModel]; exists { - return errors.New(fmt.Sprintf("missing minimaxGroupId in provider config when %s model is provided", minimaxModel)) - } - } + // If using the chat completion Pro API, a group ID must be set. + if minimaxApiTypePro == config.minimaxApiType && config.minimaxGroupId == "" { + return errors.New(fmt.Sprintf("missing minimaxGroupId in provider config when minimaxApiType is %s", minimaxApiTypePro)) } if config.apiTokens == nil || len(config.apiTokens) == 0 { return errors.New("no apiToken found in provider config") @@ -75,13 +65,13 @@ func (m *minimaxProvider) GetProviderType() string { return providerTypeMinimax } -func (m *minimaxProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (m *minimaxProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) // Delay the header processing to allow changing streaming mode in OnRequestBody - return types.HeaderStopIteration, nil + return nil } func (m *minimaxProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { @@ -94,23 +84,11 @@ func (m *minimaxProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - // 解析并映射模型,设置上下文 - model, err := m.parseModel(body) - if err != nil { - return types.ActionContinue, err - } - ctx.SetContext(ctxKeyOriginalRequestModel, model) - mappedModel := getMappedModel(model, m.config.modelMapping, log) - if mappedModel == "" { - return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") - } - ctx.SetContext(ctxKeyFinalRequestModel, mappedModel) - _, ok := chatCompletionProModels[mappedModel] - if ok { - // 使用ChatCompletion Pro接口 + if minimaxApiTypePro == m.config.minimaxApiType { + // Use chat completion Pro API. return m.handleRequestBodyByChatCompletionPro(body, log) } else { - // 使用ChatCompletion v2接口 + // Use chat completion V2 API. return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } } @@ -119,14 +97,14 @@ func (m *minimaxProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, a return m.handleRequestBodyByChatCompletionV2(body, headers, log) } -// handleRequestBodyByChatCompletionPro 使用ChatCompletion Pro接口处理请求体 +// handleRequestBodyByChatCompletionPro processes the request body using the chat completion Pro API. func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log wrapper.Log) (types.Action, error) { request := &chatCompletionRequest{} if err := decodeChatCompletionRequest(body, request); err != nil { return types.ActionContinue, err } - // 映射模型重写requestPath + // Map the model and rewrite the request path. request.Model = getMappedModel(request.Model, m.config.modelMapping, log) _ = util.OverwriteRequestPath(fmt.Sprintf("%s?GroupId=%s", minimaxChatCompletionProPath, m.config.minimaxGroupId)) @@ -143,9 +121,9 @@ func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log log.Errorf("failed to load context file: %v", err) util.ErrorHandler("ai-proxy.minimax.load_ctx_failed", fmt.Errorf("failed to load context file: %v", err)) } - // 由于 minimaxChatCompletionV2(格式和 OpenAI 一致)和 minimaxChatCompletionPro(格式和 OpenAI 不一致)中 insertHttpContextMessage 的逻辑不同,无法做到同一个 provider 统一 - // 因此对于 minimaxChatCompletionPro 需要手动处理 context 消息 - // minimaxChatCompletionV2 交给默认的 defaultInsertHttpContextMessage 方法插入 context 消息 + // Since minimaxChatCompletionV2 (format consistent with OpenAI) and minimaxChatCompletionPro (different format from OpenAI) have different logic for insertHttpContextMessage, we cannot unify them within one provider. + // For minimaxChatCompletionPro, we need to manually handle context messages. + // minimaxChatCompletionV2 uses the default defaultInsertHttpContextMessage method to insert context messages. minimaxRequest := m.buildMinimaxChatCompletionV2Request(request, content) if err := replaceJsonRequestBody(minimaxRequest, log); err != nil { util.ErrorHandler("ai-proxy.minimax.insert_ctx_failed", fmt.Errorf("failed to replace Request body: %v", err)) @@ -157,54 +135,45 @@ func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log return types.ActionContinue, err } -// handleRequestBodyByChatCompletionV2 使用ChatCompletion v2接口处理请求体 +// handleRequestBodyByChatCompletionV2 processes the request body using the chat completion V2 API. func (m *minimaxProvider) handleRequestBodyByChatCompletionV2(body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { - request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return nil, err - } - - // 映射模型重写requestPath - request.Model = getMappedModel(request.Model, m.config.modelMapping, log) util.OverwriteRequestPathHeader(headers, minimaxChatCompletionV2Path) - return body, nil + rawModel := gjson.GetBytes(body, "model").String() + mappedModel := getMappedModel(rawModel, m.config.modelMapping, log) + return sjson.SetBytes(body, "model", mappedModel) } func (m *minimaxProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { - // 使用minimax接口协议,跳过OnStreamingResponseBody()和OnResponseBody() + // Skip OnStreamingResponseBody() and OnResponseBody() when using original protocol. if m.config.protocol == protocolOriginal { ctx.DontReadResponseBody() return types.ActionContinue, nil } - // 模型对应接口为ChatCompletion v2,跳过OnStreamingResponseBody()和OnResponseBody() - model := ctx.GetStringContext(ctxKeyFinalRequestModel, "") - if model != "" { - _, ok := chatCompletionProModels[model] - if !ok { - ctx.DontReadResponseBody() - return types.ActionContinue, nil - } + // Skip OnStreamingResponseBody() and OnResponseBody() when the model corresponds to the chat completion V2 interface. + if minimaxApiTypePro != m.config.minimaxApiType { + ctx.DontReadResponseBody() + return types.ActionContinue, nil } _ = proxywasm.RemoveHttpResponseHeader("Content-Length") return types.ActionContinue, nil } -// OnStreamingResponseBody 只处理使用OpenAI协议 且 模型对应接口为ChatCompletion Pro的流式响应 +// OnStreamingResponseBody handles streaming response chunks from the Minimax service only for requests using the OpenAI protocol and corresponding to the chat completion Pro API. func (m *minimaxProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) { if isLastChunk || len(chunk) == 0 { return nil, nil } - // sample event response: + // Sample event response: // data: {"created":1689747645,"model":"abab6.5s-chat","reply":"","choices":[{"messages":[{"sender_type":"BOT","sender_name":"MM智能助理","text":"am from China."}]}],"output_sensitive":false} - // sample end event response: + // Sample end event response: // data: {"created":1689747645,"model":"abab6.5s-chat","reply":"I am from China.","choices":[{"finish_reason":"stop","messages":[{"sender_type":"BOT","sender_name":"MM智能助理","text":"I am from China."}]}],"usage":{"total_tokens":187},"input_sensitive":false,"output_sensitive":false,"id":"0106b3bc9fd844a9f3de1aa06004e2ab","base_resp":{"status_code":0,"status_msg":""}} responseBuilder := &strings.Builder{} lines := strings.Split(string(chunk), "\n") for _, data := range lines { if len(data) < 6 { - // ignore blank line or wrong format + // Ignore blank line or improperly formatted lines. continue } data = data[6:] @@ -226,7 +195,7 @@ func (m *minimaxProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name return []byte(modifiedResponseChunk), nil } -// OnResponseBody 只处理使用OpenAI协议 且 模型对应接口为ChatCompletion Pro的流式响应 +// OnResponseBody handles the final response body from the Minimax service only for requests using the OpenAI protocol and corresponding to the chat completion Pro API. func (m *minimaxProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { minimaxResp := &minimaxChatCompletionV2Resp{} if err := json.Unmarshal(body, minimaxResp); err != nil { @@ -239,39 +208,39 @@ func (m *minimaxProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiNam return types.ActionContinue, replaceJsonResponseBody(response, log) } -// minimaxChatCompletionV2Request 表示ChatCompletion V2请求的结构体 +// minimaxChatCompletionV2Request represents the structure of a chat completion V2 request. type minimaxChatCompletionV2Request struct { Model string `json:"model"` Stream bool `json:"stream,omitempty"` TokensToGenerate int64 `json:"tokens_to_generate,omitempty"` Temperature float64 `json:"temperature,omitempty"` TopP float64 `json:"top_p,omitempty"` - MaskSensitiveInfo bool `json:"mask_sensitive_info"` // 是否开启隐私信息打码,默认true + MaskSensitiveInfo bool `json:"mask_sensitive_info"` // Whether to mask sensitive information, defaults to true. Messages []minimaxMessage `json:"messages"` BotSettings []minimaxBotSetting `json:"bot_setting"` ReplyConstraints minimaxReplyConstraints `json:"reply_constraints"` } -// minimaxMessage 表示对话中的消息 +// minimaxMessage represents a message in the conversation. type minimaxMessage struct { SenderType string `json:"sender_type"` SenderName string `json:"sender_name"` Text string `json:"text"` } -// minimaxBotSetting 表示机器人的设置 +// minimaxBotSetting represents the bot's settings. type minimaxBotSetting struct { BotName string `json:"bot_name"` Content string `json:"content"` } -// minimaxReplyConstraints 表示模型回复要求 +// minimaxReplyConstraints represents requirements for model replies. type minimaxReplyConstraints struct { SenderType string `json:"sender_type"` SenderName string `json:"sender_name"` } -// minimaxChatCompletionV2Resp Minimax Chat Completion V2响应结构体 +// minimaxChatCompletionV2Resp represents the structure of a Minimax Chat Completion V2 response. type minimaxChatCompletionV2Resp struct { Created int64 `json:"created"` Model string `json:"model"` @@ -286,20 +255,20 @@ type minimaxChatCompletionV2Resp struct { BaseResp minimaxBaseResp `json:"base_resp"` } -// minimaxBaseResp 包含错误状态码和详情 +// minimaxBaseResp contains error status code and details. type minimaxBaseResp struct { StatusCode int64 `json:"status_code"` StatusMsg string `json:"status_msg"` } -// minimaxChoice 结果选项 +// minimaxChoice represents a result option. type minimaxChoice struct { Messages []minimaxMessage `json:"messages"` Index int64 `json:"index"` FinishReason string `json:"finish_reason"` } -// minimaxUsage 令牌使用情况 +// minimaxUsage represents token usage statistics. type minimaxUsage struct { TotalTokens int64 `json:"total_tokens"` } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go index 3e5323a60c..041665f9dd 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go @@ -2,10 +2,11 @@ package provider import ( "errors" + "net/http" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" - "net/http" ) const ( @@ -37,12 +38,12 @@ func (m *mistralProvider) GetProviderType() string { return providerTypeMistral } -func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) - return types.ActionContinue, nil + return nil } func (m *mistralProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go index 38d99ae0eb..733cc038b4 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go @@ -56,12 +56,12 @@ func (m *moonshotProvider) GetProviderType() string { return providerTypeMoonshot } -func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) - return types.ActionContinue, nil + return nil } func (m *moonshotProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go b/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go index 5339083819..1bed639f33 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go @@ -3,10 +3,11 @@ package provider import ( "errors" "fmt" + "net/http" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" - "net/http" ) // ollamaProvider is the provider for Ollama service. @@ -48,12 +49,12 @@ func (m *ollamaProvider) GetProviderType() string { return providerTypeOllama } -func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) - return types.ActionContinue, nil + return nil } func (m *ollamaProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go index 60c835cd49..480fdda571 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go @@ -57,9 +57,9 @@ func (m *openaiProvider) GetProviderType() string { return providerTypeOpenAI } -func (m *openaiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (m *openaiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { m.config.handleRequestHeaders(m, ctx, apiName, log) - return types.ActionContinue, nil + return nil } func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 478f7a24b6..0f482732aa 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -118,7 +118,7 @@ type ApiNameHandler interface { } type RequestHeadersHandler interface { - OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) + OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error } type TransformRequestHeadersHandler interface { @@ -206,8 +206,11 @@ type ProviderConfig struct { // @Title zh-CN hunyuan api id for authorization // @Description zh-CN 仅适用于Hun Yuan AI服务鉴权 hunyuanAuthId string `required:"false" yaml:"hunyuanAuthId" json:"hunyuanAuthId"` + // @Title zh-CN minimax API type + // @Description zh-CN 仅适用于 minimax 服务。minimax API 类型,v2 和 pro 中选填一项,默认值为 v2 + minimaxApiType string `required:"false" yaml:"minimaxApiType" json:"minimaxApiType"` // @Title zh-CN minimax group id - // @Description zh-CN 仅适用于minimax使用ChatCompletion Pro接口的模型 + // @Description zh-CN 仅适用于 minimax 服务。minimax API 类型为 pro 时必填 minimaxGroupId string `required:"false" yaml:"minimaxGroupId" json:"minimaxGroupId"` // @Title zh-CN 模型名称映射表 // @Description zh-CN 用于将请求中的模型名称映射为目标AI服务商支持的模型名称。支持通过“*”来配置全局映射 @@ -303,6 +306,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { c.claudeVersion = json.Get("claudeVersion").String() c.hunyuanAuthId = json.Get("hunyuanAuthId").String() c.hunyuanAuthKey = json.Get("hunyuanAuthKey").String() + c.minimaxApiType = json.Get("minimaxApiType").String() c.minimaxGroupId = json.Get("minimaxGroupId").String() c.cloudflareAccountId = json.Get("cloudflareAccountId").String() if c.typ == providerTypeGemini { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go index a4a727724e..95fe28e4bd 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go @@ -27,6 +27,7 @@ const ( qwenChatCompletionPath = "/api/v1/services/aigc/text-generation/generation" qwenTextEmbeddingPath = "/api/v1/services/embeddings/text-embedding/text-embedding" qwenCompatiblePath = "/compatible-mode/v1/chat/completions" + qwenBailianPath = "/api/v1/apps" qwenMultimodalGenerationPath = "/api/v1/services/aigc/multimodal-generation/generation" qwenTopPMin = 0.000001 @@ -71,16 +72,14 @@ func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName } util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) - if m.config.qwenEnableCompatible { + if m.config.IsOriginal() { + } else if m.config.qwenEnableCompatible { util.OverwriteRequestPathHeader(headers, qwenCompatiblePath) } else if apiName == ApiNameChatCompletion { util.OverwriteRequestPathHeader(headers, qwenChatCompletionPath) } else if apiName == ApiNameEmbeddings { util.OverwriteRequestPathHeader(headers, qwenTextEmbeddingPath) } - - headers.Del("Accept-Encoding") - headers.Del("Content-Length") } func (m *qwenProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { @@ -95,20 +94,19 @@ func (m *qwenProvider) GetProviderType() string { return providerTypeQwen } -func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) if m.config.protocol == protocolOriginal { ctx.DontReadRequestBody() - return types.ActionContinue, nil + return nil } - // Delay the header processing to allow changing streaming mode in OnRequestBody - return types.HeaderStopIteration, nil + return nil } func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { @@ -762,6 +760,7 @@ func (m *qwenProvider) GetApiName(path string) ApiName { switch { case strings.Contains(path, qwenChatCompletionPath), strings.Contains(path, qwenMultimodalGenerationPath), + strings.Contains(path, qwenBailianPath), strings.Contains(path, qwenCompatiblePath): return ApiNameChatCompletion case strings.Contains(path, qwenTextEmbeddingPath): diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go index c2e013643c..f44b9e3c0f 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go @@ -67,12 +67,12 @@ func (p *sparkProvider) GetProviderType() string { return providerTypeSpark } -func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } p.config.handleRequestHeaders(p, ctx, apiName, log) - return types.ActionContinue, nil + return nil } func (p *sparkProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { @@ -177,6 +177,4 @@ func (p *sparkProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName util.OverwriteRequestPathHeader(headers, sparkChatCompletionPath) util.OverwriteRequestHostHeader(headers, sparkHost) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+p.config.GetApiTokenInUse(ctx)) - headers.Del("Accept-Encoding") - headers.Del("Content-Length") } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go index 1ee01abe62..4f642c5f6c 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go @@ -2,10 +2,11 @@ package provider import ( "errors" + "net/http" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" - "net/http" ) const ( @@ -39,12 +40,12 @@ func (m *stepfunProvider) GetProviderType() string { return providerTypeStepfun } -func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) - return types.ActionContinue, nil + return nil } func (m *stepfunProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go index 7cb05a9388..e80148ca0c 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go @@ -40,12 +40,12 @@ func (m *yiProvider) GetProviderType() string { return providerTypeYi } -func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) - return types.ActionContinue, nil + return nil } func (m *yiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go index 40fbe4ef88..9c30adb10d 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go @@ -40,12 +40,12 @@ func (m *zhipuAiProvider) GetProviderType() string { return providerTypeZhipuAi } -func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error { if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName + return errUnsupportedApiName } m.config.handleRequestHeaders(m, ctx, apiName, log) - return types.ActionContinue, nil + return nil } func (m *zhipuAiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { diff --git a/plugins/wasm-go/extensions/ai-security-guard/README.md b/plugins/wasm-go/extensions/ai-security-guard/README.md index 68eeeae202..a005299da3 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/README.md +++ b/plugins/wasm-go/extensions/ai-security-guard/README.md @@ -31,6 +31,7 @@ description: 阿里云内容安全检测 | `denyMessage` | string | optional | openai格式的流式/非流式响应 | 指定内容非法时的响应内容 | | `protocol` | string | optional | openai | 协议格式,非openai协议填`original` | | `riskLevelBar` | string | optional | high | 拦截风险等级,取值为 max, high, medium, low | +| `timeout` | int | optional | 2000 | 调用内容安全服务时的超时时间 | 补充说明一下 `denyMessage`,对非法请求的处理逻辑为: - 如果配置了 `denyMessage`,返回内容为 `denyMessage` 配置内容,格式为openai格式的流式/非流式响应 diff --git a/plugins/wasm-go/extensions/ai-security-guard/main.go b/plugins/wasm-go/extensions/ai-security-guard/main.go index f4aee5632b..0e0a747fa1 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/main.go +++ b/plugins/wasm-go/extensions/ai-security-guard/main.go @@ -53,6 +53,7 @@ const ( DefaultStreamingResponseJsonPath = "choices.0.delta.content" DefaultDenyCode = 200 DefaultDenyMessage = "很抱歉,我无法回答您的问题" + DefaultTimeout = 2000 AliyunUserAgent = "CIPFrom/AIGateway" LengthLimit = 1800 @@ -100,6 +101,7 @@ type AISecurityConfig struct { denyMessage string protocolOriginal bool riskLevelBar string + timeout uint32 metrics map[string]proxywasm.MetricCounter } @@ -225,6 +227,11 @@ func parseConfig(json gjson.Result, config *AISecurityConfig, log wrapper.Log) e } else { config.riskLevelBar = HighRisk } + if obj := json.Get("timeout"); obj.Exists() { + config.timeout = uint32(obj.Int()) + } else { + config.timeout = DefaultTimeout + } config.client = wrapper.NewClusterClient(wrapper.FQDNCluster{ FQDN: serviceName, Port: servicePort, @@ -253,6 +260,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action { log.Debugf("checking request body...") + startTime := time.Now().UnixMilli() content := gjson.GetBytes(body, config.requestContentJsonPath).String() model := gjson.GetBytes(body, "model").String() ctx.SetContext("requestModel", model) @@ -279,6 +287,10 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body [] } if riskLevelToInt(response.Data.RiskLevel) < riskLevelToInt(config.riskLevelBar) { if contentIndex >= len(content) { + endTime := time.Now().UnixMilli() + ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) + ctx.SetUserAttribute("safecheck_status", "request pass") + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) proxywasm.ResumeHttpRequest() } else { singleCall() @@ -305,7 +317,9 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body [] } ctx.DontReadResponseBody() config.incrementCounter("ai_sec_request_deny", 1) - ctx.SetUserAttribute("safecheck_status", "request deny") + endTime := time.Now().UnixMilli() + ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) + ctx.SetUserAttribute("safecheck_status", "reqeust deny") if response.Data.Advice != nil { ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label) ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords) @@ -345,7 +359,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body [] reqParams.Add(k, v) } reqParams.Add("Signature", signature) - err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, callback) + err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, callback, config.timeout) if err != nil { log.Errorf("failed call the safe check service: %v", err) proxywasm.ResumeHttpRequest() @@ -364,45 +378,26 @@ func convertHeaders(hs [][2]string) map[string][]string { return ret } -// headers: map[string][]string -> [][2]string -func reconvertHeaders(hs map[string][]string) [][2]string { - var ret [][2]string - for k, vs := range hs { - for _, v := range vs { - ret = append(ret, [2]string{k, v}) - } - } - sort.SliceStable(ret, func(i, j int) bool { - return ret[i][0] < ret[j][0] - }) - return ret -} - func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log wrapper.Log) types.Action { if !config.checkResponse { log.Debugf("response checking is disabled") ctx.DontReadResponseBody() return types.ActionContinue } - headers, err := proxywasm.GetHttpResponseHeaders() - if err != nil { - log.Warnf("failed to get response headers: %v", err) - return types.ActionContinue - } - hdsMap := convertHeaders(headers) - if !strings.Contains(strings.Join(hdsMap[":status"], ";"), "200") { + statusCode, _ := proxywasm.GetHttpResponseHeader(":status") + if statusCode != "200" { log.Debugf("response is not 200, skip response body check") ctx.DontReadResponseBody() return types.ActionContinue } - ctx.SetContext("headers", hdsMap) return types.HeaderStopIteration } func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action { log.Debugf("checking response body...") - hdsMap := ctx.GetContext("headers").(map[string][]string) - isStreamingResponse := strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream") + startTime := time.Now().UnixMilli() + contentType, _ := proxywasm.GetHttpResponseHeader("content-type") + isStreamingResponse := strings.Contains(contentType, "event-stream") model := ctx.GetStringContext("requestModel", "unknown") var content string if isStreamingResponse { @@ -433,6 +428,10 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [ } if riskLevelToInt(response.Data.RiskLevel) < riskLevelToInt(config.riskLevelBar) { if contentIndex >= len(content) { + endTime := time.Now().UnixMilli() + ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime) + ctx.SetUserAttribute("safecheck_status", "response pass") + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) proxywasm.ResumeHttpResponse() } else { singleCall() @@ -458,6 +457,8 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [ proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1) } config.incrementCounter("ai_sec_response_deny", 1) + endTime := time.Now().UnixMilli() + ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime) ctx.SetUserAttribute("safecheck_status", "response deny") if response.Data.Advice != nil { ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label) @@ -498,7 +499,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [ reqParams.Add(k, v) } reqParams.Add("Signature", signature) - err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, callback) + err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, callback, config.timeout) if err != nil { log.Errorf("failed call the safe check service: %v", err) proxywasm.ResumeHttpResponse() diff --git a/plugins/wasm-go/extensions/ai-statistics/go.sum b/plugins/wasm-go/extensions/ai-statistics/go.sum index 6b1c2c3cd7..b4ab172fe2 100644 --- a/plugins/wasm-go/extensions/ai-statistics/go.sum +++ b/plugins/wasm-go/extensions/ai-statistics/go.sum @@ -3,15 +3,13 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= +github.com/higress-group/proxy-wasm-go-sdk v1.0.0 h1:BZRNf4R7jr9hwRivg/E29nkVaKEak5MWjBDhWjuHijU= github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0= github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo= github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= -github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94= github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= diff --git a/plugins/wasm-go/extensions/ai-statistics/main.go b/plugins/wasm-go/extensions/ai-statistics/main.go index 14fcc4d2ab..363f59194e 100644 --- a/plugins/wasm-go/extensions/ai-statistics/main.go +++ b/plugins/wasm-go/extensions/ai-statistics/main.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "strconv" "strings" "time" @@ -28,14 +27,15 @@ func main() { } const ( - // Trace span prefix - TracePrefix = "trace_span_tag." // Context consts StatisticsRequestStartTime = "ai-statistics-request-start-time" StatisticsFirstTokenTime = "ai-statistics-first-token-time" CtxGeneralAtrribute = "attributes" CtxLogAtrribute = "logAttributes" CtxStreamingBodyBuffer = "streamingBodyBuffer" + RouteName = "route" + ClusterName = "cluster" + APIName = "api" // Source Type FixedValue = "fixed_value" @@ -46,12 +46,14 @@ const ( ResponseBody = "response_body" // Inner metric & log attributes name - Model = "model" - InputToken = "input_token" - OutputToken = "output_token" - LLMFirstTokenDuration = "llm_first_token_duration" - LLMServiceDuration = "llm_service_duration" - LLMDurationCount = "llm_duration_count" + Model = "model" + InputToken = "input_token" + OutputToken = "output_token" + LLMFirstTokenDuration = "llm_first_token_duration" + LLMServiceDuration = "llm_service_duration" + LLMDurationCount = "llm_duration_count" + LLMStreamDurationCount = "llm_stream_duration_count" + ResponseType = "response_type" // Extract Rule RuleFirst = "first" @@ -91,6 +93,19 @@ func getRouteName() (string, error) { } } +func getAPIName() (string, error) { + if raw, err := proxywasm.GetProperty([]string{"route_name"}); err != nil { + return "-", err + } else { + parts := strings.Split(string(raw), "@") + if len(parts) != 5 { + return "-", errors.New("not api type") + } else { + return strings.Join(parts[:3], "@"), nil + } + } +} + func getClusterName() (string, error) { if raw, err := proxywasm.GetProperty([]string{"cluster_name"}); err != nil { return "-", err @@ -133,8 +148,15 @@ func parseConfig(configJson gjson.Result, config *AIStatisticsConfig, log wrappe } func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper.Log) types.Action { - ctx.SetContext(CtxGeneralAtrribute, map[string]string{}) - ctx.SetContext(CtxLogAtrribute, map[string]string{}) + route, _ := getRouteName() + cluster, _ := getClusterName() + api, api_error := getAPIName() + if api_error == nil { + route = api + } + ctx.SetContext(RouteName, route) + ctx.SetContext(ClusterName, cluster) + ctx.SetUserAttribute(APIName, api) ctx.SetContext(StatisticsRequestStartTime, time.Now().UnixMilli()) // Set user defined log & span attributes which type is fixed_value @@ -149,6 +171,9 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, lo func onHttpRequestBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body []byte, log wrapper.Log) types.Action { // Set user defined log & span attributes. setAttributeBySource(ctx, config, RequestBody, body, log) + + // Write log + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) return types.ActionContinue } @@ -177,6 +202,8 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat ctx.SetContext(CtxStreamingBodyBuffer, streamingBodyBuffer) } + ctx.SetUserAttribute(ResponseType, "stream") + // Get requestStartTime from http context requestStartTime, ok := ctx.GetContext(StatisticsRequestStartTime).(int64) if !ok { @@ -188,28 +215,19 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat if ctx.GetContext(StatisticsFirstTokenTime) == nil { firstTokenTime := time.Now().UnixMilli() ctx.SetContext(StatisticsFirstTokenTime, firstTokenTime) - attributes, _ := ctx.GetContext(CtxGeneralAtrribute).(map[string]string) - attributes[LLMFirstTokenDuration] = fmt.Sprint(firstTokenTime - requestStartTime) - ctx.SetContext(CtxGeneralAtrribute, attributes) + ctx.SetUserAttribute(LLMFirstTokenDuration, firstTokenTime-requestStartTime) } // Set information about this request - if model, inputToken, outputToken, ok := getUsage(data); ok { - attributes, _ := ctx.GetContext(CtxGeneralAtrribute).(map[string]string) - // Record Log Attributes - attributes[Model] = model - attributes[InputToken] = fmt.Sprint(inputToken) - attributes[OutputToken] = fmt.Sprint(outputToken) - // Set attributes to http context - ctx.SetContext(CtxGeneralAtrribute, attributes) + ctx.SetUserAttribute(Model, model) + ctx.SetUserAttribute(InputToken, inputToken) + ctx.SetUserAttribute(OutputToken, outputToken) } // If the end of the stream is reached, record metrics/logs/spans. if endOfStream { responseEndTime := time.Now().UnixMilli() - attributes, _ := ctx.GetContext(CtxGeneralAtrribute).(map[string]string) - attributes[LLMServiceDuration] = fmt.Sprint(responseEndTime - requestStartTime) - ctx.SetContext(CtxGeneralAtrribute, attributes) + ctx.SetUserAttribute(LLMServiceDuration, responseEndTime-requestStartTime) // Set user defined log & span attributes. if config.shouldBufferStreamingBody { @@ -220,11 +238,8 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat setAttributeBySource(ctx, config, ResponseStreamingBody, streamingBodyBuffer, log) } - // Write inner filter states which can be used by other plugins such as ai-token-ratelimit - writeFilterStates(ctx, log) - // Write log - writeLog(ctx, log) + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) // Write metrics writeMetric(ctx, config, log) @@ -233,33 +248,26 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat } func onHttpResponseBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body []byte, log wrapper.Log) types.Action { - // Get attributes from http context - attributes, _ := ctx.GetContext(CtxGeneralAtrribute).(map[string]string) - // Get requestStartTime from http context requestStartTime, _ := ctx.GetContext(StatisticsRequestStartTime).(int64) responseEndTime := time.Now().UnixMilli() - attributes[LLMServiceDuration] = fmt.Sprint(responseEndTime - requestStartTime) + ctx.SetUserAttribute(LLMServiceDuration, responseEndTime-requestStartTime) + + ctx.SetUserAttribute(ResponseType, "normal") // Set information about this request - model, inputToken, outputToken, ok := getUsage(body) - if ok { - attributes[Model] = model - attributes[InputToken] = fmt.Sprint(inputToken) - attributes[OutputToken] = fmt.Sprint(outputToken) - // Update attributes - ctx.SetContext(CtxGeneralAtrribute, attributes) + if model, inputToken, outputToken, ok := getUsage(body); ok { + ctx.SetUserAttribute(Model, model) + ctx.SetUserAttribute(InputToken, inputToken) + ctx.SetUserAttribute(OutputToken, outputToken) } // Set user defined log & span attributes. setAttributeBySource(ctx, config, ResponseBody, body, log) - // Write inner filter states which can be used by other plugins such as ai-token-ratelimit - writeFilterStates(ctx, log) - // Write log - writeLog(ctx, log) + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) // Write metrics writeMetric(ctx, config, log) @@ -294,67 +302,49 @@ func getUsage(data []byte) (model string, inputTokenUsage int64, outputTokenUsag // fetches the tracing span value from the specified source. func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, source string, body []byte, log wrapper.Log) { - attributes, ok := ctx.GetContext(CtxGeneralAtrribute).(map[string]string) - if !ok { - log.Error("failed to get attributes from http context") - return - } for _, attribute := range config.attributes { + var key string + var value interface{} if source == attribute.ValueSource { + key = attribute.Key switch source { case FixedValue: - log.Debugf("[attribute] source type: %s, key: %s, value: %s", source, attribute.Key, attribute.Value) - attributes[attribute.Key] = attribute.Value + value = attribute.Value case RequestHeader: - if value, err := proxywasm.GetHttpRequestHeader(attribute.Value); err == nil { - log.Debugf("[attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value) - attributes[attribute.Key] = value - } + value, _ = proxywasm.GetHttpRequestHeader(attribute.Value) case RequestBody: - raw := gjson.GetBytes(body, attribute.Value).Raw - var value string - if len(raw) > 2 { - value = raw[1 : len(raw)-1] - } - log.Debugf("[attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value) - attributes[attribute.Key] = value + value = gjson.GetBytes(body, attribute.Value).Value() case ResponseHeader: - if value, err := proxywasm.GetHttpResponseHeader(attribute.Value); err == nil { - log.Debugf("[log attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value) - attributes[attribute.Key] = value - } + value, _ = proxywasm.GetHttpResponseHeader(attribute.Value) case ResponseStreamingBody: - value := extractStreamingBodyByJsonPath(body, attribute.Value, attribute.Rule, log) - log.Debugf("[log attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value) - attributes[attribute.Key] = value + value = extractStreamingBodyByJsonPath(body, attribute.Value, attribute.Rule, log) case ResponseBody: - value := gjson.GetBytes(body, attribute.Value).Raw - if len(value) > 2 && value[0] == '"' && value[len(value)-1] == '"' { - value = value[1 : len(value)-1] - } - log.Debugf("[log attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value) - attributes[attribute.Key] = value + value = gjson.GetBytes(body, attribute.Value).Value() default: } - } - if attribute.ApplyToLog { - setLogAttribute(ctx, attribute.Key, attributes[attribute.Key], log) - } - if attribute.ApplyToSpan { - setSpanAttribute(attribute.Key, attributes[attribute.Key], log) + log.Debugf("[attribute] source type: %s, key: %s, value: %+v", source, key, value) + if attribute.ApplyToLog { + ctx.SetUserAttribute(key, value) + } + // for metrics + if key == Model || key == InputToken || key == OutputToken { + ctx.SetContext(key, value) + } + if attribute.ApplyToSpan { + setSpanAttribute(key, value, log) + } } } - ctx.SetContext(CtxGeneralAtrribute, attributes) } -func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, log wrapper.Log) string { +func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, log wrapper.Log) interface{} { chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n")) - var value string + var value interface{} if rule == RuleFirst { for _, chunk := range chunks { jsonObj := gjson.GetBytes(chunk, jsonPath) if jsonObj.Exists() { - value = jsonObj.String() + value = jsonObj.Value() break } } @@ -362,140 +352,116 @@ func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, l for _, chunk := range chunks { jsonObj := gjson.GetBytes(chunk, jsonPath) if jsonObj.Exists() { - value = jsonObj.String() + value = jsonObj.Value() } } } else if rule == RuleAppend { // extract llm response + var strValue string for _, chunk := range chunks { - raw := gjson.GetBytes(chunk, jsonPath).Raw - if len(raw) > 2 && raw[0] == '"' && raw[len(raw)-1] == '"' { - value += raw[1 : len(raw)-1] + jsonObj := gjson.GetBytes(chunk, jsonPath) + if jsonObj.Exists() { + strValue += jsonObj.String() } } + value = strValue } else { log.Errorf("unsupported rule type: %s", rule) } return value } -func setFilterState(key, value string, log wrapper.Log) { - if value != "" { - if e := proxywasm.SetProperty([]string{key}, []byte(fmt.Sprint(value))); e != nil { - log.Errorf("failed to set %s in filter state: %v", key, e) - } - } else { - log.Debugf("failed to write filter state [%s], because it's value is empty") - } -} - // Set the tracing span with value. -func setSpanAttribute(key, value string, log wrapper.Log) { +func setSpanAttribute(key string, value interface{}, log wrapper.Log) { if value != "" { - traceSpanTag := TracePrefix + key - if e := proxywasm.SetProperty([]string{traceSpanTag}, []byte(value)); e != nil { - log.Errorf("failed to set %s in filter state: %v", traceSpanTag, e) + traceSpanTag := wrapper.TraceSpanTagPrefix + key + if e := proxywasm.SetProperty([]string{traceSpanTag}, []byte(fmt.Sprint(value))); e != nil { + log.Warnf("failed to set %s in filter state: %v", traceSpanTag, e) } } else { log.Debugf("failed to write span attribute [%s], because it's value is empty") } } -// fetches the tracing span value from the specified source. -func setLogAttribute(ctx wrapper.HttpContext, key string, value interface{}, log wrapper.Log) { - logAttributes, ok := ctx.GetContext(CtxLogAtrribute).(map[string]string) +func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper.Log) { + // Generate usage metrics + var ok bool + var route, cluster, model string + var inputToken, outputToken uint64 + route, ok = ctx.GetContext(RouteName).(string) if !ok { - log.Error("failed to get logAttributes from http context") + log.Warnf("RouteName typd assert failed, skip metric record") return } - logAttributes[key] = fmt.Sprint(value) - ctx.SetContext(CtxLogAtrribute, logAttributes) -} - -func writeFilterStates(ctx wrapper.HttpContext, log wrapper.Log) { - attributes, _ := ctx.GetContext(CtxGeneralAtrribute).(map[string]string) - setFilterState(Model, attributes[Model], log) - setFilterState(InputToken, attributes[InputToken], log) - setFilterState(OutputToken, attributes[OutputToken], log) -} - -func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper.Log) { - attributes, _ := ctx.GetContext(CtxGeneralAtrribute).(map[string]string) - route, _ := getRouteName() - cluster, _ := getClusterName() - model, ok := attributes["model"] + cluster, ok = ctx.GetContext(ClusterName).(string) if !ok { - log.Errorf("Get model failed") + log.Warnf("ClusterName typd assert failed, skip metric record") return } - if inputToken, ok := attributes[InputToken]; ok { - inputTokenUint64, err := strconv.ParseUint(inputToken, 10, 0) - if err != nil || inputTokenUint64 == 0 { - log.Errorf("inputToken convert failed, value is %d, err msg is [%v]", inputTokenUint64, err) - return - } - config.incrementCounter(generateMetricName(route, cluster, model, InputToken), inputTokenUint64) + if ctx.GetUserAttribute(Model) == nil || ctx.GetUserAttribute(InputToken) == nil || ctx.GetUserAttribute(OutputToken) == nil { + log.Warnf("get usage information failed, skip metric record") + return } - if outputToken, ok := attributes[OutputToken]; ok { - outputTokenUint64, err := strconv.ParseUint(outputToken, 10, 0) - if err != nil || outputTokenUint64 == 0 { - log.Errorf("outputToken convert failed, value is %d, err msg is [%v]", outputTokenUint64, err) - return - } - config.incrementCounter(generateMetricName(route, cluster, model, OutputToken), outputTokenUint64) + model, ok = ctx.GetUserAttribute(Model).(string) + if !ok { + log.Warnf("Model typd assert failed, skip metric record") + return } - if llmFirstTokenDuration, ok := attributes[LLMFirstTokenDuration]; ok { - llmFirstTokenDurationUint64, err := strconv.ParseUint(llmFirstTokenDuration, 10, 0) - if err != nil || llmFirstTokenDurationUint64 == 0 { - log.Errorf("llmFirstTokenDuration convert failed, value is %d, err msg is [%v]", llmFirstTokenDurationUint64, err) + inputToken, ok = convertToUInt(ctx.GetUserAttribute(InputToken)) + if !ok { + log.Warnf("InputToken typd assert failed, skip metric record") + return + } + outputToken, ok = convertToUInt(ctx.GetUserAttribute(OutputToken)) + if !ok { + log.Warnf("OutputToken typd assert failed, skip metric record") + return + } + if inputToken == 0 || outputToken == 0 { + log.Warnf("inputToken and outputToken cannot equal to 0, skip metric record") + return + } + config.incrementCounter(generateMetricName(route, cluster, model, InputToken), inputToken) + config.incrementCounter(generateMetricName(route, cluster, model, OutputToken), outputToken) + + // Generate duration metrics + var llmFirstTokenDuration, llmServiceDuration uint64 + // Is stream response + if ctx.GetUserAttribute(LLMFirstTokenDuration) != nil { + llmFirstTokenDuration, ok = convertToUInt(ctx.GetUserAttribute(LLMFirstTokenDuration)) + if !ok { + log.Warnf("LLMFirstTokenDuration typd assert failed") return } - config.incrementCounter(generateMetricName(route, cluster, model, LLMFirstTokenDuration), llmFirstTokenDurationUint64) + config.incrementCounter(generateMetricName(route, cluster, model, LLMFirstTokenDuration), llmFirstTokenDuration) + config.incrementCounter(generateMetricName(route, cluster, model, LLMStreamDurationCount), 1) } - if llmServiceDuration, ok := attributes[LLMServiceDuration]; ok { - llmServiceDurationUint64, err := strconv.ParseUint(llmServiceDuration, 10, 0) - if err != nil || llmServiceDurationUint64 == 0 { - log.Errorf("llmServiceDuration convert failed, value is %d, err msg is [%v]", llmServiceDurationUint64, err) + if ctx.GetUserAttribute(LLMServiceDuration) != nil { + llmServiceDuration, ok = convertToUInt(ctx.GetUserAttribute(LLMServiceDuration)) + if !ok { + log.Warnf("LLMServiceDuration typd assert failed") return } - config.incrementCounter(generateMetricName(route, cluster, model, LLMServiceDuration), llmServiceDurationUint64) + config.incrementCounter(generateMetricName(route, cluster, model, LLMServiceDuration), llmServiceDuration) + config.incrementCounter(generateMetricName(route, cluster, model, LLMDurationCount), 1) } - config.incrementCounter(generateMetricName(route, cluster, model, LLMDurationCount), 1) } -func writeLog(ctx wrapper.HttpContext, log wrapper.Log) { - attributes, _ := ctx.GetContext(CtxGeneralAtrribute).(map[string]string) - logAttributes, _ := ctx.GetContext(CtxLogAtrribute).(map[string]string) - // Set inner log fields - if attributes[Model] != "" { - logAttributes[Model] = attributes[Model] - } - if attributes[InputToken] != "" { - logAttributes[InputToken] = attributes[InputToken] - } - if attributes[OutputToken] != "" { - logAttributes[OutputToken] = attributes[OutputToken] - } - if attributes[LLMFirstTokenDuration] != "" { - logAttributes[LLMFirstTokenDuration] = attributes[LLMFirstTokenDuration] - } - if attributes[LLMServiceDuration] != "" { - logAttributes[LLMServiceDuration] = attributes[LLMServiceDuration] - } - // Traverse log fields - items := []string{} - for k, v := range logAttributes { - items = append(items, fmt.Sprintf(`"%s":"%s"`, k, v)) - } - aiLogField := fmt.Sprintf(`{%s}`, strings.Join(items, ",")) - // log.Infof("ai request json log: %s", aiLogField) - jsonMap := map[string]string{ - "ai_log": aiLogField, - } - serialized, _ := json.Marshal(jsonMap) - jsonLogRaw := gjson.GetBytes(serialized, "ai_log").Raw - jsonLog := jsonLogRaw[1 : len(jsonLogRaw)-1] - if err := proxywasm.SetProperty([]string{"ai_log"}, []byte(jsonLog)); err != nil { - log.Errorf("failed to set ai_log in filter state: %v", err) +func convertToUInt(val interface{}) (uint64, bool) { + switch v := val.(type) { + case float32: + return uint64(v), true + case float64: + return uint64(v), true + case int32: + return uint64(v), true + case int64: + return uint64(v), true + case uint32: + return uint64(v), true + case uint64: + return v, true + default: + return 0, false } } diff --git a/plugins/wasm-go/extensions/ai-token-ratelimit/go.sum b/plugins/wasm-go/extensions/ai-token-ratelimit/go.sum index 4bc7bb7527..7b8c22894a 100644 --- a/plugins/wasm-go/extensions/ai-token-ratelimit/go.sum +++ b/plugins/wasm-go/extensions/ai-token-ratelimit/go.sum @@ -5,8 +5,7 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= +github.com/higress-group/proxy-wasm-go-sdk v1.0.0 h1:BZRNf4R7jr9hwRivg/E29nkVaKEak5MWjBDhWjuHijU= github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0= github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo= github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= @@ -14,8 +13,7 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/tetratelabs/wazero v1.7.1 h1:QtSfd6KLc41DIMpDYlJdoMc6k7QTN246DM2+n2Y/Dx8= github.com/tetratelabs/wazero v1.7.1/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y= -github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= -github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94= github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= diff --git a/plugins/wasm-go/extensions/ai-token-ratelimit/main.go b/plugins/wasm-go/extensions/ai-token-ratelimit/main.go index afe463a12d..6877ae5c23 100644 --- a/plugins/wasm-go/extensions/ai-token-ratelimit/main.go +++ b/plugins/wasm-go/extensions/ai-token-ratelimit/main.go @@ -15,6 +15,7 @@ package main import ( + "bytes" "fmt" "net" "net/url" @@ -61,9 +62,9 @@ const ( ConsumerHeader string = "x-mse-consumer" // LimitByConsumer从该request header获取consumer的名字 CookieHeader string = "cookie" - RateLimitLimitHeader string = "X-RateLimit-Limit" // 限制的总请求数 - RateLimitRemainingHeader string = "X-RateLimit-Remaining" // 剩余还可以发送的请求数 - RateLimitResetHeader string = "X-RateLimit-Reset" // 限流重置时间(触发限流时返回) + RateLimitLimitHeader string = "X-TokenRateLimit-Limit" // 限制的总请求数 + RateLimitRemainingHeader string = "X-TokenRateLimit-Remaining" // 剩余还可以发送的请求数 + RateLimitResetHeader string = "X-TokenRateLimit-Reset" // 限流重置时间(触发限流时返回) ) type LimitContext struct { @@ -124,6 +125,8 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitCon } if context.remaining < 0 { // 触发限流 + ctx.SetUserAttribute("token_ratelimit_status", "limited") + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) rejected(config, context) } else { proxywasm.ResumeHttpRequest() @@ -137,39 +140,49 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitCon } func onHttpStreamingBody(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, data []byte, endOfStream bool, log wrapper.Log) []byte { - if !endOfStream { - return data + var inputToken, outputToken int64 + if inputToken, outputToken, ok := getUsage(data); ok { + ctx.SetContext("input_token", inputToken) + ctx.SetContext("output_token", outputToken) } - inputTokenStr, err := proxywasm.GetProperty([]string{"filter_state", "wasm.input_token"}) - if err != nil { - return data - } - outputTokenStr, err := proxywasm.GetProperty([]string{"filter_state", "wasm.output_token"}) - if err != nil { - return data - } - inputToken, err := strconv.Atoi(string(inputTokenStr)) - if err != nil { - return data - } - outputToken, err := strconv.Atoi(string(outputTokenStr)) - if err != nil { - return data - } - limitRedisContext, ok := ctx.GetContext(LimitRedisContextKey).(LimitRedisContext) - if !ok { - return data + if endOfStream { + if ctx.GetContext("input_token") == nil || ctx.GetContext("output_token") == nil { + return data + } + inputToken = ctx.GetContext("input_token").(int64) + outputToken = ctx.GetContext("output_token").(int64) + limitRedisContext, ok := ctx.GetContext(LimitRedisContextKey).(LimitRedisContext) + if !ok { + return data + } + keys := []interface{}{limitRedisContext.key} + args := []interface{}{limitRedisContext.count, limitRedisContext.window, inputToken + outputToken} + err := config.redisClient.Eval(ResponsePhaseFixedWindowScript, 1, keys, args, nil) + if err != nil { + log.Errorf("redis call failed: %v", err) + } } - keys := []interface{}{limitRedisContext.key} - args := []interface{}{limitRedisContext.count, limitRedisContext.window, inputToken + outputToken} + return data +} - err = config.redisClient.Eval(ResponsePhaseFixedWindowScript, 1, keys, args, nil) - if err != nil { - log.Errorf("redis call failed: %v", err) - return data - } else { - return data +func getUsage(data []byte) (inputTokenUsage int64, outputTokenUsage int64, ok bool) { + chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n")) + for _, chunk := range chunks { + // the feature strings are used to identify the usage data, like: + // {"model":"gpt2","usage":{"prompt_tokens":1,"completion_tokens":1}} + if !bytes.Contains(chunk, []byte("prompt_tokens")) || !bytes.Contains(chunk, []byte("completion_tokens")) { + continue + } + inputTokenObj := gjson.GetBytes(chunk, "usage.prompt_tokens") + outputTokenObj := gjson.GetBytes(chunk, "usage.completion_tokens") + if inputTokenObj.Exists() && outputTokenObj.Exists() { + inputTokenUsage = inputTokenObj.Int() + outputTokenUsage = outputTokenObj.Int() + ok = true + return + } } + return } func checkRequestAgainstLimitRule(ctx wrapper.HttpContext, ruleItems []LimitRuleItem, log wrapper.Log) (string, *LimitRuleItem, *LimitConfigItem) { diff --git a/plugins/wasm-go/pkg/wrapper/plugin_wrapper.go b/plugins/wasm-go/pkg/wrapper/plugin_wrapper.go index be9144adfc..8b342d57b5 100644 --- a/plugins/wasm-go/pkg/wrapper/plugin_wrapper.go +++ b/plugins/wasm-go/pkg/wrapper/plugin_wrapper.go @@ -45,6 +45,8 @@ type HttpContext interface { GetStringContext(key, defaultValue string) string GetUserAttribute(key string) interface{} SetUserAttribute(key string, value interface{}) + SetUserAttributeMap(kvmap map[string]interface{}) + GetUserAttributeMap() map[string]interface{} // You can call this function to set custom log WriteUserAttributeToLog() error // You can call this function to set custom log with your specific key @@ -403,6 +405,14 @@ func (ctx *CommonHttpCtx[PluginConfig]) GetUserAttribute(key string) interface{} return ctx.userAttribute[key] } +func (ctx *CommonHttpCtx[PluginConfig]) SetUserAttributeMap(kvmap map[string]interface{}) { + ctx.userAttribute = kvmap +} + +func (ctx *CommonHttpCtx[PluginConfig]) GetUserAttributeMap() map[string]interface{} { + return ctx.userAttribute +} + func (ctx *CommonHttpCtx[PluginConfig]) WriteUserAttributeToLog() error { return ctx.WriteUserAttributeToLogWithKey(CustomLogKey) } diff --git a/plugins/wasm-go/pkg/wrapper/redis_wrapper.go b/plugins/wasm-go/pkg/wrapper/redis_wrapper.go index c619c3e191..f4b42e67e7 100644 --- a/plugins/wasm-go/pkg/wrapper/redis_wrapper.go +++ b/plugins/wasm-go/pkg/wrapper/redis_wrapper.go @@ -17,6 +17,7 @@ package wrapper import ( "bytes" "encoding/base64" + "errors" "fmt" "io" @@ -28,7 +29,7 @@ import ( type RedisResponseCallback func(response resp.Value) type RedisClient interface { - Init(username, password string, timeout int64) error + Init(username, password string, timeout int64, opts ...optionFunc) error // with this function, you can call redis as if you are using redis-cli Command(cmds []interface{}, callback RedisResponseCallback) error Eval(script string, numkeys int, keys, args []interface{}, callback RedisResponseCallback) error @@ -103,15 +104,31 @@ type RedisClient interface { } type RedisClusterClient[C Cluster] struct { - cluster C + cluster C + ready bool + checkReadyFunc func() error + option redisOption } -func NewRedisClusterClient[C Cluster](cluster C) *RedisClusterClient[C] { - return &RedisClusterClient[C]{cluster: cluster} +type redisOption struct { + dataBase int } -func RedisInit(cluster Cluster, username, password string, timeout uint32) error { - return proxywasm.RedisInit(cluster.ClusterName(), username, password, timeout) +type optionFunc func(*redisOption) + +func WithDataBase(dataBase int) optionFunc { + return func(o *redisOption) { + o.dataBase = dataBase + } +} + +func NewRedisClusterClient[C Cluster](cluster C) *RedisClusterClient[C] { + return &RedisClusterClient[C]{ + cluster: cluster, + checkReadyFunc: func() error { + return errors.New("redis client is not ready, please call Init() first") + }, + } } func RedisCall(cluster Cluster, respQuery []byte, callback RedisResponseCallback) error { @@ -165,19 +182,46 @@ func respString(args []interface{}) []byte { return buf.Bytes() } -func (c RedisClusterClient[C]) Init(username, password string, timeout int64) error { - err := RedisInit(c.cluster, username, password, uint32(timeout)) +func (c *RedisClusterClient[C]) Init(username, password string, timeout int64, opts ...optionFunc) error { + for _, opt := range opts { + opt(&c.option) + } + clusterName := c.cluster.ClusterName() + if c.option.dataBase != 0 { + clusterName = fmt.Sprintf("%s?db=%d", clusterName, c.option.dataBase) + } + err := proxywasm.RedisInit(clusterName, username, password, uint32(timeout)) if err != nil { - proxywasm.LogCriticalf("failed to init redis: %v", err) + c.checkReadyFunc = func() error { + if c.ready { + return nil + } + initErr := proxywasm.RedisInit(clusterName, username, password, uint32(timeout)) + if initErr != nil { + return initErr + } + c.ready = true + return nil + } + proxywasm.LogWarnf("failed to init redis: %v, will retry after", err) + return nil } - return err + c.checkReadyFunc = func() error { return nil } + c.ready = true + return nil } -func (c RedisClusterClient[C]) Command(cmds []interface{}, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) Command(cmds []interface{}, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } return RedisCall(c.cluster, respString(cmds), callback) } -func (c RedisClusterClient[C]) Eval(script string, numkeys int, keys, args []interface{}, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) Eval(script string, numkeys int, keys, args []interface{}, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } params := make([]interface{}, 0) params = append(params, "eval") params = append(params, script) @@ -188,21 +232,30 @@ func (c RedisClusterClient[C]) Eval(script string, numkeys int, keys, args []int } // Key -func (c RedisClusterClient[C]) Del(key string, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) Del(key string, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "del") args = append(args, key) return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) Exists(key string, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) Exists(key string, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "exists") args = append(args, key) return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) Expire(key string, ttl int, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) Expire(key string, ttl int, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "expire") args = append(args, key) @@ -210,7 +263,10 @@ func (c RedisClusterClient[C]) Expire(key string, ttl int, callback RedisRespons return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) Persist(key string, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) Persist(key string, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "persist") args = append(args, key) @@ -218,14 +274,20 @@ func (c RedisClusterClient[C]) Persist(key string, callback RedisResponseCallbac } // String -func (c RedisClusterClient[C]) Get(key string, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) Get(key string, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "get") args = append(args, key) return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) Set(key string, value interface{}, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) Set(key string, value interface{}, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "set") args = append(args, key) @@ -233,7 +295,10 @@ func (c RedisClusterClient[C]) Set(key string, value interface{}, callback Redis return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) SetEx(key string, value interface{}, ttl int, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) SetEx(key string, value interface{}, ttl int, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "set") args = append(args, key) @@ -243,7 +308,10 @@ func (c RedisClusterClient[C]) SetEx(key string, value interface{}, ttl int, cal return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) MGet(keys []string, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) MGet(keys []string, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "mget") for _, k := range keys { @@ -252,7 +320,10 @@ func (c RedisClusterClient[C]) MGet(keys []string, callback RedisResponseCallbac return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) MSet(kvMap map[string]interface{}, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) MSet(kvMap map[string]interface{}, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "mset") for k, v := range kvMap { @@ -262,21 +333,30 @@ func (c RedisClusterClient[C]) MSet(kvMap map[string]interface{}, callback Redis return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) Incr(key string, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) Incr(key string, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "incr") args = append(args, key) return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) Decr(key string, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) Decr(key string, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "decr") args = append(args, key) return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) IncrBy(key string, delta int, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) IncrBy(key string, delta int, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "incrby") args = append(args, key) @@ -284,7 +364,10 @@ func (c RedisClusterClient[C]) IncrBy(key string, delta int, callback RedisRespo return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) DecrBy(key string, delta int, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) DecrBy(key string, delta int, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "decrby") args = append(args, key) @@ -293,14 +376,20 @@ func (c RedisClusterClient[C]) DecrBy(key string, delta int, callback RedisRespo } // List -func (c RedisClusterClient[C]) LLen(key string, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) LLen(key string, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "llen") args = append(args, key) return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) RPush(key string, vals []interface{}, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) RPush(key string, vals []interface{}, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "rpush") args = append(args, key) @@ -310,14 +399,20 @@ func (c RedisClusterClient[C]) RPush(key string, vals []interface{}, callback Re return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) RPop(key string, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) RPop(key string, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "rpop") args = append(args, key) return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) LPush(key string, vals []interface{}, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) LPush(key string, vals []interface{}, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "lpush") args = append(args, key) @@ -327,14 +422,20 @@ func (c RedisClusterClient[C]) LPush(key string, vals []interface{}, callback Re return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) LPop(key string, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) LPop(key string, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "lpop") args = append(args, key) return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) LIndex(key string, index int, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) LIndex(key string, index int, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "lindex") args = append(args, key) @@ -342,7 +443,10 @@ func (c RedisClusterClient[C]) LIndex(key string, index int, callback RedisRespo return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) LRange(key string, start, stop int, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) LRange(key string, start, stop int, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "lrange") args = append(args, key) @@ -351,7 +455,10 @@ func (c RedisClusterClient[C]) LRange(key string, start, stop int, callback Redi return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) LRem(key string, count int, value interface{}, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) LRem(key string, count int, value interface{}, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "lrem") args = append(args, key) @@ -360,7 +467,10 @@ func (c RedisClusterClient[C]) LRem(key string, count int, value interface{}, ca return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) LInsertBefore(key string, pivot, value interface{}, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) LInsertBefore(key string, pivot, value interface{}, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "linsert") args = append(args, key) @@ -370,7 +480,10 @@ func (c RedisClusterClient[C]) LInsertBefore(key string, pivot, value interface{ return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) LInsertAfter(key string, pivot, value interface{}, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) LInsertAfter(key string, pivot, value interface{}, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "linsert") args = append(args, key) @@ -381,7 +494,10 @@ func (c RedisClusterClient[C]) LInsertAfter(key string, pivot, value interface{} } // Hash -func (c RedisClusterClient[C]) HExists(key, field string, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) HExists(key, field string, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "hexists") args = append(args, key) @@ -389,7 +505,10 @@ func (c RedisClusterClient[C]) HExists(key, field string, callback RedisResponse return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) HDel(key string, fields []string, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) HDel(key string, fields []string, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "hdel") args = append(args, key) @@ -399,14 +518,20 @@ func (c RedisClusterClient[C]) HDel(key string, fields []string, callback RedisR return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) HLen(key string, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) HLen(key string, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "hlen") args = append(args, key) return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) HGet(key, field string, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) HGet(key, field string, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "hget") args = append(args, key) @@ -414,7 +539,10 @@ func (c RedisClusterClient[C]) HGet(key, field string, callback RedisResponseCal return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) HSet(key, field string, value interface{}, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) HSet(key, field string, value interface{}, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "hset") args = append(args, key) @@ -423,7 +551,10 @@ func (c RedisClusterClient[C]) HSet(key, field string, value interface{}, callba return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) HMGet(key string, fields []string, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) HMGet(key string, fields []string, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "hmget") args = append(args, key) @@ -433,7 +564,10 @@ func (c RedisClusterClient[C]) HMGet(key string, fields []string, callback Redis return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) HMSet(key string, kvMap map[string]interface{}, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) HMSet(key string, kvMap map[string]interface{}, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "hmset") args = append(args, key) @@ -444,28 +578,40 @@ func (c RedisClusterClient[C]) HMSet(key string, kvMap map[string]interface{}, c return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) HKeys(key string, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) HKeys(key string, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "hkeys") args = append(args, key) return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) HVals(key string, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) HVals(key string, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "hvals") args = append(args, key) return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) HGetAll(key string, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) HGetAll(key string, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "hgetall") args = append(args, key) return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) HIncrBy(key, field string, delta int, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) HIncrBy(key, field string, delta int, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "hincrby") args = append(args, key) @@ -474,7 +620,10 @@ func (c RedisClusterClient[C]) HIncrBy(key, field string, delta int, callback Re return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) HIncrByFloat(key, field string, delta float64, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) HIncrByFloat(key, field string, delta float64, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "hincrbyfloat") args = append(args, key) @@ -484,14 +633,20 @@ func (c RedisClusterClient[C]) HIncrByFloat(key, field string, delta float64, ca } // Set -func (c RedisClusterClient[C]) SCard(key string, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) SCard(key string, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "scard") args = append(args, key) return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) SAdd(key string, vals []interface{}, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) SAdd(key string, vals []interface{}, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "sadd") args = append(args, key) @@ -501,7 +656,10 @@ func (c RedisClusterClient[C]) SAdd(key string, vals []interface{}, callback Red return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) SRem(key string, vals []interface{}, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) SRem(key string, vals []interface{}, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "srem") args = append(args, key) @@ -511,7 +669,10 @@ func (c RedisClusterClient[C]) SRem(key string, vals []interface{}, callback Red return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) SIsMember(key string, value interface{}, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) SIsMember(key string, value interface{}, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "sismember") args = append(args, key) @@ -519,14 +680,20 @@ func (c RedisClusterClient[C]) SIsMember(key string, value interface{}, callback return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) SMembers(key string, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) SMembers(key string, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "smembers") args = append(args, key) return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) SDiff(key1, key2 string, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) SDiff(key1, key2 string, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "sdiff") args = append(args, key1) @@ -534,7 +701,10 @@ func (c RedisClusterClient[C]) SDiff(key1, key2 string, callback RedisResponseCa return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) SDiffStore(destination, key1, key2 string, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) SDiffStore(destination, key1, key2 string, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "sdiffstore") args = append(args, destination) @@ -543,7 +713,10 @@ func (c RedisClusterClient[C]) SDiffStore(destination, key1, key2 string, callba return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) SInter(key1, key2 string, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) SInter(key1, key2 string, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "sinter") args = append(args, key1) @@ -551,7 +724,10 @@ func (c RedisClusterClient[C]) SInter(key1, key2 string, callback RedisResponseC return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) SInterStore(destination, key1, key2 string, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) SInterStore(destination, key1, key2 string, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "sinterstore") args = append(args, destination) @@ -560,7 +736,10 @@ func (c RedisClusterClient[C]) SInterStore(destination, key1, key2 string, callb return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) SUnion(key1, key2 string, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) SUnion(key1, key2 string, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "sunion") args = append(args, key1) @@ -568,7 +747,10 @@ func (c RedisClusterClient[C]) SUnion(key1, key2 string, callback RedisResponseC return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) SUnionStore(destination, key1, key2 string, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) SUnionStore(destination, key1, key2 string, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "sunionstore") args = append(args, destination) @@ -578,14 +760,20 @@ func (c RedisClusterClient[C]) SUnionStore(destination, key1, key2 string, callb } // ZSet -func (c RedisClusterClient[C]) ZCard(key string, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) ZCard(key string, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "zcard") args = append(args, key) return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) ZAdd(key string, msMap map[string]interface{}, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) ZAdd(key string, msMap map[string]interface{}, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "zadd") args = append(args, key) @@ -596,7 +784,10 @@ func (c RedisClusterClient[C]) ZAdd(key string, msMap map[string]interface{}, ca return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) ZCount(key string, min interface{}, max interface{}, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) ZCount(key string, min interface{}, max interface{}, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "zcount") args = append(args, key) @@ -605,7 +796,10 @@ func (c RedisClusterClient[C]) ZCount(key string, min interface{}, max interface return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) ZIncrBy(key string, member string, delta interface{}, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) ZIncrBy(key string, member string, delta interface{}, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "zincrby") args = append(args, key) @@ -614,7 +808,10 @@ func (c RedisClusterClient[C]) ZIncrBy(key string, member string, delta interfac return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) ZScore(key, member string, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) ZScore(key, member string, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "zscore") args = append(args, key) @@ -622,7 +819,10 @@ func (c RedisClusterClient[C]) ZScore(key, member string, callback RedisResponse return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) ZRank(key, member string, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) ZRank(key, member string, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "zrank") args = append(args, key) @@ -630,7 +830,10 @@ func (c RedisClusterClient[C]) ZRank(key, member string, callback RedisResponseC return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) ZRevRank(key, member string, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) ZRevRank(key, member string, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "zrevrank") args = append(args, key) @@ -638,7 +841,10 @@ func (c RedisClusterClient[C]) ZRevRank(key, member string, callback RedisRespon return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) ZRem(key string, members []string, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) ZRem(key string, members []string, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "zrem") args = append(args, key) @@ -648,7 +854,10 @@ func (c RedisClusterClient[C]) ZRem(key string, members []string, callback Redis return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) ZRange(key string, start, stop int, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) ZRange(key string, start, stop int, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "zrange") args = append(args, key) @@ -657,7 +866,10 @@ func (c RedisClusterClient[C]) ZRange(key string, start, stop int, callback Redi return RedisCall(c.cluster, respString(args), callback) } -func (c RedisClusterClient[C]) ZRevRange(key string, start, stop int, callback RedisResponseCallback) error { +func (c *RedisClusterClient[C]) ZRevRange(key string, start, stop int, callback RedisResponseCallback) error { + if err := c.checkReadyFunc(); err != nil { + return err + } args := make([]interface{}, 0) args = append(args, "zrevrange") args = append(args, key)