-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4c87264
commit bf90fc2
Showing
4 changed files
with
581 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,225 @@ | ||
local cjson = require("cjson.safe") | ||
local core = require("apisix.core") | ||
local schema = require("apisix.plugins.ai-proxy.schema") | ||
|
||
local ngx_req = ngx.req | ||
local ngx = ngx | ||
local fmt = string.format | ||
|
||
local plugin_name = "ai-proxy" | ||
local _M = { | ||
version = 0.5, | ||
priority = 1002, | ||
name = plugin_name, | ||
schema = schema, | ||
} | ||
|
||
|
||
function _M.check_schema(conf) | ||
return core.schema.check(schema, conf) | ||
end | ||
|
||
|
||
-- static messages | ||
local ERROR__NOT_SET = 'data: {"error": true, "message": "empty or unsupported transformer response"}' | ||
|
||
-- formats_compatible is a map of formats that are compatible with each other. | ||
local formats_compatible = { | ||
["llm/v1/chat"] = { | ||
["llm/v1/chat"] = true, | ||
}, | ||
["llm/v1/completions"] = { | ||
["llm/v1/completions"] = true, | ||
}, | ||
} | ||
|
||
|
||
|
||
-- identify_request determines the format of the request. | ||
-- It returns the format, or nil and an error message. | ||
-- @tparam table request The request to identify | ||
-- @treturn[1] string The format of the request | ||
-- @treturn[2] nil | ||
-- @treturn[2] string An error message if unidentified, or matching multiple formats | ||
-- { "messages": [ { "role": "system", "content": "You are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] } | ||
local function identify_request(request) | ||
-- primitive request format determination | ||
local formats = {} | ||
|
||
if type(request.messages) == "table" and #request.messages > 0 then | ||
table.insert(formats, "llm/v1/chat") | ||
end | ||
|
||
if type(request.prompt) == "string" then | ||
table.insert(formats, "llm/v1/completions") | ||
end | ||
|
||
if formats[2] then | ||
return nil, "request matches multiple LLM request formats" | ||
elseif not formats_compatible[formats[1] or false] then | ||
core.log.warn("dibag: ", core.json.encode(request.messages)) | ||
core.log.warn("dibag: ", core.json.encode(request)) | ||
return nil, "request format not recognised" | ||
else | ||
return formats[1] | ||
end | ||
end | ||
|
||
|
||
local function is_compatible(request, route_type) | ||
if route_type == "preserve" then | ||
return true | ||
end | ||
|
||
local format, err = identify_request(request) | ||
if err then | ||
return nil, err | ||
end | ||
|
||
if formats_compatible[format][route_type] then | ||
return true | ||
end | ||
|
||
return false, fmt("[%s] message format is not compatible with [%s] route type", format, route_type) | ||
end | ||
|
||
local function transform_body(conf, ctx) | ||
local route_type = conf.route_type | ||
local ai_driver = require("apisix.plugins.ai-proxy.drivers." .. conf.model.provider) | ||
|
||
-- Note: below even if we are told not to do response transform, we still need to do | ||
-- get the body for analytics | ||
local response_body = core.response.hold_body_chunk(ctx) | ||
|
||
local err | ||
|
||
if not response_body then | ||
err = "no response body found when transforming response" | ||
elseif route_type ~= "preserve" then | ||
response_body, err = ai_driver.from_format(response_body, conf.model, route_type) | ||
|
||
if err then | ||
core.log.error("issue when transforming the response body for analytics: ", err) | ||
end | ||
end | ||
|
||
if err then | ||
ngx.status = 500 | ||
response_body = cjson.encode({ error = { message = err } }) | ||
end | ||
|
||
ctx.plugin.buffered_response_body = response_body | ||
end | ||
|
||
|
||
function _M.header_filter(conf, ctx) | ||
-- only act on 200 in first release - pass the unmodifed response all the way through if any failure | ||
if ngx.status ~= 200 then | ||
return | ||
end | ||
|
||
local ai_driver = require("apisix.plugins.ai-proxy.drivers." .. conf.model.provider) | ||
-- ai_driver.post_request(conf) | ||
|
||
transform_body(conf, ctx) | ||
core.request.set_header(ctx, "Content-Length", nil) | ||
end | ||
|
||
function _M.body_filter(conf, ctx) | ||
if ngx.status ~= 200 then | ||
return | ||
end | ||
|
||
ngx.arg[1] = ctx.plugin.buffered_response_body | ||
ngx.arg[2] = true | ||
|
||
ctx.plugin.buffered_response_body = nil | ||
end | ||
|
||
function _M.access(conf, ctx) | ||
local f = io.open("dibag", "w+") | ||
f:write(core.json.encode(ngx.ctx, true)) | ||
f:close() | ||
local route_type = conf.route_type | ||
local multipart = false | ||
|
||
local content_type = core.request.header(ctx, "Content-Type") or "application/json" | ||
multipart = content_type == "multipart/form-data" -- this may be a large file upload, so we have to proxy it directly | ||
|
||
local request_table = core.request.get_body() -- TODO: max size | ||
local esc, err = request_table:gsub("\\\"", "\"") | ||
if err then | ||
core.log.error("dibagerr: ", err) | ||
end | ||
request_table = core.json.decode(esc) | ||
core.log.warn("dibag: type: ", core.json.encode(esc)) | ||
if not request_table then | ||
return 400, "content-type header does not match request body, or bad JSON formatting" | ||
end | ||
|
||
-- copy from the user request if present | ||
if (not multipart) and (not conf.model.name) and (request_table.model) then | ||
if type(request_table.model) == "string" then | ||
conf.model.name = request_table.model | ||
end | ||
elseif multipart then | ||
conf.model.name = "NOT_SPECIFIED" -- TEST: UPLOAD A FILE hehe | ||
end | ||
|
||
-- check that the user isn't trying to override the plugin conf model in the request body | ||
if type(request_table.model) == "string" and request_table.model ~= "" then | ||
if request_table.model ~= conf.model.name then | ||
return 400, "cannot use own model - must be: " .. conf.model.name | ||
end | ||
end | ||
|
||
-- model is stashed in the copied plugin conf, for consistency in transformation functions | ||
if not conf.model.name then | ||
return 400, "model parameter not found in request, nor in gateway configuration" | ||
end | ||
|
||
-- check the incoming format is the same as the configured LLM format | ||
local compatible, err = is_compatible(request_table, route_type) | ||
if not multipart and not compatible then | ||
-- llm_state.disable_ai_proxy_response_transform() | ||
return 400, err | ||
end | ||
|
||
local ai_driver = require("apisix.plugins.ai-proxy.drivers." .. conf.model.provider) | ||
|
||
-- execute pre-request hooks for this driver | ||
|
||
-- transform the body to kapisix-format for this provider/model | ||
local parsed_request_body, content_type, err | ||
if route_type ~= "preserve" and (not multipart) then | ||
-- transform the body to kapisix-format for this provider/model | ||
parsed_request_body, content_type, err = ai_driver.to_format(request_table, conf.model, route_type) | ||
if err then | ||
-- llm_state.disable_ai_proxy_response_transform() | ||
return 400, err | ||
end | ||
end | ||
|
||
-- execute pre-request hooks for "all" drivers before set new body | ||
local ok, err = ai_driver.pre_request(conf, parsed_request_body) | ||
if not ok then | ||
return 400, err | ||
end | ||
|
||
if route_type ~= "preserve" then | ||
ngx_req.set_body_data(core.json.encode(parsed_request_body)) | ||
core.request.set_header(ctx, "Content-Type", content_type) | ||
end | ||
|
||
-- get the provider's cached identity interface - nil may come back, which is fine | ||
|
||
-- now re-configure the request for this operation type | ||
local ok, err = ai_driver.configure_request(conf, ctx) | ||
if not ok then | ||
core.log.error("failed to configure request for AI service: ", err) | ||
return 500 | ||
end | ||
|
||
end | ||
|
||
return _M |
Oops, something went wrong.