diff --git a/CODEOWNERS b/CODEOWNERS
index 3d36c596c3..c875a7968a 100644
--- a/CODEOWNERS
+++ b/CODEOWNERS
@@ -2,7 +2,8 @@
/envoy @gengleilei @johnlanni
/istio @SpecialYang @johnlanni
/pkg @SpecialYang @johnlanni @CH3CHO
-/plugins @johnlanni @WeixinX @CH3CHO
+/plugins @johnlanni @CH3CHO @rinfx
+/plugins/wasm-go/extensions/ai-proxy @cr7258 @CH3CHO @rinfx
/plugins/wasm-rust @007gzs @jizhuozhi
/registry @NameHaibinZhang @2456868764 @johnlanni
/test @Xunzhuo @2456868764 @CH3CHO
diff --git a/README.md b/README.md
index fd15371c6b..e27042f67c 100644
--- a/README.md
+++ b/README.md
@@ -6,9 +6,14 @@
AI Native API Gateway
+
+
[![Build Status](https://github.com/alibaba/higress/actions/workflows/build-and-test.yaml/badge.svg?branch=main)](https://github.com/alibaba/higress/actions)
[![license](https://img.shields.io/github/license/alibaba/higress.svg)](https://www.apache.org/licenses/LICENSE-2.0.html)
+
+
+
[**官网**](https://higress.cn/) |
[**文档**](https://higress.cn/docs/latest/overview/what-is-higress/) |
[**博客**](https://higress.cn/blog/) |
@@ -17,6 +22,7 @@
[**AI插件**](https://higress.cn/plugin/)
+
English | 中文 | 日本語
diff --git a/plugins/wasm-go/extensions/ai-cache/cache/provider.go b/plugins/wasm-go/extensions/ai-cache/cache/provider.go
index 1238d21570..d68acd5099 100644
--- a/plugins/wasm-go/extensions/ai-cache/cache/provider.go
+++ b/plugins/wasm-go/extensions/ai-cache/cache/provider.go
@@ -2,6 +2,7 @@ package cache
import (
"errors"
+ "strings"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
@@ -62,7 +63,12 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
c.serviceName = json.Get("serviceName").String()
c.servicePort = int(json.Get("servicePort").Int())
if !json.Get("servicePort").Exists() {
- c.servicePort = 6379
+ if strings.HasSuffix(c.serviceName, ".static") {
+ // use default logic port which is 80 for static service
+ c.servicePort = 80
+ } else {
+ c.servicePort = 6379
+ }
}
c.serviceHost = json.Get("serviceHost").String()
c.username = json.Get("username").String()
diff --git a/plugins/wasm-go/extensions/ai-cache/core.go b/plugins/wasm-go/extensions/ai-cache/core.go
index 19a9b2b856..b46fd28e8e 100644
--- a/plugins/wasm-go/extensions/ai-cache/core.go
+++ b/plugins/wasm-go/extensions/ai-cache/core.go
@@ -74,6 +74,9 @@ func processCacheHit(key string, response string, stream bool, ctx wrapper.HttpC
ctx.SetContext(CACHE_KEY_CONTEXT_KEY, nil)
+ ctx.SetUserAttribute("cache_status", "hit")
+ ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
+
if stream {
proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}, []byte(fmt.Sprintf(c.StreamResponseTemplate, escapedResponse)), -1)
} else {
diff --git a/plugins/wasm-go/extensions/ai-cache/go.mod b/plugins/wasm-go/extensions/ai-cache/go.mod
index e4aae265e0..56bea605f4 100644
--- a/plugins/wasm-go/extensions/ai-cache/go.mod
+++ b/plugins/wasm-go/extensions/ai-cache/go.mod
@@ -8,14 +8,14 @@ replace github.com/alibaba/higress/plugins/wasm-go => ../..
require (
github.com/alibaba/higress/plugins/wasm-go v1.4.2
- github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f
+ github.com/google/uuid v1.6.0
+ github.com/higress-group/proxy-wasm-go-sdk v1.0.0
github.com/tidwall/gjson v1.17.3
github.com/tidwall/resp v0.1.1
// github.com/weaviate/weaviate-go-client/v4 v4.15.1
)
require (
- github.com/google/uuid v1.6.0 // indirect
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect
github.com/magefile/mage v1.14.0 // indirect
github.com/stretchr/testify v1.9.0 // indirect
diff --git a/plugins/wasm-go/extensions/ai-cache/go.sum b/plugins/wasm-go/extensions/ai-cache/go.sum
index 7ada0c8b70..0a3635868b 100644
--- a/plugins/wasm-go/extensions/ai-cache/go.sum
+++ b/plugins/wasm-go/extensions/ai-cache/go.sum
@@ -3,8 +3,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
-github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
-github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
+github.com/higress-group/proxy-wasm-go-sdk v1.0.0 h1:BZRNf4R7jr9hwRivg/E29nkVaKEak5MWjBDhWjuHijU=
+github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
diff --git a/plugins/wasm-go/extensions/ai-cache/main.go b/plugins/wasm-go/extensions/ai-cache/main.go
index 1aca29f0ec..62edb80dcb 100644
--- a/plugins/wasm-go/extensions/ai-cache/main.go
+++ b/plugins/wasm-go/extensions/ai-cache/main.go
@@ -128,9 +128,15 @@ func onHttpRequestBody(ctx wrapper.HttpContext, c config.PluginConfig, body []by
func onHttpResponseHeaders(ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log) types.Action {
skipCache := ctx.GetContext(SKIP_CACHE_HEADER)
if skipCache != nil {
+ ctx.SetUserAttribute("cache_status", "skip")
+ ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
ctx.DontReadResponseBody()
return types.ActionContinue
}
+ if ctx.GetContext(CACHE_KEY_CONTEXT_KEY) != nil {
+ ctx.SetUserAttribute("cache_status", "miss")
+ ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
+ }
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
if strings.Contains(contentType, "text/event-stream") {
ctx.SetContext(STREAM_CONTEXT_KEY, struct{}{})
diff --git a/plugins/wasm-go/extensions/ai-cache/util.go b/plugins/wasm-go/extensions/ai-cache/util.go
index 983dfbb25a..7fbd4954e2 100644
--- a/plugins/wasm-go/extensions/ai-cache/util.go
+++ b/plugins/wasm-go/extensions/ai-cache/util.go
@@ -101,55 +101,58 @@ func processStreamLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, chun
}
func processSSEMessage(ctx wrapper.HttpContext, c config.PluginConfig, sseMessage string, log wrapper.Log) (string, error) {
- subMessages := strings.Split(sseMessage, "\n")
- var message string
- for _, msg := range subMessages {
- if strings.HasPrefix(msg, "data:") {
- message = msg
- break
+ content := ""
+ for _, chunk := range strings.Split(sseMessage, "\n\n") {
+ log.Infof("chunk _ : %s", chunk)
+ subMessages := strings.Split(chunk, "\n")
+ var message string
+ for _, msg := range subMessages {
+ if strings.HasPrefix(msg, "data:") {
+ message = msg
+ break
+ }
+ }
+ if len(message) < 6 {
+ return content, fmt.Errorf("[processSSEMessage] invalid message: %s", message)
}
- }
- if len(message) < 6 {
- return "", fmt.Errorf("[processSSEMessage] invalid message: %s", message)
- }
- // skip the prefix "data:"
- bodyJson := message[5:]
+ // skip the prefix "data:"
+ bodyJson := message[5:]
- if strings.TrimSpace(bodyJson) == "[DONE]" {
- return "", nil
- }
+ if strings.TrimSpace(bodyJson) == "[DONE]" {
+ return content, nil
+ }
- // Extract values from JSON fields
- responseBody := gjson.Get(bodyJson, c.CacheStreamValueFrom)
- toolCalls := gjson.Get(bodyJson, c.CacheToolCallsFrom)
+ // Extract values from JSON fields
+ responseBody := gjson.Get(bodyJson, c.CacheStreamValueFrom)
+ toolCalls := gjson.Get(bodyJson, c.CacheToolCallsFrom)
- if toolCalls.Exists() {
- // TODO: Temporarily store the tool_calls value in the context for processing
- ctx.SetContext(TOOL_CALLS_CONTEXT_KEY, toolCalls.String())
- }
-
- // Check if the ResponseBody field exists
- if !responseBody.Exists() {
- if ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) != nil {
- log.Debugf("[processSSEMessage] unable to extract content from message; cache content is not nil: %s", message)
- return "", nil
+ if toolCalls.Exists() {
+ // TODO: Temporarily store the tool_calls value in the context for processing
+ ctx.SetContext(TOOL_CALLS_CONTEXT_KEY, toolCalls.String())
}
- return "", fmt.Errorf("[processSSEMessage] unable to extract content from message; cache content is nil: %s", message)
- } else {
- tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY)
- // If there is no content in the cache, initialize and set the content
- if tempContentI == nil {
- content := responseBody.String()
- ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, content)
- return content, nil
- }
+ // Check if the ResponseBody field exists
+ if !responseBody.Exists() {
+ if ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) != nil {
+ log.Debugf("[processSSEMessage] unable to extract content from message; cache content is not nil: %s", message)
+ return content, nil
+ }
+ return content, fmt.Errorf("[processSSEMessage] unable to extract content from message; cache content is nil: %s", message)
+ } else {
+ tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY)
- // Update the content in the cache
- appendMsg := responseBody.String()
- content := tempContentI.(string) + appendMsg
- ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, content)
- return content, nil
+ // If there is no content in the cache, initialize and set the content
+ if tempContentI == nil {
+ content = responseBody.String()
+ ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, content)
+ } else {
+ // Update the content in the cache
+ appendMsg := responseBody.String()
+ content = tempContentI.(string) + appendMsg
+ ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, content)
+ }
+ }
}
+ return content, nil
}
diff --git a/plugins/wasm-go/extensions/ai-history/go.sum b/plugins/wasm-go/extensions/ai-history/go.sum
index 6b1c2c3cd7..b4ab172fe2 100644
--- a/plugins/wasm-go/extensions/ai-history/go.sum
+++ b/plugins/wasm-go/extensions/ai-history/go.sum
@@ -3,15 +3,13 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
-github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
-github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
+github.com/higress-group/proxy-wasm-go-sdk v1.0.0 h1:BZRNf4R7jr9hwRivg/E29nkVaKEak5MWjBDhWjuHijU=
github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
-github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
-github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
+github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94=
github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
diff --git a/plugins/wasm-go/extensions/ai-history/main.go b/plugins/wasm-go/extensions/ai-history/main.go
index 512e13f1c6..3f728dd96d 100644
--- a/plugins/wasm-go/extensions/ai-history/main.go
+++ b/plugins/wasm-go/extensions/ai-history/main.go
@@ -194,6 +194,12 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte
ctx.SetContext(StreamContextKey, struct{}{})
}
identityKey := ctx.GetStringContext(IdentityKey, "")
+ question := TrimQuote(bodyJson.Get(config.QuestionFrom.RequestBody).String())
+ if question == "" {
+ log.Debug("parse question from request body failed")
+ return types.ActionContinue
+ }
+ ctx.SetContext(QuestionContextKey, question)
err := config.redisClient.Get(config.CacheKeyPrefix+identityKey, func(response resp.Value) {
if err := response.Error(); err != nil {
log.Errorf("redis get failed, err:%v", err)
@@ -230,13 +236,6 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte
_ = proxywasm.SendHttpResponseWithDetail(200, "OK", [][2]string{{"content-type", "application/json; charset=utf-8"}}, res, -1)
return
}
- question := TrimQuote(bodyJson.Get(config.QuestionFrom.RequestBody).String())
- if question == "" {
- log.Debug("parse question from request body failed")
- _ = proxywasm.ResumeHttpRequest()
- return
- }
- ctx.SetContext(QuestionContextKey, question)
fillHistoryCnt := getIntQueryParameter("fill_history_cnt", path, config.FillHistoryCnt) * 2
currJson := bodyJson.Get("messages").String()
var currMessage []ChatHistory
@@ -317,38 +316,39 @@ func getIntQueryParameter(name string, path string, defaultValue int) int {
}
func processSSEMessage(ctx wrapper.HttpContext, config PluginConfig, sseMessage string, log wrapper.Log) string {
- subMessages := strings.Split(sseMessage, "\n")
- var message string
- for _, msg := range subMessages {
- if strings.HasPrefix(msg, "data:") {
- message = msg
- break
+ content := ""
+ for _, chunk := range strings.Split(sseMessage, "\n\n") {
+ subMessages := strings.Split(chunk, "\n")
+ var message string
+ for _, msg := range subMessages {
+ if strings.HasPrefix(msg, "data:") {
+ message = msg
+ break
+ }
}
- }
- if len(message) < 6 {
- log.Errorf("invalid message:%s", message)
- return ""
- }
- // skip the prefix "data:"
- bodyJson := message[5:]
- if gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Exists() {
- tempContentI := ctx.GetContext(AnswerContentContextKey)
- if tempContentI == nil {
- content := TrimQuote(gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Raw)
- ctx.SetContext(AnswerContentContextKey, content)
+ if len(message) < 6 {
+ log.Errorf("invalid message:%s", message)
return content
}
- append := TrimQuote(gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Raw)
- content := tempContentI.(string) + append
- ctx.SetContext(AnswerContentContextKey, content)
- return content
- } else if gjson.Get(bodyJson, "choices.0.delta.content.tool_calls").Exists() {
- // TODO: compatible with other providers
- ctx.SetContext(ToolCallsContextKey, struct{}{})
- return ""
+ // skip the prefix "data:"
+ bodyJson := message[5:]
+ if gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Exists() {
+ tempContentI := ctx.GetContext(AnswerContentContextKey)
+ if tempContentI == nil {
+ content = TrimQuote(gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Raw)
+ ctx.SetContext(AnswerContentContextKey, content)
+ } else {
+ append := TrimQuote(gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Raw)
+ content = tempContentI.(string) + append
+ ctx.SetContext(AnswerContentContextKey, content)
+ }
+ } else if gjson.Get(bodyJson, "choices.0.delta.content.tool_calls").Exists() {
+ // TODO: compatible with other providers
+ ctx.SetContext(ToolCallsContextKey, struct{}{})
+ }
+ log.Debugf("unknown message:%s", bodyJson)
}
- log.Debugf("unknown message:%s", bodyJson)
- return ""
+ return content
}
func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {
diff --git a/plugins/wasm-go/extensions/ai-proxy/README.md b/plugins/wasm-go/extensions/ai-proxy/README.md
index 8317f653d4..80b7c2a890 100644
--- a/plugins/wasm-go/extensions/ai-proxy/README.md
+++ b/plugins/wasm-go/extensions/ai-proxy/README.md
@@ -174,9 +174,10 @@ Mistral 所对应的 `type` 为 `mistral`。它并无特有的配置字段。
MiniMax所对应的 `type` 为 `minimax`。它特有的配置字段如下:
-| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
-| ---------------- | -------- | ------------------------------------------------------------ | ------ | ------------------------------------------------------------ |
-| `minimaxGroupId` | string | 当使用`abab6.5-chat`, `abab6.5s-chat`, `abab5.5s-chat`, `abab5.5-chat`四种模型时必填 | - | 当使用`abab6.5-chat`, `abab6.5s-chat`, `abab5.5s-chat`, `abab5.5-chat`四种模型时会使用ChatCompletion Pro,需要设置groupID |
+| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
+| ---------------- | -------- | ------------------------------ | ------ |----------------------------------------------------------------|
+| `minimaxApiType` | string | v2 和 pro 中选填一项 | v2 | v2 代表 ChatCompletion v2 API,pro 代表 ChatCompletion Pro API |
+| `minimaxGroupId` | string | `minimaxApiType` 为 pro 时必填 | - | `minimaxApiType` 为 pro 时使用 ChatCompletion Pro API,需要设置 groupID |
#### Anthropic Claude
@@ -1000,17 +1001,16 @@ provider:
apiTokens:
- "YOUR_MINIMAX_API_TOKEN"
modelMapping:
- "gpt-3": "abab6.5g-chat"
- "gpt-4": "abab6.5-chat"
- "*": "abab6.5g-chat"
- minimaxGroupId: "YOUR_MINIMAX_GROUP_ID"
+ "gpt-3": "abab6.5s-chat"
+ "gpt-4": "abab6.5g-chat"
+ "*": "abab6.5t-chat"
```
**请求示例**
```json
{
- "model": "gpt-4-turbo",
+ "model": "gpt-3",
"messages": [
{
"role": "user",
@@ -1025,27 +1025,33 @@ provider:
```json
{
- "id": "02b2251f8c6c09d68c1743f07c72afd7",
+ "id": "03ac4fcfe1c6cc9c6a60f9d12046e2b4",
"choices": [
{
"finish_reason": "stop",
"index": 0,
"message": {
- "content": "你好!我是MM智能助理,一款由MiniMax自研的大型语言模型。我可以帮助你解答问题,提供信息,进行对话等。有什么可以帮助你的吗?",
- "role": "assistant"
+ "content": "你好,我是一个由MiniMax公司研发的大型语言模型,名为MM智能助理。我可以帮助回答问题、提供信息、进行对话和执行多种语言处理任务。如果你有任何问题或需要帮助,请随时告诉我!",
+ "role": "assistant",
+ "name": "MM智能助理",
+ "audio_content": ""
}
}
],
- "created": 1717760544,
+ "created": 1734155471,
"model": "abab6.5s-chat",
"object": "chat.completion",
"usage": {
- "total_tokens": 106
+ "total_tokens": 116,
+ "total_characters": 0,
+ "prompt_tokens": 70,
+ "completion_tokens": 46
},
"input_sensitive": false,
"output_sensitive": false,
"input_sensitive_type": 0,
"output_sensitive_type": 0,
+ "output_sensitive_int": 0,
"base_resp": {
"status_code": 0,
"status_msg": ""
diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go
index 0bc62175e2..3f4dc49bab 100644
--- a/plugins/wasm-go/extensions/ai-proxy/main.go
+++ b/plugins/wasm-go/extensions/ai-proxy/main.go
@@ -89,29 +89,35 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
}
if apiName == "" {
- log.Debugf("[onHttpRequestHeader] unsupported path: %s", path.Path)
- // _ = util.SendResponse(404, "ai-proxy.unknown_api", util.MimeTypeTextPlain, "API not found: "+path.Path)
- log.Debugf("[onHttpRequestHeader] no send response")
+ log.Warnf("[onHttpRequestHeader] unsupported path: %s", path.Path)
return types.ActionContinue
}
+ // Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
+ ctx.DisableReroute()
+
ctx.SetContext(ctxKeyApiName, apiName)
+ _, needHandleBody := activeProvider.(provider.ResponseBodyHandler)
+ _, needHandleStreamingBody := activeProvider.(provider.StreamingResponseBodyHandler)
+ if needHandleBody || needHandleStreamingBody {
+ proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
+ }
+
if handler, ok := activeProvider.(provider.RequestHeadersHandler); ok {
- // Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
- ctx.DisableReroute()
// Set the apiToken for the current request.
providerConfig.SetApiTokenInUse(ctx, log)
hasRequestBody := wrapper.HasRequestBody()
- action, err := handler.OnRequestHeaders(ctx, apiName, log)
+ err := handler.OnRequestHeaders(ctx, apiName, log)
if err == nil {
if hasRequestBody {
+ proxywasm.RemoveHttpRequestHeader("Content-Length")
ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes)
- // Always return types.HeaderStopIteration to support fallback routing,
- // as long as onHttpRequestBody can be called.
+ // Delay the header processing to allow changing in OnRequestBody
return types.HeaderStopIteration
}
- return action
+ ctx.DontReadRequestBody()
+ return types.ActionContinue
}
util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err))
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go
index 6f42d570d0..fa5f1362c1 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go
@@ -40,13 +40,13 @@ func (m *ai360Provider) GetProviderType() string {
return providerTypeAi360
}
-func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
+func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
- return types.ActionContinue, errUnsupportedApiName
+ return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
- return types.HeaderStopIteration, nil
+ return nil
}
func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
@@ -58,7 +58,5 @@ func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
func (m *ai360Provider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestHostHeader(headers, ai360Domain)
- util.OverwriteRequestAuthorizationHeader(headers, "Authorization "+m.config.GetApiTokenInUse(ctx))
- headers.Del("Accept-Encoding")
- headers.Del("Content-Length")
+ util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx))
}
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go
index b09cdd0951..9e02d0fd9a 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go
@@ -53,12 +53,12 @@ func (m *azureProvider) GetProviderType() string {
return providerTypeAzure
}
-func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
+func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
- return types.ActionContinue, errUnsupportedApiName
+ return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
- return types.ActionContinue, nil
+ return nil
}
func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
@@ -86,6 +86,6 @@ func (m *azureProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName
util.OverwriteRequestPathHeader(headers, m.serviceUrl.RequestURI())
}
util.OverwriteRequestHostHeader(headers, m.serviceUrl.Host)
- util.OverwriteRequestAuthorizationHeader(headers, "api-key "+m.config.GetApiTokenInUse(ctx))
+ headers.Set("api-key", m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
}
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go
index b43ba8ee26..759c2dd036 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go
@@ -42,12 +42,12 @@ func (m *baichuanProvider) GetProviderType() string {
return providerTypeBaichuan
}
-func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
+func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
- return types.ActionContinue, errUnsupportedApiName
+ return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
- return types.ActionContinue, nil
+ return nil
}
func (m *baichuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go
index 0908836290..595ef3d4ff 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go
@@ -63,12 +63,12 @@ func (g *baiduProvider) GetProviderType() string {
return providerTypeBaidu
}
-func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
+func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
- return types.ActionContinue, errUnsupportedApiName
+ return errUnsupportedApiName
}
g.config.handleRequestHeaders(g, ctx, apiName, log)
- return types.ActionContinue, nil
+ return nil
}
func (g *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go
index 8b98d62d64..9943469749 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go
@@ -102,27 +102,25 @@ func (c *claudeProvider) GetProviderType() string {
return providerTypeClaude
}
-func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
+func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
- return types.ActionContinue, errUnsupportedApiName
+ return errUnsupportedApiName
}
c.config.handleRequestHeaders(c, ctx, apiName, log)
- return types.ActionContinue, nil
+ return nil
}
func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, claudeChatCompletionPath)
util.OverwriteRequestHostHeader(headers, claudeDomain)
- headers.Add("x-api-key", c.config.GetApiTokenInUse(ctx))
+ headers.Set("x-api-key", c.config.GetApiTokenInUse(ctx))
if c.config.claudeVersion == "" {
c.config.claudeVersion = defaultVersion
}
- headers.Add("anthropic-version", c.config.claudeVersion)
- headers.Del("Accept-Encoding")
- headers.Del("Content-Length")
+ headers.Set("anthropic-version", c.config.claudeVersion)
}
func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go
index 2f6108b0df..4340183ee4 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go
@@ -42,12 +42,12 @@ func (c *cloudflareProvider) GetProviderType() string {
return providerTypeCloudflare
}
-func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
+func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
- return types.ActionContinue, errUnsupportedApiName
+ return errUnsupportedApiName
}
c.config.handleRequestHeaders(c, ctx, apiName, log)
- return types.ActionContinue, nil
+ return nil
}
func (c *cloudflareProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
@@ -61,6 +61,4 @@ func (c *cloudflareProvider) TransformRequestHeaders(ctx wrapper.HttpContext, ap
util.OverwriteRequestPathHeader(headers, strings.Replace(cloudflareChatCompletionPath, "{account_id}", c.config.cloudflareAccountId, 1))
util.OverwriteRequestHostHeader(headers, cloudflareDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+c.config.GetApiTokenInUse(ctx))
- headers.Del("Accept-Encoding")
- headers.Del("Content-Length")
}
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go
index 72dbaf280b..a3b930e7fb 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go
@@ -3,11 +3,12 @@ package provider
import (
"encoding/json"
"errors"
+ "net/http"
+ "strings"
+
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
- "net/http"
- "strings"
)
const (
@@ -54,12 +55,12 @@ func (m *cohereProvider) GetProviderType() string {
return providerTypeCohere
}
-func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
+func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
- return types.ActionContinue, errUnsupportedApiName
+ return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
- return types.ActionContinue, nil
+ return nil
}
func (m *cohereProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/coze.go b/plugins/wasm-go/extensions/ai-proxy/provider/coze.go
index 878bbb9f9a..43cdca60fb 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/coze.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/coze.go
@@ -6,7 +6,6 @@ import (
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
- "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
const (
@@ -38,9 +37,9 @@ func (m *cozeProvider) GetProviderType() string {
return providerTypeCoze
}
-func (m *cozeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
+func (m *cozeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
m.config.handleRequestHeaders(m, ctx, apiName, log)
- return types.ActionContinue, nil
+ return nil
}
func (m *cozeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go
index bafe6b3dde..345a70c94a 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go
@@ -76,19 +76,17 @@ func (d *deeplProvider) GetProviderType() string {
return providerTypeDeepl
}
-func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
+func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
- return types.ActionContinue, errUnsupportedApiName
+ return errUnsupportedApiName
}
d.config.handleRequestHeaders(d, ctx, apiName, log)
- return types.HeaderStopIteration, nil
+ return nil
}
func (d *deeplProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, deeplChatCompletionPath)
util.OverwriteRequestAuthorizationHeader(headers, "DeepL-Auth-Key "+d.config.GetApiTokenInUse(ctx))
- headers.Del("Content-Length")
- headers.Del("Accept-Encoding")
}
func (d *deeplProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go
index 9cad3928f5..7d240f09ae 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go
@@ -2,10 +2,11 @@ package provider
import (
"errors"
+ "net/http"
+
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
- "net/http"
)
// deepseekProvider is the provider for deepseek Ai service.
@@ -41,12 +42,12 @@ func (m *deepseekProvider) GetProviderType() string {
return providerTypeDeepSeek
}
-func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
+func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
- return types.ActionContinue, errUnsupportedApiName
+ return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
- return types.ActionContinue, nil
+ return nil
}
func (m *deepseekProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go b/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go
index 651b983206..96a4aab548 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go
@@ -2,11 +2,12 @@ package provider
import (
"errors"
+ "net/http"
+ "strings"
+
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
- "net/http"
- "strings"
)
const (
@@ -39,12 +40,12 @@ func (m *doubaoProvider) GetProviderType() string {
return providerTypeDoubao
}
-func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
+func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
- return types.ActionContinue, errUnsupportedApiName
+ return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
- return types.ActionContinue, nil
+ return nil
}
func (m *doubaoProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go
index a4c1ef2cd9..7a9b0a3dd0 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go
@@ -51,20 +51,18 @@ func (g *geminiProvider) GetProviderType() string {
return providerTypeGemini
}
-func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
+func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
- return types.ActionContinue, errUnsupportedApiName
+ return errUnsupportedApiName
}
g.config.handleRequestHeaders(g, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
- return types.HeaderStopIteration, nil
+ return nil
}
func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestHostHeader(headers, geminiDomain)
- headers.Add(geminiApiKeyHeader, g.config.GetApiTokenInUse(ctx))
- headers.Del("Accept-Encoding")
- headers.Del("Content-Length")
+ headers.Set(geminiApiKeyHeader, g.config.GetApiTokenInUse(ctx))
}
func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/github.go b/plugins/wasm-go/extensions/ai-proxy/provider/github.go
index 0a2b0c84de..348134c0a5 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/github.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/github.go
@@ -2,11 +2,12 @@ package provider
import (
"errors"
+ "net/http"
+ "strings"
+
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
- "net/http"
- "strings"
)
// githubProvider is the provider for GitHub OpenAI service.
@@ -42,13 +43,13 @@ func (m *githubProvider) GetProviderType() string {
return providerTypeGithub
}
-func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
+func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
- return types.ActionContinue, errUnsupportedApiName
+ return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
- return types.HeaderStopIteration, nil
+ return nil
}
func (m *githubProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
@@ -67,8 +68,6 @@ func (m *githubProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam
util.OverwriteRequestPathHeader(headers, githubEmbeddingPath)
}
util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx))
- headers.Del("Accept-Encoding")
- headers.Del("Content-Length")
}
func (m *githubProvider) GetApiName(path string) ApiName {
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/groq.go b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go
index dfbd971261..5f2734519d 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/groq.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go
@@ -41,12 +41,12 @@ func (g *groqProvider) GetProviderType() string {
return providerTypeGroq
}
-func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
+func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
- return types.ActionContinue, errUnsupportedApiName
+ return errUnsupportedApiName
}
g.config.handleRequestHeaders(g, ctx, apiName, log)
- return types.ActionContinue, nil
+ return nil
}
func (g *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go
index b6a49eb551..4b10a4d7c5 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go
@@ -114,13 +114,13 @@ func (m *hunyuanProvider) GetProviderType() string {
return providerTypeHunyuan
}
-func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
+func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
- return types.ActionContinue, errUnsupportedApiName
+ return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
- return types.HeaderStopIteration, nil
+ return nil
}
func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
@@ -128,11 +128,8 @@ func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNa
util.OverwriteRequestPathHeader(headers, hunyuanRequestPath)
// 添加 hunyuan 需要的自定义字段
- headers.Add(actionKey, hunyuanChatCompletionTCAction)
- headers.Add(versionKey, versionValue)
-
- headers.Del("Accept-Encoding")
- headers.Del("Content-Length")
+ headers.Set(actionKey, hunyuanChatCompletionTCAction)
+ headers.Set(versionKey, versionValue)
}
// hunyuan 的 OnRequestBody 逻辑中包含了对 headers 签名的逻辑,并且插入 context 以后还要重新计算签名,因此无法复用 handleRequestBody 方法
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go
index 0bcf7ac326..9531edcf11 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go
@@ -11,47 +11,37 @@ import (
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
)
// minimaxProvider is the provider for minimax service.
const (
- minimaxDomain = "api.minimax.chat"
- // minimaxChatCompletionV2Path 接口请求响应格式与OpenAI相同
- // 接口文档: https://platform.minimaxi.com/document/guides/chat-model/V2?id=65e0736ab2845de20908e2dd
+ minimaxApiTypeV2 = "v2" // minimaxApiTypeV2 represents chat completion V2 API.
+ minimaxApiTypePro = "pro" // minimaxApiTypePro represents chat completion Pro API.
+ minimaxDomain = "api.minimax.chat"
+ // minimaxChatCompletionV2Path represents the API path for chat completion V2 API which has a response format similar to OpenAI's.
minimaxChatCompletionV2Path = "/v1/text/chatcompletion_v2"
- // minimaxChatCompletionProPath 接口请求响应格式与OpenAI不同
- // 接口文档: https://platform.minimaxi.com/document/guides/chat-model/pro/api?id=6569c85948bc7b684b30377e
+ // minimaxChatCompletionProPath represents the API path for chat completion Pro API which has a different response format from OpenAI's.
minimaxChatCompletionProPath = "/v1/text/chatcompletion_pro"
- senderTypeUser string = "USER" // 用户发送的内容
- senderTypeBot string = "BOT" // 模型生成的内容
+ senderTypeUser string = "USER" // Content sent by the user.
+ senderTypeBot string = "BOT" // Content generated by the model.
- // 默认机器人设置
+ // Default bot settings.
defaultBotName string = "MM智能助理"
defaultBotSettingContent string = "MM智能助理是一款由MiniMax自研的,没有调用其他产品的接口的大型语言模型。MiniMax是一家中国科技公司,一直致力于进行大模型相关的研究。"
defaultSenderName string = "小明"
)
-// chatCompletionProModels 这些模型对应接口为ChatCompletion Pro
-var chatCompletionProModels = map[string]struct{}{
- "abab6.5-chat": {},
- "abab6.5s-chat": {},
- "abab5.5s-chat": {},
- "abab5.5-chat": {},
-}
-
type minimaxProviderInitializer struct {
}
func (m *minimaxProviderInitializer) ValidateConfig(config ProviderConfig) error {
- // 如果存在模型对应接口为ChatCompletion Pro必须配置minimaxGroupId
- if len(config.modelMapping) > 0 && config.minimaxGroupId == "" {
- for _, minimaxModel := range config.modelMapping {
- if _, exists := chatCompletionProModels[minimaxModel]; exists {
- return errors.New(fmt.Sprintf("missing minimaxGroupId in provider config when %s model is provided", minimaxModel))
- }
- }
+ // If using the chat completion Pro API, a group ID must be set.
+ if minimaxApiTypePro == config.minimaxApiType && config.minimaxGroupId == "" {
+ return errors.New(fmt.Sprintf("missing minimaxGroupId in provider config when minimaxApiType is %s", minimaxApiTypePro))
}
if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
@@ -75,13 +65,13 @@ func (m *minimaxProvider) GetProviderType() string {
return providerTypeMinimax
}
-func (m *minimaxProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
+func (m *minimaxProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
- return types.ActionContinue, errUnsupportedApiName
+ return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
- return types.HeaderStopIteration, nil
+ return nil
}
func (m *minimaxProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
@@ -94,23 +84,11 @@ func (m *minimaxProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
- // 解析并映射模型,设置上下文
- model, err := m.parseModel(body)
- if err != nil {
- return types.ActionContinue, err
- }
- ctx.SetContext(ctxKeyOriginalRequestModel, model)
- mappedModel := getMappedModel(model, m.config.modelMapping, log)
- if mappedModel == "" {
- return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
- }
- ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
- _, ok := chatCompletionProModels[mappedModel]
- if ok {
- // 使用ChatCompletion Pro接口
+ if minimaxApiTypePro == m.config.minimaxApiType {
+ // Use chat completion Pro API.
return m.handleRequestBodyByChatCompletionPro(body, log)
} else {
- // 使用ChatCompletion v2接口
+ // Use chat completion V2 API.
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
}
@@ -119,14 +97,14 @@ func (m *minimaxProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, a
return m.handleRequestBodyByChatCompletionV2(body, headers, log)
}
-// handleRequestBodyByChatCompletionPro 使用ChatCompletion Pro接口处理请求体
+// handleRequestBodyByChatCompletionPro processes the request body using the chat completion Pro API.
func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log wrapper.Log) (types.Action, error) {
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
- // 映射模型重写requestPath
+ // Map the model and rewrite the request path.
request.Model = getMappedModel(request.Model, m.config.modelMapping, log)
_ = util.OverwriteRequestPath(fmt.Sprintf("%s?GroupId=%s", minimaxChatCompletionProPath, m.config.minimaxGroupId))
@@ -143,9 +121,9 @@ func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log
log.Errorf("failed to load context file: %v", err)
util.ErrorHandler("ai-proxy.minimax.load_ctx_failed", fmt.Errorf("failed to load context file: %v", err))
}
- // 由于 minimaxChatCompletionV2(格式和 OpenAI 一致)和 minimaxChatCompletionPro(格式和 OpenAI 不一致)中 insertHttpContextMessage 的逻辑不同,无法做到同一个 provider 统一
- // 因此对于 minimaxChatCompletionPro 需要手动处理 context 消息
- // minimaxChatCompletionV2 交给默认的 defaultInsertHttpContextMessage 方法插入 context 消息
+ // Since minimaxChatCompletionV2 (format consistent with OpenAI) and minimaxChatCompletionPro (different format from OpenAI) have different logic for insertHttpContextMessage, we cannot unify them within one provider.
+ // For minimaxChatCompletionPro, we need to manually handle context messages.
+ // minimaxChatCompletionV2 uses the default defaultInsertHttpContextMessage method to insert context messages.
minimaxRequest := m.buildMinimaxChatCompletionV2Request(request, content)
if err := replaceJsonRequestBody(minimaxRequest, log); err != nil {
util.ErrorHandler("ai-proxy.minimax.insert_ctx_failed", fmt.Errorf("failed to replace Request body: %v", err))
@@ -157,54 +135,45 @@ func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log
return types.ActionContinue, err
}
-// handleRequestBodyByChatCompletionV2 使用ChatCompletion v2接口处理请求体
+// handleRequestBodyByChatCompletionV2 processes the request body using the chat completion V2 API.
func (m *minimaxProvider) handleRequestBodyByChatCompletionV2(body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
- request := &chatCompletionRequest{}
- if err := decodeChatCompletionRequest(body, request); err != nil {
- return nil, err
- }
-
- // 映射模型重写requestPath
- request.Model = getMappedModel(request.Model, m.config.modelMapping, log)
util.OverwriteRequestPathHeader(headers, minimaxChatCompletionV2Path)
- return body, nil
+ rawModel := gjson.GetBytes(body, "model").String()
+ mappedModel := getMappedModel(rawModel, m.config.modelMapping, log)
+ return sjson.SetBytes(body, "model", mappedModel)
}
func (m *minimaxProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
- // 使用minimax接口协议,跳过OnStreamingResponseBody()和OnResponseBody()
+ // Skip OnStreamingResponseBody() and OnResponseBody() when using original protocol.
if m.config.protocol == protocolOriginal {
ctx.DontReadResponseBody()
return types.ActionContinue, nil
}
- // 模型对应接口为ChatCompletion v2,跳过OnStreamingResponseBody()和OnResponseBody()
- model := ctx.GetStringContext(ctxKeyFinalRequestModel, "")
- if model != "" {
- _, ok := chatCompletionProModels[model]
- if !ok {
- ctx.DontReadResponseBody()
- return types.ActionContinue, nil
- }
+ // Skip OnStreamingResponseBody() and OnResponseBody() when the model corresponds to the chat completion V2 interface.
+ if minimaxApiTypePro != m.config.minimaxApiType {
+ ctx.DontReadResponseBody()
+ return types.ActionContinue, nil
}
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
return types.ActionContinue, nil
}
-// OnStreamingResponseBody 只处理使用OpenAI协议 且 模型对应接口为ChatCompletion Pro的流式响应
+// OnStreamingResponseBody handles streaming response chunks from the Minimax service only for requests using the OpenAI protocol and corresponding to the chat completion Pro API.
func (m *minimaxProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
if isLastChunk || len(chunk) == 0 {
return nil, nil
}
- // sample event response:
+ // Sample event response:
// data: {"created":1689747645,"model":"abab6.5s-chat","reply":"","choices":[{"messages":[{"sender_type":"BOT","sender_name":"MM智能助理","text":"am from China."}]}],"output_sensitive":false}
- // sample end event response:
+ // Sample end event response:
// data: {"created":1689747645,"model":"abab6.5s-chat","reply":"I am from China.","choices":[{"finish_reason":"stop","messages":[{"sender_type":"BOT","sender_name":"MM智能助理","text":"I am from China."}]}],"usage":{"total_tokens":187},"input_sensitive":false,"output_sensitive":false,"id":"0106b3bc9fd844a9f3de1aa06004e2ab","base_resp":{"status_code":0,"status_msg":""}}
responseBuilder := &strings.Builder{}
lines := strings.Split(string(chunk), "\n")
for _, data := range lines {
if len(data) < 6 {
- // ignore blank line or wrong format
+ // Ignore blank line or improperly formatted lines.
continue
}
data = data[6:]
@@ -226,7 +195,7 @@ func (m *minimaxProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name
return []byte(modifiedResponseChunk), nil
}
-// OnResponseBody 只处理使用OpenAI协议 且 模型对应接口为ChatCompletion Pro的流式响应
+// OnResponseBody handles the final response body from the Minimax service only for requests using the OpenAI protocol and corresponding to the chat completion Pro API.
func (m *minimaxProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
minimaxResp := &minimaxChatCompletionV2Resp{}
if err := json.Unmarshal(body, minimaxResp); err != nil {
@@ -239,39 +208,39 @@ func (m *minimaxProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiNam
return types.ActionContinue, replaceJsonResponseBody(response, log)
}
-// minimaxChatCompletionV2Request 表示ChatCompletion V2请求的结构体
+// minimaxChatCompletionV2Request represents the structure of a chat completion V2 request.
type minimaxChatCompletionV2Request struct {
Model string `json:"model"`
Stream bool `json:"stream,omitempty"`
TokensToGenerate int64 `json:"tokens_to_generate,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
- MaskSensitiveInfo bool `json:"mask_sensitive_info"` // 是否开启隐私信息打码,默认true
+ MaskSensitiveInfo bool `json:"mask_sensitive_info"` // Whether to mask sensitive information, defaults to true.
Messages []minimaxMessage `json:"messages"`
BotSettings []minimaxBotSetting `json:"bot_setting"`
ReplyConstraints minimaxReplyConstraints `json:"reply_constraints"`
}
-// minimaxMessage 表示对话中的消息
+// minimaxMessage represents a message in the conversation.
type minimaxMessage struct {
SenderType string `json:"sender_type"`
SenderName string `json:"sender_name"`
Text string `json:"text"`
}
-// minimaxBotSetting 表示机器人的设置
+// minimaxBotSetting represents the bot's settings.
type minimaxBotSetting struct {
BotName string `json:"bot_name"`
Content string `json:"content"`
}
-// minimaxReplyConstraints 表示模型回复要求
+// minimaxReplyConstraints represents requirements for model replies.
type minimaxReplyConstraints struct {
SenderType string `json:"sender_type"`
SenderName string `json:"sender_name"`
}
-// minimaxChatCompletionV2Resp Minimax Chat Completion V2响应结构体
+// minimaxChatCompletionV2Resp represents the structure of a Minimax Chat Completion V2 response.
type minimaxChatCompletionV2Resp struct {
Created int64 `json:"created"`
Model string `json:"model"`
@@ -286,20 +255,20 @@ type minimaxChatCompletionV2Resp struct {
BaseResp minimaxBaseResp `json:"base_resp"`
}
-// minimaxBaseResp 包含错误状态码和详情
+// minimaxBaseResp contains error status code and details.
type minimaxBaseResp struct {
StatusCode int64 `json:"status_code"`
StatusMsg string `json:"status_msg"`
}
-// minimaxChoice 结果选项
+// minimaxChoice represents a result option.
type minimaxChoice struct {
Messages []minimaxMessage `json:"messages"`
Index int64 `json:"index"`
FinishReason string `json:"finish_reason"`
}
-// minimaxUsage 令牌使用情况
+// minimaxUsage represents token usage statistics.
type minimaxUsage struct {
TotalTokens int64 `json:"total_tokens"`
}
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go
index 3e5323a60c..041665f9dd 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go
@@ -2,10 +2,11 @@ package provider
import (
"errors"
+ "net/http"
+
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
- "net/http"
)
const (
@@ -37,12 +38,12 @@ func (m *mistralProvider) GetProviderType() string {
return providerTypeMistral
}
-func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
+func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
- return types.ActionContinue, errUnsupportedApiName
+ return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
- return types.ActionContinue, nil
+ return nil
}
func (m *mistralProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go
index 38d99ae0eb..733cc038b4 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go
@@ -56,12 +56,12 @@ func (m *moonshotProvider) GetProviderType() string {
return providerTypeMoonshot
}
-func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
+func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
- return types.ActionContinue, errUnsupportedApiName
+ return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
- return types.ActionContinue, nil
+ return nil
}
func (m *moonshotProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go b/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go
index 5339083819..1bed639f33 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go
@@ -3,10 +3,11 @@ package provider
import (
"errors"
"fmt"
+ "net/http"
+
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
- "net/http"
)
// ollamaProvider is the provider for Ollama service.
@@ -48,12 +49,12 @@ func (m *ollamaProvider) GetProviderType() string {
return providerTypeOllama
}
-func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
+func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
- return types.ActionContinue, errUnsupportedApiName
+ return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
- return types.ActionContinue, nil
+ return nil
}
func (m *ollamaProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go
index 60c835cd49..480fdda571 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go
@@ -57,9 +57,9 @@ func (m *openaiProvider) GetProviderType() string {
return providerTypeOpenAI
}
-func (m *openaiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
+func (m *openaiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
m.config.handleRequestHeaders(m, ctx, apiName, log)
- return types.ActionContinue, nil
+ return nil
}
func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go
index 478f7a24b6..0f482732aa 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go
@@ -118,7 +118,7 @@ type ApiNameHandler interface {
}
type RequestHeadersHandler interface {
- OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error)
+ OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error
}
type TransformRequestHeadersHandler interface {
@@ -206,8 +206,11 @@ type ProviderConfig struct {
// @Title zh-CN hunyuan api id for authorization
// @Description zh-CN 仅适用于Hun Yuan AI服务鉴权
hunyuanAuthId string `required:"false" yaml:"hunyuanAuthId" json:"hunyuanAuthId"`
+ // @Title zh-CN minimax API type
+ // @Description zh-CN 仅适用于 minimax 服务。minimax API 类型,v2 和 pro 中选填一项,默认值为 v2
+ minimaxApiType string `required:"false" yaml:"minimaxApiType" json:"minimaxApiType"`
// @Title zh-CN minimax group id
- // @Description zh-CN 仅适用于minimax使用ChatCompletion Pro接口的模型
+ // @Description zh-CN 仅适用于 minimax 服务。minimax API 类型为 pro 时必填
minimaxGroupId string `required:"false" yaml:"minimaxGroupId" json:"minimaxGroupId"`
// @Title zh-CN 模型名称映射表
// @Description zh-CN 用于将请求中的模型名称映射为目标AI服务商支持的模型名称。支持通过“*”来配置全局映射
@@ -303,6 +306,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
c.claudeVersion = json.Get("claudeVersion").String()
c.hunyuanAuthId = json.Get("hunyuanAuthId").String()
c.hunyuanAuthKey = json.Get("hunyuanAuthKey").String()
+ c.minimaxApiType = json.Get("minimaxApiType").String()
c.minimaxGroupId = json.Get("minimaxGroupId").String()
c.cloudflareAccountId = json.Get("cloudflareAccountId").String()
if c.typ == providerTypeGemini {
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go
index a4a727724e..95fe28e4bd 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go
@@ -27,6 +27,7 @@ const (
qwenChatCompletionPath = "/api/v1/services/aigc/text-generation/generation"
qwenTextEmbeddingPath = "/api/v1/services/embeddings/text-embedding/text-embedding"
qwenCompatiblePath = "/compatible-mode/v1/chat/completions"
+ qwenBailianPath = "/api/v1/apps"
qwenMultimodalGenerationPath = "/api/v1/services/aigc/multimodal-generation/generation"
qwenTopPMin = 0.000001
@@ -71,16 +72,14 @@ func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName
}
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
- if m.config.qwenEnableCompatible {
+ if m.config.IsOriginal() {
+ } else if m.config.qwenEnableCompatible {
util.OverwriteRequestPathHeader(headers, qwenCompatiblePath)
} else if apiName == ApiNameChatCompletion {
util.OverwriteRequestPathHeader(headers, qwenChatCompletionPath)
} else if apiName == ApiNameEmbeddings {
util.OverwriteRequestPathHeader(headers, qwenTextEmbeddingPath)
}
-
- headers.Del("Accept-Encoding")
- headers.Del("Content-Length")
}
func (m *qwenProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
@@ -95,20 +94,19 @@ func (m *qwenProvider) GetProviderType() string {
return providerTypeQwen
}
-func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
+func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
- return types.ActionContinue, errUnsupportedApiName
+ return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
if m.config.protocol == protocolOriginal {
ctx.DontReadRequestBody()
- return types.ActionContinue, nil
+ return nil
}
- // Delay the header processing to allow changing streaming mode in OnRequestBody
- return types.HeaderStopIteration, nil
+ return nil
}
func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
@@ -762,6 +760,7 @@ func (m *qwenProvider) GetApiName(path string) ApiName {
switch {
case strings.Contains(path, qwenChatCompletionPath),
strings.Contains(path, qwenMultimodalGenerationPath),
+ strings.Contains(path, qwenBailianPath),
strings.Contains(path, qwenCompatiblePath):
return ApiNameChatCompletion
case strings.Contains(path, qwenTextEmbeddingPath):
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go
index c2e013643c..f44b9e3c0f 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go
@@ -67,12 +67,12 @@ func (p *sparkProvider) GetProviderType() string {
return providerTypeSpark
}
-func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
+func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
- return types.ActionContinue, errUnsupportedApiName
+ return errUnsupportedApiName
}
p.config.handleRequestHeaders(p, ctx, apiName, log)
- return types.ActionContinue, nil
+ return nil
}
func (p *sparkProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
@@ -177,6 +177,4 @@ func (p *sparkProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName
util.OverwriteRequestPathHeader(headers, sparkChatCompletionPath)
util.OverwriteRequestHostHeader(headers, sparkHost)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+p.config.GetApiTokenInUse(ctx))
- headers.Del("Accept-Encoding")
- headers.Del("Content-Length")
}
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go
index 1ee01abe62..4f642c5f6c 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go
@@ -2,10 +2,11 @@ package provider
import (
"errors"
+ "net/http"
+
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
- "net/http"
)
const (
@@ -39,12 +40,12 @@ func (m *stepfunProvider) GetProviderType() string {
return providerTypeStepfun
}
-func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
+func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
- return types.ActionContinue, errUnsupportedApiName
+ return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
- return types.ActionContinue, nil
+ return nil
}
func (m *stepfunProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go
index 7cb05a9388..e80148ca0c 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go
@@ -40,12 +40,12 @@ func (m *yiProvider) GetProviderType() string {
return providerTypeYi
}
-func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
+func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
- return types.ActionContinue, errUnsupportedApiName
+ return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
- return types.ActionContinue, nil
+ return nil
}
func (m *yiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go
index 40fbe4ef88..9c30adb10d 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go
@@ -40,12 +40,12 @@ func (m *zhipuAiProvider) GetProviderType() string {
return providerTypeZhipuAi
}
-func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
+func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
- return types.ActionContinue, errUnsupportedApiName
+ return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
- return types.ActionContinue, nil
+ return nil
}
func (m *zhipuAiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
diff --git a/plugins/wasm-go/extensions/ai-security-guard/README.md b/plugins/wasm-go/extensions/ai-security-guard/README.md
index 68eeeae202..a005299da3 100644
--- a/plugins/wasm-go/extensions/ai-security-guard/README.md
+++ b/plugins/wasm-go/extensions/ai-security-guard/README.md
@@ -31,6 +31,7 @@ description: 阿里云内容安全检测
| `denyMessage` | string | optional | openai格式的流式/非流式响应 | 指定内容非法时的响应内容 |
| `protocol` | string | optional | openai | 协议格式,非openai协议填`original` |
| `riskLevelBar` | string | optional | high | 拦截风险等级,取值为 max, high, medium, low |
+| `timeout` | int | optional | 2000 | 调用内容安全服务时的超时时间 |
补充说明一下 `denyMessage`,对非法请求的处理逻辑为:
- 如果配置了 `denyMessage`,返回内容为 `denyMessage` 配置内容,格式为openai格式的流式/非流式响应
diff --git a/plugins/wasm-go/extensions/ai-security-guard/main.go b/plugins/wasm-go/extensions/ai-security-guard/main.go
index f4aee5632b..0e0a747fa1 100644
--- a/plugins/wasm-go/extensions/ai-security-guard/main.go
+++ b/plugins/wasm-go/extensions/ai-security-guard/main.go
@@ -53,6 +53,7 @@ const (
DefaultStreamingResponseJsonPath = "choices.0.delta.content"
DefaultDenyCode = 200
DefaultDenyMessage = "很抱歉,我无法回答您的问题"
+ DefaultTimeout = 2000
AliyunUserAgent = "CIPFrom/AIGateway"
LengthLimit = 1800
@@ -100,6 +101,7 @@ type AISecurityConfig struct {
denyMessage string
protocolOriginal bool
riskLevelBar string
+ timeout uint32
metrics map[string]proxywasm.MetricCounter
}
@@ -225,6 +227,11 @@ func parseConfig(json gjson.Result, config *AISecurityConfig, log wrapper.Log) e
} else {
config.riskLevelBar = HighRisk
}
+ if obj := json.Get("timeout"); obj.Exists() {
+ config.timeout = uint32(obj.Int())
+ } else {
+ config.timeout = DefaultTimeout
+ }
config.client = wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: serviceName,
Port: servicePort,
@@ -253,6 +260,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log
func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
log.Debugf("checking request body...")
+ startTime := time.Now().UnixMilli()
content := gjson.GetBytes(body, config.requestContentJsonPath).String()
model := gjson.GetBytes(body, "model").String()
ctx.SetContext("requestModel", model)
@@ -279,6 +287,10 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
}
if riskLevelToInt(response.Data.RiskLevel) < riskLevelToInt(config.riskLevelBar) {
if contentIndex >= len(content) {
+ endTime := time.Now().UnixMilli()
+ ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
+ ctx.SetUserAttribute("safecheck_status", "request pass")
+ ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
proxywasm.ResumeHttpRequest()
} else {
singleCall()
@@ -305,7 +317,9 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
}
ctx.DontReadResponseBody()
config.incrementCounter("ai_sec_request_deny", 1)
- ctx.SetUserAttribute("safecheck_status", "request deny")
+ endTime := time.Now().UnixMilli()
+ ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
+ ctx.SetUserAttribute("safecheck_status", "reqeust deny")
if response.Data.Advice != nil {
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
@@ -345,7 +359,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
reqParams.Add(k, v)
}
reqParams.Add("Signature", signature)
- err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, callback)
+ err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, callback, config.timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
proxywasm.ResumeHttpRequest()
@@ -364,45 +378,26 @@ func convertHeaders(hs [][2]string) map[string][]string {
return ret
}
-// headers: map[string][]string -> [][2]string
-func reconvertHeaders(hs map[string][]string) [][2]string {
- var ret [][2]string
- for k, vs := range hs {
- for _, v := range vs {
- ret = append(ret, [2]string{k, v})
- }
- }
- sort.SliceStable(ret, func(i, j int) bool {
- return ret[i][0] < ret[j][0]
- })
- return ret
-}
-
func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log wrapper.Log) types.Action {
if !config.checkResponse {
log.Debugf("response checking is disabled")
ctx.DontReadResponseBody()
return types.ActionContinue
}
- headers, err := proxywasm.GetHttpResponseHeaders()
- if err != nil {
- log.Warnf("failed to get response headers: %v", err)
- return types.ActionContinue
- }
- hdsMap := convertHeaders(headers)
- if !strings.Contains(strings.Join(hdsMap[":status"], ";"), "200") {
+ statusCode, _ := proxywasm.GetHttpResponseHeader(":status")
+ if statusCode != "200" {
log.Debugf("response is not 200, skip response body check")
ctx.DontReadResponseBody()
return types.ActionContinue
}
- ctx.SetContext("headers", hdsMap)
return types.HeaderStopIteration
}
func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
log.Debugf("checking response body...")
- hdsMap := ctx.GetContext("headers").(map[string][]string)
- isStreamingResponse := strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream")
+ startTime := time.Now().UnixMilli()
+ contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
+ isStreamingResponse := strings.Contains(contentType, "event-stream")
model := ctx.GetStringContext("requestModel", "unknown")
var content string
if isStreamingResponse {
@@ -433,6 +428,10 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
}
if riskLevelToInt(response.Data.RiskLevel) < riskLevelToInt(config.riskLevelBar) {
if contentIndex >= len(content) {
+ endTime := time.Now().UnixMilli()
+ ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime)
+ ctx.SetUserAttribute("safecheck_status", "response pass")
+ ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
proxywasm.ResumeHttpResponse()
} else {
singleCall()
@@ -458,6 +457,8 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
}
config.incrementCounter("ai_sec_response_deny", 1)
+ endTime := time.Now().UnixMilli()
+ ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "response deny")
if response.Data.Advice != nil {
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
@@ -498,7 +499,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
reqParams.Add(k, v)
}
reqParams.Add("Signature", signature)
- err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, callback)
+ err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, callback, config.timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
proxywasm.ResumeHttpResponse()
diff --git a/plugins/wasm-go/extensions/ai-statistics/go.sum b/plugins/wasm-go/extensions/ai-statistics/go.sum
index 6b1c2c3cd7..b4ab172fe2 100644
--- a/plugins/wasm-go/extensions/ai-statistics/go.sum
+++ b/plugins/wasm-go/extensions/ai-statistics/go.sum
@@ -3,15 +3,13 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
-github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
-github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
+github.com/higress-group/proxy-wasm-go-sdk v1.0.0 h1:BZRNf4R7jr9hwRivg/E29nkVaKEak5MWjBDhWjuHijU=
github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
-github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
-github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
+github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94=
github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
diff --git a/plugins/wasm-go/extensions/ai-statistics/main.go b/plugins/wasm-go/extensions/ai-statistics/main.go
index 14fcc4d2ab..363f59194e 100644
--- a/plugins/wasm-go/extensions/ai-statistics/main.go
+++ b/plugins/wasm-go/extensions/ai-statistics/main.go
@@ -5,7 +5,6 @@ import (
"encoding/json"
"errors"
"fmt"
- "strconv"
"strings"
"time"
@@ -28,14 +27,15 @@ func main() {
}
const (
- // Trace span prefix
- TracePrefix = "trace_span_tag."
// Context consts
StatisticsRequestStartTime = "ai-statistics-request-start-time"
StatisticsFirstTokenTime = "ai-statistics-first-token-time"
CtxGeneralAtrribute = "attributes"
CtxLogAtrribute = "logAttributes"
CtxStreamingBodyBuffer = "streamingBodyBuffer"
+ RouteName = "route"
+ ClusterName = "cluster"
+ APIName = "api"
// Source Type
FixedValue = "fixed_value"
@@ -46,12 +46,14 @@ const (
ResponseBody = "response_body"
// Inner metric & log attributes name
- Model = "model"
- InputToken = "input_token"
- OutputToken = "output_token"
- LLMFirstTokenDuration = "llm_first_token_duration"
- LLMServiceDuration = "llm_service_duration"
- LLMDurationCount = "llm_duration_count"
+ Model = "model"
+ InputToken = "input_token"
+ OutputToken = "output_token"
+ LLMFirstTokenDuration = "llm_first_token_duration"
+ LLMServiceDuration = "llm_service_duration"
+ LLMDurationCount = "llm_duration_count"
+ LLMStreamDurationCount = "llm_stream_duration_count"
+ ResponseType = "response_type"
// Extract Rule
RuleFirst = "first"
@@ -91,6 +93,19 @@ func getRouteName() (string, error) {
}
}
+func getAPIName() (string, error) {
+ if raw, err := proxywasm.GetProperty([]string{"route_name"}); err != nil {
+ return "-", err
+ } else {
+ parts := strings.Split(string(raw), "@")
+ if len(parts) != 5 {
+ return "-", errors.New("not api type")
+ } else {
+ return strings.Join(parts[:3], "@"), nil
+ }
+ }
+}
+
func getClusterName() (string, error) {
if raw, err := proxywasm.GetProperty([]string{"cluster_name"}); err != nil {
return "-", err
@@ -133,8 +148,15 @@ func parseConfig(configJson gjson.Result, config *AIStatisticsConfig, log wrappe
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper.Log) types.Action {
- ctx.SetContext(CtxGeneralAtrribute, map[string]string{})
- ctx.SetContext(CtxLogAtrribute, map[string]string{})
+ route, _ := getRouteName()
+ cluster, _ := getClusterName()
+ api, api_error := getAPIName()
+ if api_error == nil {
+ route = api
+ }
+ ctx.SetContext(RouteName, route)
+ ctx.SetContext(ClusterName, cluster)
+ ctx.SetUserAttribute(APIName, api)
ctx.SetContext(StatisticsRequestStartTime, time.Now().UnixMilli())
// Set user defined log & span attributes which type is fixed_value
@@ -149,6 +171,9 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, lo
func onHttpRequestBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body []byte, log wrapper.Log) types.Action {
// Set user defined log & span attributes.
setAttributeBySource(ctx, config, RequestBody, body, log)
+
+ // Write log
+ ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
return types.ActionContinue
}
@@ -177,6 +202,8 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat
ctx.SetContext(CtxStreamingBodyBuffer, streamingBodyBuffer)
}
+ ctx.SetUserAttribute(ResponseType, "stream")
+
// Get requestStartTime from http context
requestStartTime, ok := ctx.GetContext(StatisticsRequestStartTime).(int64)
if !ok {
@@ -188,28 +215,19 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat
if ctx.GetContext(StatisticsFirstTokenTime) == nil {
firstTokenTime := time.Now().UnixMilli()
ctx.SetContext(StatisticsFirstTokenTime, firstTokenTime)
- attributes, _ := ctx.GetContext(CtxGeneralAtrribute).(map[string]string)
- attributes[LLMFirstTokenDuration] = fmt.Sprint(firstTokenTime - requestStartTime)
- ctx.SetContext(CtxGeneralAtrribute, attributes)
+ ctx.SetUserAttribute(LLMFirstTokenDuration, firstTokenTime-requestStartTime)
}
// Set information about this request
-
if model, inputToken, outputToken, ok := getUsage(data); ok {
- attributes, _ := ctx.GetContext(CtxGeneralAtrribute).(map[string]string)
- // Record Log Attributes
- attributes[Model] = model
- attributes[InputToken] = fmt.Sprint(inputToken)
- attributes[OutputToken] = fmt.Sprint(outputToken)
- // Set attributes to http context
- ctx.SetContext(CtxGeneralAtrribute, attributes)
+ ctx.SetUserAttribute(Model, model)
+ ctx.SetUserAttribute(InputToken, inputToken)
+ ctx.SetUserAttribute(OutputToken, outputToken)
}
// If the end of the stream is reached, record metrics/logs/spans.
if endOfStream {
responseEndTime := time.Now().UnixMilli()
- attributes, _ := ctx.GetContext(CtxGeneralAtrribute).(map[string]string)
- attributes[LLMServiceDuration] = fmt.Sprint(responseEndTime - requestStartTime)
- ctx.SetContext(CtxGeneralAtrribute, attributes)
+ ctx.SetUserAttribute(LLMServiceDuration, responseEndTime-requestStartTime)
// Set user defined log & span attributes.
if config.shouldBufferStreamingBody {
@@ -220,11 +238,8 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat
setAttributeBySource(ctx, config, ResponseStreamingBody, streamingBodyBuffer, log)
}
- // Write inner filter states which can be used by other plugins such as ai-token-ratelimit
- writeFilterStates(ctx, log)
-
// Write log
- writeLog(ctx, log)
+ ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
// Write metrics
writeMetric(ctx, config, log)
@@ -233,33 +248,26 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat
}
func onHttpResponseBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body []byte, log wrapper.Log) types.Action {
- // Get attributes from http context
- attributes, _ := ctx.GetContext(CtxGeneralAtrribute).(map[string]string)
-
// Get requestStartTime from http context
requestStartTime, _ := ctx.GetContext(StatisticsRequestStartTime).(int64)
responseEndTime := time.Now().UnixMilli()
- attributes[LLMServiceDuration] = fmt.Sprint(responseEndTime - requestStartTime)
+ ctx.SetUserAttribute(LLMServiceDuration, responseEndTime-requestStartTime)
+
+ ctx.SetUserAttribute(ResponseType, "normal")
// Set information about this request
- model, inputToken, outputToken, ok := getUsage(body)
- if ok {
- attributes[Model] = model
- attributes[InputToken] = fmt.Sprint(inputToken)
- attributes[OutputToken] = fmt.Sprint(outputToken)
- // Update attributes
- ctx.SetContext(CtxGeneralAtrribute, attributes)
+ if model, inputToken, outputToken, ok := getUsage(body); ok {
+ ctx.SetUserAttribute(Model, model)
+ ctx.SetUserAttribute(InputToken, inputToken)
+ ctx.SetUserAttribute(OutputToken, outputToken)
}
// Set user defined log & span attributes.
setAttributeBySource(ctx, config, ResponseBody, body, log)
- // Write inner filter states which can be used by other plugins such as ai-token-ratelimit
- writeFilterStates(ctx, log)
-
// Write log
- writeLog(ctx, log)
+ ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
// Write metrics
writeMetric(ctx, config, log)
@@ -294,67 +302,49 @@ func getUsage(data []byte) (model string, inputTokenUsage int64, outputTokenUsag
// fetches the tracing span value from the specified source.
func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, source string, body []byte, log wrapper.Log) {
- attributes, ok := ctx.GetContext(CtxGeneralAtrribute).(map[string]string)
- if !ok {
- log.Error("failed to get attributes from http context")
- return
- }
for _, attribute := range config.attributes {
+ var key string
+ var value interface{}
if source == attribute.ValueSource {
+ key = attribute.Key
switch source {
case FixedValue:
- log.Debugf("[attribute] source type: %s, key: %s, value: %s", source, attribute.Key, attribute.Value)
- attributes[attribute.Key] = attribute.Value
+ value = attribute.Value
case RequestHeader:
- if value, err := proxywasm.GetHttpRequestHeader(attribute.Value); err == nil {
- log.Debugf("[attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value)
- attributes[attribute.Key] = value
- }
+ value, _ = proxywasm.GetHttpRequestHeader(attribute.Value)
case RequestBody:
- raw := gjson.GetBytes(body, attribute.Value).Raw
- var value string
- if len(raw) > 2 {
- value = raw[1 : len(raw)-1]
- }
- log.Debugf("[attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value)
- attributes[attribute.Key] = value
+ value = gjson.GetBytes(body, attribute.Value).Value()
case ResponseHeader:
- if value, err := proxywasm.GetHttpResponseHeader(attribute.Value); err == nil {
- log.Debugf("[log attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value)
- attributes[attribute.Key] = value
- }
+ value, _ = proxywasm.GetHttpResponseHeader(attribute.Value)
case ResponseStreamingBody:
- value := extractStreamingBodyByJsonPath(body, attribute.Value, attribute.Rule, log)
- log.Debugf("[log attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value)
- attributes[attribute.Key] = value
+ value = extractStreamingBodyByJsonPath(body, attribute.Value, attribute.Rule, log)
case ResponseBody:
- value := gjson.GetBytes(body, attribute.Value).Raw
- if len(value) > 2 && value[0] == '"' && value[len(value)-1] == '"' {
- value = value[1 : len(value)-1]
- }
- log.Debugf("[log attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value)
- attributes[attribute.Key] = value
+ value = gjson.GetBytes(body, attribute.Value).Value()
default:
}
- }
- if attribute.ApplyToLog {
- setLogAttribute(ctx, attribute.Key, attributes[attribute.Key], log)
- }
- if attribute.ApplyToSpan {
- setSpanAttribute(attribute.Key, attributes[attribute.Key], log)
+ log.Debugf("[attribute] source type: %s, key: %s, value: %+v", source, key, value)
+ if attribute.ApplyToLog {
+ ctx.SetUserAttribute(key, value)
+ }
+ // for metrics
+ if key == Model || key == InputToken || key == OutputToken {
+ ctx.SetContext(key, value)
+ }
+ if attribute.ApplyToSpan {
+ setSpanAttribute(key, value, log)
+ }
}
}
- ctx.SetContext(CtxGeneralAtrribute, attributes)
}
-func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, log wrapper.Log) string {
+func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, log wrapper.Log) interface{} {
chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n"))
- var value string
+ var value interface{}
if rule == RuleFirst {
for _, chunk := range chunks {
jsonObj := gjson.GetBytes(chunk, jsonPath)
if jsonObj.Exists() {
- value = jsonObj.String()
+ value = jsonObj.Value()
break
}
}
@@ -362,140 +352,116 @@ func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, l
for _, chunk := range chunks {
jsonObj := gjson.GetBytes(chunk, jsonPath)
if jsonObj.Exists() {
- value = jsonObj.String()
+ value = jsonObj.Value()
}
}
} else if rule == RuleAppend {
// extract llm response
+ var strValue string
for _, chunk := range chunks {
- raw := gjson.GetBytes(chunk, jsonPath).Raw
- if len(raw) > 2 && raw[0] == '"' && raw[len(raw)-1] == '"' {
- value += raw[1 : len(raw)-1]
+ jsonObj := gjson.GetBytes(chunk, jsonPath)
+ if jsonObj.Exists() {
+ strValue += jsonObj.String()
}
}
+ value = strValue
} else {
log.Errorf("unsupported rule type: %s", rule)
}
return value
}
-func setFilterState(key, value string, log wrapper.Log) {
- if value != "" {
- if e := proxywasm.SetProperty([]string{key}, []byte(fmt.Sprint(value))); e != nil {
- log.Errorf("failed to set %s in filter state: %v", key, e)
- }
- } else {
- log.Debugf("failed to write filter state [%s], because it's value is empty")
- }
-}
-
// Set the tracing span with value.
-func setSpanAttribute(key, value string, log wrapper.Log) {
+func setSpanAttribute(key string, value interface{}, log wrapper.Log) {
if value != "" {
- traceSpanTag := TracePrefix + key
- if e := proxywasm.SetProperty([]string{traceSpanTag}, []byte(value)); e != nil {
- log.Errorf("failed to set %s in filter state: %v", traceSpanTag, e)
+ traceSpanTag := wrapper.TraceSpanTagPrefix + key
+ if e := proxywasm.SetProperty([]string{traceSpanTag}, []byte(fmt.Sprint(value))); e != nil {
+ log.Warnf("failed to set %s in filter state: %v", traceSpanTag, e)
}
} else {
log.Debugf("failed to write span attribute [%s], because it's value is empty")
}
}
-// fetches the tracing span value from the specified source.
-func setLogAttribute(ctx wrapper.HttpContext, key string, value interface{}, log wrapper.Log) {
- logAttributes, ok := ctx.GetContext(CtxLogAtrribute).(map[string]string)
+func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper.Log) {
+ // Generate usage metrics
+ var ok bool
+ var route, cluster, model string
+ var inputToken, outputToken uint64
+ route, ok = ctx.GetContext(RouteName).(string)
if !ok {
- log.Error("failed to get logAttributes from http context")
+ log.Warnf("RouteName typd assert failed, skip metric record")
return
}
- logAttributes[key] = fmt.Sprint(value)
- ctx.SetContext(CtxLogAtrribute, logAttributes)
-}
-
-func writeFilterStates(ctx wrapper.HttpContext, log wrapper.Log) {
- attributes, _ := ctx.GetContext(CtxGeneralAtrribute).(map[string]string)
- setFilterState(Model, attributes[Model], log)
- setFilterState(InputToken, attributes[InputToken], log)
- setFilterState(OutputToken, attributes[OutputToken], log)
-}
-
-func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper.Log) {
- attributes, _ := ctx.GetContext(CtxGeneralAtrribute).(map[string]string)
- route, _ := getRouteName()
- cluster, _ := getClusterName()
- model, ok := attributes["model"]
+ cluster, ok = ctx.GetContext(ClusterName).(string)
if !ok {
- log.Errorf("Get model failed")
+ log.Warnf("ClusterName typd assert failed, skip metric record")
return
}
- if inputToken, ok := attributes[InputToken]; ok {
- inputTokenUint64, err := strconv.ParseUint(inputToken, 10, 0)
- if err != nil || inputTokenUint64 == 0 {
- log.Errorf("inputToken convert failed, value is %d, err msg is [%v]", inputTokenUint64, err)
- return
- }
- config.incrementCounter(generateMetricName(route, cluster, model, InputToken), inputTokenUint64)
+ if ctx.GetUserAttribute(Model) == nil || ctx.GetUserAttribute(InputToken) == nil || ctx.GetUserAttribute(OutputToken) == nil {
+ log.Warnf("get usage information failed, skip metric record")
+ return
}
- if outputToken, ok := attributes[OutputToken]; ok {
- outputTokenUint64, err := strconv.ParseUint(outputToken, 10, 0)
- if err != nil || outputTokenUint64 == 0 {
- log.Errorf("outputToken convert failed, value is %d, err msg is [%v]", outputTokenUint64, err)
- return
- }
- config.incrementCounter(generateMetricName(route, cluster, model, OutputToken), outputTokenUint64)
+ model, ok = ctx.GetUserAttribute(Model).(string)
+ if !ok {
+ log.Warnf("Model typd assert failed, skip metric record")
+ return
}
- if llmFirstTokenDuration, ok := attributes[LLMFirstTokenDuration]; ok {
- llmFirstTokenDurationUint64, err := strconv.ParseUint(llmFirstTokenDuration, 10, 0)
- if err != nil || llmFirstTokenDurationUint64 == 0 {
- log.Errorf("llmFirstTokenDuration convert failed, value is %d, err msg is [%v]", llmFirstTokenDurationUint64, err)
+ inputToken, ok = convertToUInt(ctx.GetUserAttribute(InputToken))
+ if !ok {
+ log.Warnf("InputToken typd assert failed, skip metric record")
+ return
+ }
+ outputToken, ok = convertToUInt(ctx.GetUserAttribute(OutputToken))
+ if !ok {
+ log.Warnf("OutputToken typd assert failed, skip metric record")
+ return
+ }
+ if inputToken == 0 || outputToken == 0 {
+ log.Warnf("inputToken and outputToken cannot equal to 0, skip metric record")
+ return
+ }
+ config.incrementCounter(generateMetricName(route, cluster, model, InputToken), inputToken)
+ config.incrementCounter(generateMetricName(route, cluster, model, OutputToken), outputToken)
+
+ // Generate duration metrics
+ var llmFirstTokenDuration, llmServiceDuration uint64
+ // Is stream response
+ if ctx.GetUserAttribute(LLMFirstTokenDuration) != nil {
+ llmFirstTokenDuration, ok = convertToUInt(ctx.GetUserAttribute(LLMFirstTokenDuration))
+ if !ok {
+ log.Warnf("LLMFirstTokenDuration typd assert failed")
return
}
- config.incrementCounter(generateMetricName(route, cluster, model, LLMFirstTokenDuration), llmFirstTokenDurationUint64)
+ config.incrementCounter(generateMetricName(route, cluster, model, LLMFirstTokenDuration), llmFirstTokenDuration)
+ config.incrementCounter(generateMetricName(route, cluster, model, LLMStreamDurationCount), 1)
}
- if llmServiceDuration, ok := attributes[LLMServiceDuration]; ok {
- llmServiceDurationUint64, err := strconv.ParseUint(llmServiceDuration, 10, 0)
- if err != nil || llmServiceDurationUint64 == 0 {
- log.Errorf("llmServiceDuration convert failed, value is %d, err msg is [%v]", llmServiceDurationUint64, err)
+ if ctx.GetUserAttribute(LLMServiceDuration) != nil {
+ llmServiceDuration, ok = convertToUInt(ctx.GetUserAttribute(LLMServiceDuration))
+ if !ok {
+ log.Warnf("LLMServiceDuration typd assert failed")
return
}
- config.incrementCounter(generateMetricName(route, cluster, model, LLMServiceDuration), llmServiceDurationUint64)
+ config.incrementCounter(generateMetricName(route, cluster, model, LLMServiceDuration), llmServiceDuration)
+ config.incrementCounter(generateMetricName(route, cluster, model, LLMDurationCount), 1)
}
- config.incrementCounter(generateMetricName(route, cluster, model, LLMDurationCount), 1)
}
-func writeLog(ctx wrapper.HttpContext, log wrapper.Log) {
- attributes, _ := ctx.GetContext(CtxGeneralAtrribute).(map[string]string)
- logAttributes, _ := ctx.GetContext(CtxLogAtrribute).(map[string]string)
- // Set inner log fields
- if attributes[Model] != "" {
- logAttributes[Model] = attributes[Model]
- }
- if attributes[InputToken] != "" {
- logAttributes[InputToken] = attributes[InputToken]
- }
- if attributes[OutputToken] != "" {
- logAttributes[OutputToken] = attributes[OutputToken]
- }
- if attributes[LLMFirstTokenDuration] != "" {
- logAttributes[LLMFirstTokenDuration] = attributes[LLMFirstTokenDuration]
- }
- if attributes[LLMServiceDuration] != "" {
- logAttributes[LLMServiceDuration] = attributes[LLMServiceDuration]
- }
- // Traverse log fields
- items := []string{}
- for k, v := range logAttributes {
- items = append(items, fmt.Sprintf(`"%s":"%s"`, k, v))
- }
- aiLogField := fmt.Sprintf(`{%s}`, strings.Join(items, ","))
- // log.Infof("ai request json log: %s", aiLogField)
- jsonMap := map[string]string{
- "ai_log": aiLogField,
- }
- serialized, _ := json.Marshal(jsonMap)
- jsonLogRaw := gjson.GetBytes(serialized, "ai_log").Raw
- jsonLog := jsonLogRaw[1 : len(jsonLogRaw)-1]
- if err := proxywasm.SetProperty([]string{"ai_log"}, []byte(jsonLog)); err != nil {
- log.Errorf("failed to set ai_log in filter state: %v", err)
+func convertToUInt(val interface{}) (uint64, bool) {
+ switch v := val.(type) {
+ case float32:
+ return uint64(v), true
+ case float64:
+ return uint64(v), true
+ case int32:
+ return uint64(v), true
+ case int64:
+ return uint64(v), true
+ case uint32:
+ return uint64(v), true
+ case uint64:
+ return v, true
+ default:
+ return 0, false
}
}
diff --git a/plugins/wasm-go/extensions/ai-token-ratelimit/go.sum b/plugins/wasm-go/extensions/ai-token-ratelimit/go.sum
index 4bc7bb7527..7b8c22894a 100644
--- a/plugins/wasm-go/extensions/ai-token-ratelimit/go.sum
+++ b/plugins/wasm-go/extensions/ai-token-ratelimit/go.sum
@@ -5,8 +5,7 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
-github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
-github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
+github.com/higress-group/proxy-wasm-go-sdk v1.0.0 h1:BZRNf4R7jr9hwRivg/E29nkVaKEak5MWjBDhWjuHijU=
github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
@@ -14,8 +13,7 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/tetratelabs/wazero v1.7.1 h1:QtSfd6KLc41DIMpDYlJdoMc6k7QTN246DM2+n2Y/Dx8=
github.com/tetratelabs/wazero v1.7.1/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
-github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
-github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
+github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94=
github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
diff --git a/plugins/wasm-go/extensions/ai-token-ratelimit/main.go b/plugins/wasm-go/extensions/ai-token-ratelimit/main.go
index afe463a12d..6877ae5c23 100644
--- a/plugins/wasm-go/extensions/ai-token-ratelimit/main.go
+++ b/plugins/wasm-go/extensions/ai-token-ratelimit/main.go
@@ -15,6 +15,7 @@
package main
import (
+ "bytes"
"fmt"
"net"
"net/url"
@@ -61,9 +62,9 @@ const (
ConsumerHeader string = "x-mse-consumer" // LimitByConsumer从该request header获取consumer的名字
CookieHeader string = "cookie"
- RateLimitLimitHeader string = "X-RateLimit-Limit" // 限制的总请求数
- RateLimitRemainingHeader string = "X-RateLimit-Remaining" // 剩余还可以发送的请求数
- RateLimitResetHeader string = "X-RateLimit-Reset" // 限流重置时间(触发限流时返回)
+ RateLimitLimitHeader string = "X-TokenRateLimit-Limit" // 限制的总请求数
+ RateLimitRemainingHeader string = "X-TokenRateLimit-Remaining" // 剩余还可以发送的请求数
+ RateLimitResetHeader string = "X-TokenRateLimit-Reset" // 限流重置时间(触发限流时返回)
)
type LimitContext struct {
@@ -124,6 +125,8 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitCon
}
if context.remaining < 0 {
// 触发限流
+ ctx.SetUserAttribute("token_ratelimit_status", "limited")
+ ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
rejected(config, context)
} else {
proxywasm.ResumeHttpRequest()
@@ -137,39 +140,49 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitCon
}
func onHttpStreamingBody(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, data []byte, endOfStream bool, log wrapper.Log) []byte {
- if !endOfStream {
- return data
+ var inputToken, outputToken int64
+ if inputToken, outputToken, ok := getUsage(data); ok {
+ ctx.SetContext("input_token", inputToken)
+ ctx.SetContext("output_token", outputToken)
}
- inputTokenStr, err := proxywasm.GetProperty([]string{"filter_state", "wasm.input_token"})
- if err != nil {
- return data
- }
- outputTokenStr, err := proxywasm.GetProperty([]string{"filter_state", "wasm.output_token"})
- if err != nil {
- return data
- }
- inputToken, err := strconv.Atoi(string(inputTokenStr))
- if err != nil {
- return data
- }
- outputToken, err := strconv.Atoi(string(outputTokenStr))
- if err != nil {
- return data
- }
- limitRedisContext, ok := ctx.GetContext(LimitRedisContextKey).(LimitRedisContext)
- if !ok {
- return data
+ if endOfStream {
+ if ctx.GetContext("input_token") == nil || ctx.GetContext("output_token") == nil {
+ return data
+ }
+ inputToken = ctx.GetContext("input_token").(int64)
+ outputToken = ctx.GetContext("output_token").(int64)
+ limitRedisContext, ok := ctx.GetContext(LimitRedisContextKey).(LimitRedisContext)
+ if !ok {
+ return data
+ }
+ keys := []interface{}{limitRedisContext.key}
+ args := []interface{}{limitRedisContext.count, limitRedisContext.window, inputToken + outputToken}
+ err := config.redisClient.Eval(ResponsePhaseFixedWindowScript, 1, keys, args, nil)
+ if err != nil {
+ log.Errorf("redis call failed: %v", err)
+ }
}
- keys := []interface{}{limitRedisContext.key}
- args := []interface{}{limitRedisContext.count, limitRedisContext.window, inputToken + outputToken}
+ return data
+}
- err = config.redisClient.Eval(ResponsePhaseFixedWindowScript, 1, keys, args, nil)
- if err != nil {
- log.Errorf("redis call failed: %v", err)
- return data
- } else {
- return data
+func getUsage(data []byte) (inputTokenUsage int64, outputTokenUsage int64, ok bool) {
+ chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n"))
+ for _, chunk := range chunks {
+ // the feature strings are used to identify the usage data, like:
+ // {"model":"gpt2","usage":{"prompt_tokens":1,"completion_tokens":1}}
+ if !bytes.Contains(chunk, []byte("prompt_tokens")) || !bytes.Contains(chunk, []byte("completion_tokens")) {
+ continue
+ }
+ inputTokenObj := gjson.GetBytes(chunk, "usage.prompt_tokens")
+ outputTokenObj := gjson.GetBytes(chunk, "usage.completion_tokens")
+ if inputTokenObj.Exists() && outputTokenObj.Exists() {
+ inputTokenUsage = inputTokenObj.Int()
+ outputTokenUsage = outputTokenObj.Int()
+ ok = true
+ return
+ }
}
+ return
}
func checkRequestAgainstLimitRule(ctx wrapper.HttpContext, ruleItems []LimitRuleItem, log wrapper.Log) (string, *LimitRuleItem, *LimitConfigItem) {
diff --git a/plugins/wasm-go/pkg/wrapper/plugin_wrapper.go b/plugins/wasm-go/pkg/wrapper/plugin_wrapper.go
index be9144adfc..8b342d57b5 100644
--- a/plugins/wasm-go/pkg/wrapper/plugin_wrapper.go
+++ b/plugins/wasm-go/pkg/wrapper/plugin_wrapper.go
@@ -45,6 +45,8 @@ type HttpContext interface {
GetStringContext(key, defaultValue string) string
GetUserAttribute(key string) interface{}
SetUserAttribute(key string, value interface{})
+ SetUserAttributeMap(kvmap map[string]interface{})
+ GetUserAttributeMap() map[string]interface{}
// You can call this function to set custom log
WriteUserAttributeToLog() error
// You can call this function to set custom log with your specific key
@@ -403,6 +405,14 @@ func (ctx *CommonHttpCtx[PluginConfig]) GetUserAttribute(key string) interface{}
return ctx.userAttribute[key]
}
+func (ctx *CommonHttpCtx[PluginConfig]) SetUserAttributeMap(kvmap map[string]interface{}) {
+ ctx.userAttribute = kvmap
+}
+
+func (ctx *CommonHttpCtx[PluginConfig]) GetUserAttributeMap() map[string]interface{} {
+ return ctx.userAttribute
+}
+
func (ctx *CommonHttpCtx[PluginConfig]) WriteUserAttributeToLog() error {
return ctx.WriteUserAttributeToLogWithKey(CustomLogKey)
}
diff --git a/plugins/wasm-go/pkg/wrapper/redis_wrapper.go b/plugins/wasm-go/pkg/wrapper/redis_wrapper.go
index c619c3e191..f4b42e67e7 100644
--- a/plugins/wasm-go/pkg/wrapper/redis_wrapper.go
+++ b/plugins/wasm-go/pkg/wrapper/redis_wrapper.go
@@ -17,6 +17,7 @@ package wrapper
import (
"bytes"
"encoding/base64"
+ "errors"
"fmt"
"io"
@@ -28,7 +29,7 @@ import (
type RedisResponseCallback func(response resp.Value)
type RedisClient interface {
- Init(username, password string, timeout int64) error
+ Init(username, password string, timeout int64, opts ...optionFunc) error
// with this function, you can call redis as if you are using redis-cli
Command(cmds []interface{}, callback RedisResponseCallback) error
Eval(script string, numkeys int, keys, args []interface{}, callback RedisResponseCallback) error
@@ -103,15 +104,31 @@ type RedisClient interface {
}
type RedisClusterClient[C Cluster] struct {
- cluster C
+ cluster C
+ ready bool
+ checkReadyFunc func() error
+ option redisOption
}
-func NewRedisClusterClient[C Cluster](cluster C) *RedisClusterClient[C] {
- return &RedisClusterClient[C]{cluster: cluster}
+type redisOption struct {
+ dataBase int
}
-func RedisInit(cluster Cluster, username, password string, timeout uint32) error {
- return proxywasm.RedisInit(cluster.ClusterName(), username, password, timeout)
+type optionFunc func(*redisOption)
+
+func WithDataBase(dataBase int) optionFunc {
+ return func(o *redisOption) {
+ o.dataBase = dataBase
+ }
+}
+
+func NewRedisClusterClient[C Cluster](cluster C) *RedisClusterClient[C] {
+ return &RedisClusterClient[C]{
+ cluster: cluster,
+ checkReadyFunc: func() error {
+ return errors.New("redis client is not ready, please call Init() first")
+ },
+ }
}
func RedisCall(cluster Cluster, respQuery []byte, callback RedisResponseCallback) error {
@@ -165,19 +182,46 @@ func respString(args []interface{}) []byte {
return buf.Bytes()
}
-func (c RedisClusterClient[C]) Init(username, password string, timeout int64) error {
- err := RedisInit(c.cluster, username, password, uint32(timeout))
+func (c *RedisClusterClient[C]) Init(username, password string, timeout int64, opts ...optionFunc) error {
+ for _, opt := range opts {
+ opt(&c.option)
+ }
+ clusterName := c.cluster.ClusterName()
+ if c.option.dataBase != 0 {
+ clusterName = fmt.Sprintf("%s?db=%d", clusterName, c.option.dataBase)
+ }
+ err := proxywasm.RedisInit(clusterName, username, password, uint32(timeout))
if err != nil {
- proxywasm.LogCriticalf("failed to init redis: %v", err)
+ c.checkReadyFunc = func() error {
+ if c.ready {
+ return nil
+ }
+ initErr := proxywasm.RedisInit(clusterName, username, password, uint32(timeout))
+ if initErr != nil {
+ return initErr
+ }
+ c.ready = true
+ return nil
+ }
+ proxywasm.LogWarnf("failed to init redis: %v, will retry after", err)
+ return nil
}
- return err
+ c.checkReadyFunc = func() error { return nil }
+ c.ready = true
+ return nil
}
-func (c RedisClusterClient[C]) Command(cmds []interface{}, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) Command(cmds []interface{}, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
return RedisCall(c.cluster, respString(cmds), callback)
}
-func (c RedisClusterClient[C]) Eval(script string, numkeys int, keys, args []interface{}, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) Eval(script string, numkeys int, keys, args []interface{}, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
params := make([]interface{}, 0)
params = append(params, "eval")
params = append(params, script)
@@ -188,21 +232,30 @@ func (c RedisClusterClient[C]) Eval(script string, numkeys int, keys, args []int
}
// Key
-func (c RedisClusterClient[C]) Del(key string, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) Del(key string, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "del")
args = append(args, key)
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) Exists(key string, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) Exists(key string, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "exists")
args = append(args, key)
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) Expire(key string, ttl int, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) Expire(key string, ttl int, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "expire")
args = append(args, key)
@@ -210,7 +263,10 @@ func (c RedisClusterClient[C]) Expire(key string, ttl int, callback RedisRespons
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) Persist(key string, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) Persist(key string, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "persist")
args = append(args, key)
@@ -218,14 +274,20 @@ func (c RedisClusterClient[C]) Persist(key string, callback RedisResponseCallbac
}
// String
-func (c RedisClusterClient[C]) Get(key string, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) Get(key string, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "get")
args = append(args, key)
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) Set(key string, value interface{}, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) Set(key string, value interface{}, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "set")
args = append(args, key)
@@ -233,7 +295,10 @@ func (c RedisClusterClient[C]) Set(key string, value interface{}, callback Redis
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) SetEx(key string, value interface{}, ttl int, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) SetEx(key string, value interface{}, ttl int, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "set")
args = append(args, key)
@@ -243,7 +308,10 @@ func (c RedisClusterClient[C]) SetEx(key string, value interface{}, ttl int, cal
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) MGet(keys []string, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) MGet(keys []string, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "mget")
for _, k := range keys {
@@ -252,7 +320,10 @@ func (c RedisClusterClient[C]) MGet(keys []string, callback RedisResponseCallbac
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) MSet(kvMap map[string]interface{}, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) MSet(kvMap map[string]interface{}, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "mset")
for k, v := range kvMap {
@@ -262,21 +333,30 @@ func (c RedisClusterClient[C]) MSet(kvMap map[string]interface{}, callback Redis
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) Incr(key string, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) Incr(key string, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "incr")
args = append(args, key)
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) Decr(key string, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) Decr(key string, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "decr")
args = append(args, key)
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) IncrBy(key string, delta int, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) IncrBy(key string, delta int, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "incrby")
args = append(args, key)
@@ -284,7 +364,10 @@ func (c RedisClusterClient[C]) IncrBy(key string, delta int, callback RedisRespo
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) DecrBy(key string, delta int, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) DecrBy(key string, delta int, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "decrby")
args = append(args, key)
@@ -293,14 +376,20 @@ func (c RedisClusterClient[C]) DecrBy(key string, delta int, callback RedisRespo
}
// List
-func (c RedisClusterClient[C]) LLen(key string, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) LLen(key string, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "llen")
args = append(args, key)
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) RPush(key string, vals []interface{}, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) RPush(key string, vals []interface{}, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "rpush")
args = append(args, key)
@@ -310,14 +399,20 @@ func (c RedisClusterClient[C]) RPush(key string, vals []interface{}, callback Re
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) RPop(key string, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) RPop(key string, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "rpop")
args = append(args, key)
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) LPush(key string, vals []interface{}, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) LPush(key string, vals []interface{}, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "lpush")
args = append(args, key)
@@ -327,14 +422,20 @@ func (c RedisClusterClient[C]) LPush(key string, vals []interface{}, callback Re
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) LPop(key string, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) LPop(key string, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "lpop")
args = append(args, key)
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) LIndex(key string, index int, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) LIndex(key string, index int, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "lindex")
args = append(args, key)
@@ -342,7 +443,10 @@ func (c RedisClusterClient[C]) LIndex(key string, index int, callback RedisRespo
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) LRange(key string, start, stop int, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) LRange(key string, start, stop int, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "lrange")
args = append(args, key)
@@ -351,7 +455,10 @@ func (c RedisClusterClient[C]) LRange(key string, start, stop int, callback Redi
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) LRem(key string, count int, value interface{}, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) LRem(key string, count int, value interface{}, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "lrem")
args = append(args, key)
@@ -360,7 +467,10 @@ func (c RedisClusterClient[C]) LRem(key string, count int, value interface{}, ca
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) LInsertBefore(key string, pivot, value interface{}, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) LInsertBefore(key string, pivot, value interface{}, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "linsert")
args = append(args, key)
@@ -370,7 +480,10 @@ func (c RedisClusterClient[C]) LInsertBefore(key string, pivot, value interface{
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) LInsertAfter(key string, pivot, value interface{}, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) LInsertAfter(key string, pivot, value interface{}, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "linsert")
args = append(args, key)
@@ -381,7 +494,10 @@ func (c RedisClusterClient[C]) LInsertAfter(key string, pivot, value interface{}
}
// Hash
-func (c RedisClusterClient[C]) HExists(key, field string, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) HExists(key, field string, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "hexists")
args = append(args, key)
@@ -389,7 +505,10 @@ func (c RedisClusterClient[C]) HExists(key, field string, callback RedisResponse
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) HDel(key string, fields []string, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) HDel(key string, fields []string, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "hdel")
args = append(args, key)
@@ -399,14 +518,20 @@ func (c RedisClusterClient[C]) HDel(key string, fields []string, callback RedisR
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) HLen(key string, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) HLen(key string, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "hlen")
args = append(args, key)
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) HGet(key, field string, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) HGet(key, field string, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "hget")
args = append(args, key)
@@ -414,7 +539,10 @@ func (c RedisClusterClient[C]) HGet(key, field string, callback RedisResponseCal
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) HSet(key, field string, value interface{}, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) HSet(key, field string, value interface{}, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "hset")
args = append(args, key)
@@ -423,7 +551,10 @@ func (c RedisClusterClient[C]) HSet(key, field string, value interface{}, callba
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) HMGet(key string, fields []string, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) HMGet(key string, fields []string, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "hmget")
args = append(args, key)
@@ -433,7 +564,10 @@ func (c RedisClusterClient[C]) HMGet(key string, fields []string, callback Redis
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) HMSet(key string, kvMap map[string]interface{}, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) HMSet(key string, kvMap map[string]interface{}, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "hmset")
args = append(args, key)
@@ -444,28 +578,40 @@ func (c RedisClusterClient[C]) HMSet(key string, kvMap map[string]interface{}, c
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) HKeys(key string, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) HKeys(key string, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "hkeys")
args = append(args, key)
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) HVals(key string, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) HVals(key string, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "hvals")
args = append(args, key)
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) HGetAll(key string, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) HGetAll(key string, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "hgetall")
args = append(args, key)
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) HIncrBy(key, field string, delta int, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) HIncrBy(key, field string, delta int, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "hincrby")
args = append(args, key)
@@ -474,7 +620,10 @@ func (c RedisClusterClient[C]) HIncrBy(key, field string, delta int, callback Re
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) HIncrByFloat(key, field string, delta float64, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) HIncrByFloat(key, field string, delta float64, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "hincrbyfloat")
args = append(args, key)
@@ -484,14 +633,20 @@ func (c RedisClusterClient[C]) HIncrByFloat(key, field string, delta float64, ca
}
// Set
-func (c RedisClusterClient[C]) SCard(key string, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) SCard(key string, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "scard")
args = append(args, key)
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) SAdd(key string, vals []interface{}, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) SAdd(key string, vals []interface{}, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "sadd")
args = append(args, key)
@@ -501,7 +656,10 @@ func (c RedisClusterClient[C]) SAdd(key string, vals []interface{}, callback Red
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) SRem(key string, vals []interface{}, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) SRem(key string, vals []interface{}, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "srem")
args = append(args, key)
@@ -511,7 +669,10 @@ func (c RedisClusterClient[C]) SRem(key string, vals []interface{}, callback Red
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) SIsMember(key string, value interface{}, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) SIsMember(key string, value interface{}, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "sismember")
args = append(args, key)
@@ -519,14 +680,20 @@ func (c RedisClusterClient[C]) SIsMember(key string, value interface{}, callback
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) SMembers(key string, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) SMembers(key string, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "smembers")
args = append(args, key)
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) SDiff(key1, key2 string, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) SDiff(key1, key2 string, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "sdiff")
args = append(args, key1)
@@ -534,7 +701,10 @@ func (c RedisClusterClient[C]) SDiff(key1, key2 string, callback RedisResponseCa
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) SDiffStore(destination, key1, key2 string, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) SDiffStore(destination, key1, key2 string, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "sdiffstore")
args = append(args, destination)
@@ -543,7 +713,10 @@ func (c RedisClusterClient[C]) SDiffStore(destination, key1, key2 string, callba
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) SInter(key1, key2 string, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) SInter(key1, key2 string, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "sinter")
args = append(args, key1)
@@ -551,7 +724,10 @@ func (c RedisClusterClient[C]) SInter(key1, key2 string, callback RedisResponseC
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) SInterStore(destination, key1, key2 string, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) SInterStore(destination, key1, key2 string, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "sinterstore")
args = append(args, destination)
@@ -560,7 +736,10 @@ func (c RedisClusterClient[C]) SInterStore(destination, key1, key2 string, callb
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) SUnion(key1, key2 string, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) SUnion(key1, key2 string, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "sunion")
args = append(args, key1)
@@ -568,7 +747,10 @@ func (c RedisClusterClient[C]) SUnion(key1, key2 string, callback RedisResponseC
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) SUnionStore(destination, key1, key2 string, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) SUnionStore(destination, key1, key2 string, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "sunionstore")
args = append(args, destination)
@@ -578,14 +760,20 @@ func (c RedisClusterClient[C]) SUnionStore(destination, key1, key2 string, callb
}
// ZSet
-func (c RedisClusterClient[C]) ZCard(key string, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) ZCard(key string, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "zcard")
args = append(args, key)
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) ZAdd(key string, msMap map[string]interface{}, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) ZAdd(key string, msMap map[string]interface{}, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "zadd")
args = append(args, key)
@@ -596,7 +784,10 @@ func (c RedisClusterClient[C]) ZAdd(key string, msMap map[string]interface{}, ca
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) ZCount(key string, min interface{}, max interface{}, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) ZCount(key string, min interface{}, max interface{}, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "zcount")
args = append(args, key)
@@ -605,7 +796,10 @@ func (c RedisClusterClient[C]) ZCount(key string, min interface{}, max interface
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) ZIncrBy(key string, member string, delta interface{}, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) ZIncrBy(key string, member string, delta interface{}, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "zincrby")
args = append(args, key)
@@ -614,7 +808,10 @@ func (c RedisClusterClient[C]) ZIncrBy(key string, member string, delta interfac
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) ZScore(key, member string, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) ZScore(key, member string, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "zscore")
args = append(args, key)
@@ -622,7 +819,10 @@ func (c RedisClusterClient[C]) ZScore(key, member string, callback RedisResponse
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) ZRank(key, member string, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) ZRank(key, member string, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "zrank")
args = append(args, key)
@@ -630,7 +830,10 @@ func (c RedisClusterClient[C]) ZRank(key, member string, callback RedisResponseC
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) ZRevRank(key, member string, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) ZRevRank(key, member string, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "zrevrank")
args = append(args, key)
@@ -638,7 +841,10 @@ func (c RedisClusterClient[C]) ZRevRank(key, member string, callback RedisRespon
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) ZRem(key string, members []string, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) ZRem(key string, members []string, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "zrem")
args = append(args, key)
@@ -648,7 +854,10 @@ func (c RedisClusterClient[C]) ZRem(key string, members []string, callback Redis
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) ZRange(key string, start, stop int, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) ZRange(key string, start, stop int, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "zrange")
args = append(args, key)
@@ -657,7 +866,10 @@ func (c RedisClusterClient[C]) ZRange(key string, start, stop int, callback Redi
return RedisCall(c.cluster, respString(args), callback)
}
-func (c RedisClusterClient[C]) ZRevRange(key string, start, stop int, callback RedisResponseCallback) error {
+func (c *RedisClusterClient[C]) ZRevRange(key string, start, stop int, callback RedisResponseCallback) error {
+ if err := c.checkReadyFunc(); err != nil {
+ return err
+ }
args := make([]interface{}, 0)
args = append(args, "zrevrange")
args = append(args, key)