Skip to content

Commit

Permalink
Fix apache#11276 + little style refactor
Browse files Browse the repository at this point in the history
Features:
- config param "key_claim_name" (default = "key"), so for example one could use "iss" to check the validity of the JWT;

Style:
- 2 blank lines between functions;
- 1 blank like before "else" and "elseif";
- jwt -> JWT;
- Capitalized logs and response messages;
- Added description for each schema configuration parameter;
  • Loading branch information
mikyll committed May 23, 2024
1 parent 0468d78 commit e393d0c
Showing 1 changed file with 83 additions and 31 deletions.
114 changes: 83 additions & 31 deletions apisix/plugins/jwt-auth.lua
Original file line number Diff line number Diff line change
Expand Up @@ -36,41 +36,72 @@ local schema = {
type = "object",
properties = {
header = {
description = "The name of the HTTP header where the JWT token is expected to be "
.. "found.",
type = "string",
default = "authorization"
},
query = {
description = "The name of the query parameter where the JWT token is expected to be "
.. "found.",
type = "string",
default = "jwt"
},
cookie = {
description = "The name of the cookie where the JWT token is expected to be found.",
type = "string",
default = "jwt"
},
hide_credentials = {
description = "If true, the plugin will remove the JWT token from the header, query, "
.. "or cookie after extracting it to avoid sending it to the upstream "
.. "service.",
type = "boolean",
default = false
},
-- New config parameter
key_claim_name = {
description = "The name of the claim in the JWT token that contains the user key.",
type = "string",
default = "key"
}
},
}


local consumer_schema = {
type = "object",
-- can't use additionalProperties with dependencies
properties = {
key = {type = "string"},
secret = {type = "string"},
key = {
description = "A unique key used to identify the Consumer (corresponds to "
.. "conf.key_claim_name in JWT).",
type = "string"
},
secret = {
description = "The encryption key used for signing and verifying the JWT token. "
.. "If unspecified, auto generated in the background.",
type = "string"
},
algorithm = {
description = "The encryption algorithm used for signing and verifying the JWT token.",
type = "string",
enum = {"HS256", "HS512", "RS256", "ES256"},
default = "HS256"
},
exp = {type = "integer", minimum = 1, default = 86400},
exp = {
description = "The expiration time of the JWT token in seconds.",
type = "integer", minimum = 1,
default = 86400
},
base64_secret = {
description = "Indicates if the secret is base64 encoded.",
type = "boolean",
default = false
},
lifetime_grace_period = {
description = "The grace period in seconds for the JWT token lifetime validation. "
.. "Allows a buffer time for token expiration.",
type = "integer",
minimum = 0,
default = 0
Expand All @@ -89,8 +120,16 @@ local consumer_schema = {
},
{
properties = {
public_key = {type = "string"},
private_key= {type = "string"},
public_key = {
description = "The public key used for verifying the JWT token when "
.. "using RS256 or ES256 algorithms.",
type = "string"
},
private_key= {
description = "The private key used for signing the JWT token when "
.. "using RS256 or ES256 algorithms.",
type = "string"
},
algorithm = {
enum = {"RS256", "ES256"},
},
Expand All @@ -116,11 +155,12 @@ local _M = {


function _M.check_schema(conf, schema_type)
core.log.info("input conf: ", core.json.delay_encode(conf))
core.log.info("Input conf: ", core.json.delay_encode(conf))

local ok, err
if schema_type == core.schema.TYPE_CONSUMER then
ok, err = core.schema.check(consumer_schema, conf)

else
return core.schema.check(schema, conf)
end
Expand All @@ -131,6 +171,7 @@ function _M.check_schema(conf, schema_type)

if conf.algorithm ~= "RS256" and conf.algorithm ~= "ES256" and not conf.secret then
conf.secret = ngx_encode_base64(resty_random.bytes(32, true))

elseif conf.base64_secret then
if ngx_decode_base64(conf.secret) == nil then
return false, "base64_secret required but the secret is not in base64 format"
Expand All @@ -151,20 +192,21 @@ function _M.check_schema(conf, schema_type)
return true
end


local function remove_specified_cookie(src, key)
local cookie_key_pattern = "([a-zA-Z0-9-_]*)"
local cookie_val_pattern = "([a-zA-Z0-9-._]*)"
local t = new_tab(1, 0)

local it, err = ngx_re_gmatch(src, cookie_key_pattern .. "=" .. cookie_val_pattern, "jo")
if not it then
core.log.error("match origins failed: ", err)
core.log.error("Match origins failed: ", err)
return src
end
while true do
local m, err = it()
if err then
core.log.error("iterate origins failed: ", err)
core.log.error("Iterate origins failed: ", err)
return src
end
if not m then
Expand All @@ -178,6 +220,7 @@ local function remove_specified_cookie(src, key)
return table_concat(t, "; ")
end


local function fetch_jwt_token(conf, ctx)
local token = core.request.header(ctx, conf.header)
if token then
Expand Down Expand Up @@ -220,6 +263,7 @@ local function fetch_jwt_token(conf, ctx)
return val
end


local function get_secret(conf)
local secret = conf.secret

Expand All @@ -237,10 +281,13 @@ local function get_rsa_or_ecdsa_keypair(conf)

if public_key and private_key then
return public_key, private_key

elseif public_key and not private_key then
return nil, nil, "missing private key"

elseif not public_key and private_key then
return nil, nil, "missing public key"

else
return nil, nil, "public and private keys are missing"
end
Expand All @@ -264,8 +311,8 @@ end
local function sign_jwt_with_HS(key, consumer, payload)
local auth_secret, err = get_secret(consumer.auth_conf)
if not auth_secret then
core.log.error("failed to sign jwt, err: ", err)
core.response.exit(503, "failed to sign jwt")
core.log.error("Failed to sign JWT, err: ", err)
core.response.exit(503, "Failed to sign JWT")
end
local ok, jwt_token = pcall(jwt.sign, _M,
auth_secret,
Expand All @@ -278,8 +325,8 @@ local function sign_jwt_with_HS(key, consumer, payload)
}
)
if not ok then
core.log.warn("failed to sign jwt, err: ", jwt_token.reason)
core.response.exit(500, "failed to sign jwt")
core.log.warn("Failed to sign JWT, err: ", jwt_token.reason)
core.response.exit(500, "Failed to sign JWT")
end
return jwt_token
end
Expand All @@ -290,8 +337,8 @@ local function sign_jwt_with_RS256_ES256(key, consumer, payload)
consumer.auth_conf
)
if not public_key then
core.log.error("failed to sign jwt, err: ", err)
core.response.exit(503, "failed to sign jwt")
core.log.error("Failed to sign JWT, err: ", err)
core.response.exit(503, "Failed to sign JWT")
end

local ok, jwt_token = pcall(jwt.sign, _M,
Expand All @@ -308,12 +355,13 @@ local function sign_jwt_with_RS256_ES256(key, consumer, payload)
}
)
if not ok then
core.log.warn("failed to sign jwt, err: ", jwt_token.reason)
core.response.exit(500, "failed to sign jwt")
core.log.warn("Failed to sign JWT, err: ", jwt_token.reason)
core.response.exit(500, "Failed to sign JWT")
end
return jwt_token
end


-- introducing method_only flag (returns respective signing method) to save http API calls.
local function algorithm_handler(consumer, method_only)
if not consumer.auth_conf.algorithm or consumer.auth_conf.algorithm == "HS256"
Expand All @@ -323,6 +371,7 @@ local function algorithm_handler(consumer, method_only)
end

return get_secret(consumer.auth_conf)

elseif consumer.auth_conf.algorithm == "RS256" or consumer.auth_conf.algorithm == "ES256" then
if method_only then
return sign_jwt_with_RS256_ES256
Expand All @@ -333,24 +382,26 @@ local function algorithm_handler(consumer, method_only)
end
end


function _M.rewrite(conf, ctx)
-- fetch token and hide credentials if necessary
local jwt_token, err = fetch_jwt_token(conf, ctx)
if not jwt_token then
core.log.info("failed to fetch JWT token: ", err)
core.log.info("Failed to fetch JWT token: ", err)
return 401, {message = "Missing JWT token in request"}
end

local jwt_obj = jwt:load_jwt(jwt_token)
core.log.info("jwt object: ", core.json.delay_encode(jwt_obj))
core.log.info("JWT object: ", core.json.delay_encode(jwt_obj))
if not jwt_obj.valid then
core.log.warn("JWT token invalid: ", jwt_obj.reason)
return 401, {message = "JWT token invalid"}
end

local user_key = jwt_obj.payload and jwt_obj.payload.key
local key_claim_name = conf.key_claim_name
local user_key = jwt_obj.payload and jwt_obj.payload[key_claim_name]
if not user_key then
return 401, {message = "missing user key in JWT token"}
return 401, {message = "Missing " .. key_claim_name .. " claim in JWT token"}
end

local consumer_conf = consumer_mod.plugin(plugin_name)
Expand All @@ -362,28 +413,29 @@ function _M.rewrite(conf, ctx)

local consumer = consumers[user_key]
if not consumer then
return 401, {message = "Invalid user key in JWT token"}
-- This means that there's a mismatch between the JWT key claim and the Consumer key field
return 401, {message = "Invalid user " .. key_claim_name .. " in JWT token"}
end
core.log.info("consumer: ", core.json.delay_encode(consumer))
core.log.info("Consumer: ", core.json.delay_encode(consumer))

local auth_secret, err = algorithm_handler(consumer)
if not auth_secret then
core.log.error("failed to retrieve secrets, err: ", err)
return 503, {message = "failed to verify jwt"}
core.log.error("Failed to retrieve secrets, err: ", err)
return 503, {message = "Failed to verify JWT"}
end
local claim_specs = jwt:get_default_validation_options(jwt_obj)
claim_specs.lifetime_grace_period = consumer.auth_conf.lifetime_grace_period

jwt_obj = jwt:verify_jwt_obj(auth_secret, jwt_obj, claim_specs)
core.log.info("jwt object: ", core.json.delay_encode(jwt_obj))
core.log.info("JWT object: ", core.json.delay_encode(jwt_obj))

if not jwt_obj.verified then
core.log.warn("failed to verify jwt: ", jwt_obj.reason)
return 401, {message = "failed to verify jwt"}
core.log.warn("Failed to verify JWT: ", jwt_obj.reason)
return 401, {message = "Failed to verify JWT"}
end

consumer_mod.attach_consumer(ctx, consumer, consumer_conf)
core.log.info("hit jwt-auth rewrite")
core.log.info("Hit jwt-auth rewrite")
end


Expand All @@ -406,13 +458,13 @@ local function gen_token()

local consumers = consumer_mod.consumers_kv(plugin_name, consumer_conf, "key")

core.log.info("consumers: ", core.json.delay_encode(consumers))
core.log.info("Consumers: ", core.json.delay_encode(consumers))
local consumer = consumers[key]
if not consumer then
return core.response.exit(404)
end

core.log.info("consumer: ", core.json.delay_encode(consumer))
core.log.info("Consumer: ", core.json.delay_encode(consumer))

local sign_handler = algorithm_handler(consumer, true)
local jwt_token = sign_handler(key, consumer, payload)
Expand All @@ -435,4 +487,4 @@ function _M.api()
end


return _M
return _M

0 comments on commit e393d0c

Please sign in to comment.