Skip to content

Commit

Permalink
feat: ai-proxy plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
shreemaan-abhishek committed Aug 15, 2024
1 parent 4c87264 commit bf90fc2
Show file tree
Hide file tree
Showing 4 changed files with 581 additions and 0 deletions.
1 change: 1 addition & 0 deletions apisix/cli/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ local _M = {
"proxy-mirror",
"proxy-rewrite",
"workflow",
"ai-proxy",
"api-breaker",
"limit-conn",
"limit-count",
Expand Down
225 changes: 225 additions & 0 deletions apisix/plugins/ai-proxy.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
local cjson = require("cjson.safe")
local core = require("apisix.core")
local schema = require("apisix.plugins.ai-proxy.schema")

local ngx_req = ngx.req
local ngx = ngx
local fmt = string.format

local plugin_name = "ai-proxy"
local _M = {
version = 0.5,
priority = 1002,
name = plugin_name,
schema = schema,
}


function _M.check_schema(conf)
return core.schema.check(schema, conf)
end


-- static messages
local ERROR__NOT_SET = 'data: {"error": true, "message": "empty or unsupported transformer response"}'

-- formats_compatible is a map of formats that are compatible with each other.
local formats_compatible = {
["llm/v1/chat"] = {
["llm/v1/chat"] = true,
},
["llm/v1/completions"] = {
["llm/v1/completions"] = true,
},
}



-- identify_request determines the format of the request.
-- It returns the format, or nil and an error message.
-- @tparam table request The request to identify
-- @treturn[1] string The format of the request
-- @treturn[2] nil
-- @treturn[2] string An error message if unidentified, or matching multiple formats
-- { "messages": [ { "role": "system", "content": "You are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] }
local function identify_request(request)
-- primitive request format determination
local formats = {}

if type(request.messages) == "table" and #request.messages > 0 then
table.insert(formats, "llm/v1/chat")
end

if type(request.prompt) == "string" then
table.insert(formats, "llm/v1/completions")
end

if formats[2] then
return nil, "request matches multiple LLM request formats"
elseif not formats_compatible[formats[1] or false] then
core.log.warn("dibag: ", core.json.encode(request.messages))
core.log.warn("dibag: ", core.json.encode(request))
return nil, "request format not recognised"
else
return formats[1]
end
end


local function is_compatible(request, route_type)
if route_type == "preserve" then
return true
end

local format, err = identify_request(request)
if err then
return nil, err
end

if formats_compatible[format][route_type] then
return true
end

return false, fmt("[%s] message format is not compatible with [%s] route type", format, route_type)
end

local function transform_body(conf, ctx)
local route_type = conf.route_type
local ai_driver = require("apisix.plugins.ai-proxy.drivers." .. conf.model.provider)

-- Note: below even if we are told not to do response transform, we still need to do
-- get the body for analytics
local response_body = core.response.hold_body_chunk(ctx)

local err

if not response_body then
err = "no response body found when transforming response"
elseif route_type ~= "preserve" then
response_body, err = ai_driver.from_format(response_body, conf.model, route_type)

if err then
core.log.error("issue when transforming the response body for analytics: ", err)
end
end

if err then
ngx.status = 500
response_body = cjson.encode({ error = { message = err } })
end

ctx.plugin.buffered_response_body = response_body
end


function _M.header_filter(conf, ctx)
-- only act on 200 in first release - pass the unmodifed response all the way through if any failure
if ngx.status ~= 200 then
return
end

local ai_driver = require("apisix.plugins.ai-proxy.drivers." .. conf.model.provider)
-- ai_driver.post_request(conf)

transform_body(conf, ctx)
core.request.set_header(ctx, "Content-Length", nil)
end

function _M.body_filter(conf, ctx)
if ngx.status ~= 200 then
return
end

ngx.arg[1] = ctx.plugin.buffered_response_body
ngx.arg[2] = true

ctx.plugin.buffered_response_body = nil
end

function _M.access(conf, ctx)
local f = io.open("dibag", "w+")
f:write(core.json.encode(ngx.ctx, true))
f:close()
local route_type = conf.route_type
local multipart = false

local content_type = core.request.header(ctx, "Content-Type") or "application/json"
multipart = content_type == "multipart/form-data" -- this may be a large file upload, so we have to proxy it directly

local request_table = core.request.get_body() -- TODO: max size
local esc, err = request_table:gsub("\\\"", "\"")
if err then
core.log.error("dibagerr: ", err)
end
request_table = core.json.decode(esc)
core.log.warn("dibag: type: ", core.json.encode(esc))
if not request_table then
return 400, "content-type header does not match request body, or bad JSON formatting"
end

-- copy from the user request if present
if (not multipart) and (not conf.model.name) and (request_table.model) then
if type(request_table.model) == "string" then
conf.model.name = request_table.model
end
elseif multipart then
conf.model.name = "NOT_SPECIFIED" -- TEST: UPLOAD A FILE hehe
end

-- check that the user isn't trying to override the plugin conf model in the request body
if type(request_table.model) == "string" and request_table.model ~= "" then
if request_table.model ~= conf.model.name then
return 400, "cannot use own model - must be: " .. conf.model.name
end
end

-- model is stashed in the copied plugin conf, for consistency in transformation functions
if not conf.model.name then
return 400, "model parameter not found in request, nor in gateway configuration"
end

-- check the incoming format is the same as the configured LLM format
local compatible, err = is_compatible(request_table, route_type)
if not multipart and not compatible then
-- llm_state.disable_ai_proxy_response_transform()
return 400, err
end

local ai_driver = require("apisix.plugins.ai-proxy.drivers." .. conf.model.provider)

-- execute pre-request hooks for this driver

-- transform the body to kapisix-format for this provider/model
local parsed_request_body, content_type, err
if route_type ~= "preserve" and (not multipart) then
-- transform the body to kapisix-format for this provider/model
parsed_request_body, content_type, err = ai_driver.to_format(request_table, conf.model, route_type)
if err then
-- llm_state.disable_ai_proxy_response_transform()
return 400, err
end
end

-- execute pre-request hooks for "all" drivers before set new body
local ok, err = ai_driver.pre_request(conf, parsed_request_body)
if not ok then
return 400, err
end

if route_type ~= "preserve" then
ngx_req.set_body_data(core.json.encode(parsed_request_body))
core.request.set_header(ctx, "Content-Type", content_type)
end

-- get the provider's cached identity interface - nil may come back, which is fine

-- now re-configure the request for this operation type
local ok, err = ai_driver.configure_request(conf, ctx)
if not ok then
core.log.error("failed to configure request for AI service: ", err)
return 500
end

end

return _M
Loading

0 comments on commit bf90fc2

Please sign in to comment.