diff --git a/CHANGELOG.md b/CHANGELOG.md index 5d6e371e37..f460f9cb76 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ This project adheres to [Semantic Versioning](http://semver.org/). - #1614, Add `db-pool-automatic-recovery` configuration to disable connection retrying - @taimoorzaeem - #2492, Allow full response control when raising exceptions - @taimoorzaeem, @laurenceisla - #2771, Add `Server-Timing` header with JWT duration - @taimoorzaeem + - #2698, Add config `jwt-cache-max-lifetime` and implement JWT caching - @taimoorzaeem ### Fixed diff --git a/postgrest.cabal b/postgrest.cabal index 290618a04d..dbd74ac0b6 100644 --- a/postgrest.cabal +++ b/postgrest.cabal @@ -80,8 +80,10 @@ library , auto-update >= 0.1.4 && < 0.2 , base64-bytestring >= 1 && < 1.3 , bytestring >= 0.10.8 && < 0.12 + , cache >= 0.1.3 && < 0.2.0 , case-insensitive >= 1.2 && < 1.3 , cassava >= 0.4.5 && < 0.6 + , clock >= 0.8.3 && < 0.9.0 , configurator-pg >= 0.2 && < 0.3 , containers >= 0.5.7 && < 0.7 , contravariant-extras >= 0.3.3 && < 0.4 diff --git a/src/PostgREST/AppState.hs b/src/PostgREST/AppState.hs index 89e3cb69d1..ed20481f60 100644 --- a/src/PostgREST/AppState.hs +++ b/src/PostgREST/AppState.hs @@ -4,6 +4,7 @@ module PostgREST.AppState ( AppState + , AuthResult(..) , destroy , getConfig , getSchemaCache @@ -12,6 +13,7 @@ module PostgREST.AppState , getPgVersion , getRetryNextIn , getTime + , getJwtCache , init , initWithPool , logWithZTime @@ -24,8 +26,11 @@ module PostgREST.AppState , runListener ) where +import qualified Data.Aeson as JSON +import qualified Data.Aeson.KeyMap as KM import qualified Data.ByteString.Char8 as BS import qualified Data.ByteString.Lazy as LBS +import qualified Data.Cache as C import Data.Either.Combinators (whenLeft) import qualified Data.Text.Encoding as T import Hasql.Connection (acquire) @@ -62,6 +67,11 @@ import PostgREST.SchemaCache.Identifiers (dumpQi) import Protolude +data AuthResult = AuthResult + { authClaims :: KM.KeyMap JSON.Value + , authRole :: BS.ByteString + } + data AppState = AppState -- | Database connection pool { statePool :: SQL.Pool @@ -87,6 +97,8 @@ data AppState = AppState , stateRetryNextIn :: IORef Int -- | Logs a pool error with a debounce , debounceLogAcquisitionTimeout :: IO () + -- | JWT Cache + , jwtCache :: C.Cache ByteString AuthResult } init :: AppConfig -> IO AppState @@ -108,6 +120,7 @@ initWithPool pool conf = do <*> myThreadId <*> newIORef 0 <*> pure (pure ()) + <*> C.newCache Nothing debLogTimeout <- @@ -188,6 +201,9 @@ putConfig = atomicWriteIORef . stateConf getTime :: AppState -> IO UTCTime getTime = stateGetTime +getJwtCache :: AppState -> C.Cache ByteString AuthResult +getJwtCache = jwtCache + -- | Log to stderr with local time logWithZTime :: AppState -> Text -> IO () logWithZTime appState txt = do diff --git a/src/PostgREST/Auth.hs b/src/PostgREST/Auth.hs index fb07daa600..35096e138e 100644 --- a/src/PostgREST/Auth.hs +++ b/src/PostgREST/Auth.hs @@ -26,6 +26,8 @@ import qualified Data.Aeson.KeyMap as KM import qualified Data.Aeson.Types as JSON import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy.Char8 as LBS +import qualified Data.Cache as C +import qualified Data.Scientific as Sci import qualified Data.Vault.Lazy as Vault import qualified Data.Vector as V import qualified Network.HTTP.Types.Header as HTTP @@ -36,22 +38,20 @@ import Control.Lens (set) import Control.Monad.Except (liftEither) import Data.Either.Combinators (mapLeft) import Data.List (lookup) -import Data.Time.Clock (UTCTime) +import Data.Time.Clock (UTCTime, nominalDiffTimeToSeconds) +import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds) +import System.Clock (TimeSpec (..)) import System.IO.Unsafe (unsafePerformIO) import System.TimeIt (timeItT) -import PostgREST.AppState (AppState, getConfig, getTime) +import PostgREST.AppState (AppState, AuthResult (..), getConfig, + getJwtCache, getTime) import PostgREST.Config (AppConfig (..), JSPath, JSPathExp (..)) import PostgREST.Error (Error (..)) import Protolude -data AuthResult = AuthResult - { authClaims :: KM.KeyMap JSON.Value - , authRole :: BS.ByteString - } - -- | Receives the JWT secret and audience (from config) and a JWT and returns a -- JSON object of JWT claims. parseToken :: Monad m => @@ -107,16 +107,48 @@ middleware appState app req respond = do let token = fromMaybe "" $ Wai.extractBearerAuth =<< lookup HTTP.hAuthorization (Wai.requestHeaders req) parseJwt = runExceptT $ parseToken conf (LBS.fromStrict token) time >>= parseClaims conf - if configDbPlanEnabled conf - then do - (dur,authResult) <- timeItT parseJwt - let req' = req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult & Vault.insert jwtDurKey dur } - app req' respond - else do - authResult <- parseJwt - let req' = req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult } - app req' respond - +-- If DbPlanEnabled -> calculate JWT validation time +-- If JwtCacheMaxLifetime -> cache JWT validation result + req' <- case (configDbPlanEnabled conf, configJwtCacheMaxLifetime conf) of + (True, 0) -> do + (dur, authResult) <- timeItT parseJwt + return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult & Vault.insert jwtDurKey dur } + + (True, maxLifetime) -> do + (dur, authResult) <- timeItT $ getJWTFromCache appState token maxLifetime parseJwt time + return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult & Vault.insert jwtDurKey dur } + + (False, 0) -> do + authResult <- parseJwt + return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult } + + (False, maxLifetime) -> do + authResult <- getJWTFromCache appState token maxLifetime parseJwt time + return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult } + + app req' respond + +-- | Used to retrieve and insert JWT to JWT Cache +getJWTFromCache :: AppState -> ByteString -> Int -> IO (Either Error AuthResult) -> UTCTime -> IO (Either Error AuthResult) +getJWTFromCache appState token maxLifetime parseJwt utc = do + checkCache <- C.lookup (getJwtCache appState) token + authResult <- maybe parseJwt (pure . Right) checkCache + + case (authResult,checkCache) of + (Right res, Nothing) -> C.insert' (getJwtCache appState) (getTimeSpec res maxLifetime utc) token res + _ -> pure () + + return authResult + +-- Used to extract JWT exp claim and add to JWT Cache +getTimeSpec :: AuthResult -> Int -> UTCTime -> Maybe TimeSpec +getTimeSpec res maxLifetime utc = do + let expireJSON = KM.lookup "exp" (authClaims res) + utcToSecs = floor . nominalDiffTimeToSeconds . utcTimeToPOSIXSeconds + sciToInt = fromMaybe 0 . Sci.toBoundedInteger + case expireJSON of + Just (JSON.Number seconds) -> Just $ TimeSpec (sciToInt seconds - utcToSecs utc) 0 + _ -> Just $ TimeSpec (fromIntegral maxLifetime :: Int64) 0 authResultKey :: Vault.Key (Either Error AuthResult) authResultKey = unsafePerformIO Vault.newKey diff --git a/src/PostgREST/CLI.hs b/src/PostgREST/CLI.hs index 7f39d212bf..2bbd907278 100644 --- a/src/PostgREST/CLI.hs +++ b/src/PostgREST/CLI.hs @@ -162,7 +162,7 @@ exampleConfigFile = |## Time in seconds after which to recycle unused pool connections |# db-pool-max-idletime = 30 | - |## Allow autmatic database connection retrying + |## Allow automatic database connection retrying |# db-pool-automatic-recovery = true | |## Stored proc to exec immediately after auth @@ -205,6 +205,9 @@ exampleConfigFile = |# jwt-secret = "secret_with_at_least_32_characters" |jwt-secret-is-base64 = false | + |## Enables and set JWT Cache max lifetime, disables caching with 0 + |# jwt-cache-max-lifetime = 0 + | |## Logging level, the admitted values are: crit, error, warn and info. |log-level = "error" | diff --git a/src/PostgREST/Config.hs b/src/PostgREST/Config.hs index 4221ff59db..47f4e99abe 100644 --- a/src/PostgREST/Config.hs +++ b/src/PostgREST/Config.hs @@ -97,6 +97,7 @@ data AppConfig = AppConfig , configJwtRoleClaimKey :: JSPath , configJwtSecret :: Maybe BS.ByteString , configJwtSecretIsBase64 :: Bool + , configJwtCacheMaxLifetime :: Int , configLogLevel :: LogLevel , configOpenApiMode :: OpenAPIMode , configOpenApiSecurityActive :: Bool @@ -162,6 +163,7 @@ toText conf = ,("jwt-role-claim-key", q . T.intercalate mempty . fmap dumpJSPath . configJwtRoleClaimKey) ,("jwt-secret", q . T.decodeUtf8 . showJwtSecret) ,("jwt-secret-is-base64", T.toLower . show . configJwtSecretIsBase64) + ,("jwt-cache-max-lifetime", show . configJwtCacheMaxLifetime) ,("log-level", q . dumpLogLevel . configLogLevel) ,("openapi-mode", q . dumpOpenApiMode . configOpenApiMode) ,("openapi-security-active", T.toLower . show . configOpenApiSecurityActive) @@ -265,6 +267,7 @@ parser optPath env dbSettings roleSettings roleIsolationLvl = <*> (fromMaybe False <$> optWithAlias (optBool "jwt-secret-is-base64") (optBool "secret-is-base64")) + <*> (fromMaybe 0 <$> optInt "jwt-cache-max-lifetime") <*> parseLogLevel "log-level" <*> parseOpenAPIMode "openapi-mode" <*> (fromMaybe False <$> optBool "openapi-security-active") diff --git a/test/io/configs/expected/aliases.config b/test/io/configs/expected/aliases.config index 96e74e2ade..3e1df8855e 100644 --- a/test/io/configs/expected/aliases.config +++ b/test/io/configs/expected/aliases.config @@ -22,6 +22,7 @@ jwt-aud = "" jwt-role-claim-key = ".\"aliased\"" jwt-secret = "" jwt-secret-is-base64 = true +jwt-cache-max-lifetime = 0 log-level = "error" openapi-mode = "follow-privileges" openapi-security-active = false diff --git a/test/io/configs/expected/boolean-numeric.config b/test/io/configs/expected/boolean-numeric.config index 438cd35897..1b86f251f6 100644 --- a/test/io/configs/expected/boolean-numeric.config +++ b/test/io/configs/expected/boolean-numeric.config @@ -22,6 +22,7 @@ jwt-aud = "" jwt-role-claim-key = ".\"role\"" jwt-secret = "" jwt-secret-is-base64 = true +jwt-cache-max-lifetime = 0 log-level = "error" openapi-mode = "follow-privileges" openapi-security-active = false diff --git a/test/io/configs/expected/boolean-string.config b/test/io/configs/expected/boolean-string.config index 438cd35897..1b86f251f6 100644 --- a/test/io/configs/expected/boolean-string.config +++ b/test/io/configs/expected/boolean-string.config @@ -22,6 +22,7 @@ jwt-aud = "" jwt-role-claim-key = ".\"role\"" jwt-secret = "" jwt-secret-is-base64 = true +jwt-cache-max-lifetime = 0 log-level = "error" openapi-mode = "follow-privileges" openapi-security-active = false diff --git a/test/io/configs/expected/defaults.config b/test/io/configs/expected/defaults.config index 3561d74280..d4f6a9624e 100644 --- a/test/io/configs/expected/defaults.config +++ b/test/io/configs/expected/defaults.config @@ -22,6 +22,7 @@ jwt-aud = "" jwt-role-claim-key = ".\"role\"" jwt-secret = "" jwt-secret-is-base64 = false +jwt-cache-max-lifetime = 0 log-level = "error" openapi-mode = "follow-privileges" openapi-security-active = false diff --git a/test/io/configs/expected/no-defaults-with-db-other-authenticator.config b/test/io/configs/expected/no-defaults-with-db-other-authenticator.config index 226b1a7341..856a109a07 100644 --- a/test/io/configs/expected/no-defaults-with-db-other-authenticator.config +++ b/test/io/configs/expected/no-defaults-with-db-other-authenticator.config @@ -22,6 +22,7 @@ jwt-aud = "https://otherexample.org" jwt-role-claim-key = ".\"other\".\"pre_config_role\"" jwt-secret = "ODERREALLYREALLYREALLYREALLYVERYSAFE" jwt-secret-is-base64 = true +jwt-cache-max-lifetime = 86400 log-level = "info" openapi-mode = "disabled" openapi-security-active = false diff --git a/test/io/configs/expected/no-defaults-with-db.config b/test/io/configs/expected/no-defaults-with-db.config index d434cf7781..9cab547f6e 100644 --- a/test/io/configs/expected/no-defaults-with-db.config +++ b/test/io/configs/expected/no-defaults-with-db.config @@ -22,6 +22,7 @@ jwt-aud = "https://example.org" jwt-role-claim-key = ".\"a\".\"role\"" jwt-secret = "OVERRIDE=REALLY=REALLY=REALLY=REALLY=VERY=SAFE" jwt-secret-is-base64 = false +jwt-cache-max-lifetime = 86400 log-level = "info" openapi-mode = "ignore-privileges" openapi-security-active = true diff --git a/test/io/configs/expected/no-defaults.config b/test/io/configs/expected/no-defaults.config index a657ef851d..1e84858d06 100644 --- a/test/io/configs/expected/no-defaults.config +++ b/test/io/configs/expected/no-defaults.config @@ -22,6 +22,7 @@ jwt-aud = "https://postgrest.org" jwt-role-claim-key = ".\"user\"[0].\"real-role\"" jwt-secret = "c2VjdXJpdHl0aHJvdWdob2JzY3VyaXR5" jwt-secret-is-base64 = true +jwt-cache-max-lifetime = 86400 log-level = "info" openapi-mode = "ignore-privileges" openapi-security-active = true diff --git a/test/io/configs/expected/types.config b/test/io/configs/expected/types.config index bd0d91f4ed..40bda26d5c 100644 --- a/test/io/configs/expected/types.config +++ b/test/io/configs/expected/types.config @@ -22,6 +22,7 @@ jwt-aud = "" jwt-role-claim-key = ".\"role\"" jwt-secret = "" jwt-secret-is-base64 = false +jwt-cache-max-lifetime = 0 log-level = "error" openapi-mode = "follow-privileges" openapi-security-active = false diff --git a/test/io/configs/no-defaults-env.yaml b/test/io/configs/no-defaults-env.yaml index 158cd0e874..d3cd013ee4 100644 --- a/test/io/configs/no-defaults-env.yaml +++ b/test/io/configs/no-defaults-env.yaml @@ -24,6 +24,7 @@ PGRST_JWT_AUD: 'https://postgrest.org' PGRST_JWT_ROLE_CLAIM_KEY: '.user[0]."real-role"' PGRST_JWT_SECRET: c2VjdXJpdHl0aHJvdWdob2JzY3VyaXR5 PGRST_JWT_SECRET_IS_BASE64: true +PGRST_JWT_CACHE_MAX_LIFETIME: 86400 PGRST_LOG_LEVEL: info PGRST_OPENAPI_MODE: 'ignore-privileges' PGRST_OPENAPI_SECURITY_ACTIVE: true diff --git a/test/io/configs/no-defaults.config b/test/io/configs/no-defaults.config index 9c7419b16b..5488b065f9 100644 --- a/test/io/configs/no-defaults.config +++ b/test/io/configs/no-defaults.config @@ -22,6 +22,7 @@ jwt-aud = "https://postgrest.org" jwt-role-claim-key = ".user[0].\"real-role\"" jwt-secret = "c2VjdXJpdHl0aHJvdWdob2JzY3VyaXR5" jwt-secret-is-base64 = true +jwt-cache-max-lifetime = 86400 log-level = "info" openapi-mode = "ignore-privileges" openapi-security-active = true diff --git a/test/io/test_io.py b/test/io/test_io.py index 55a0712a23..4398da18a3 100644 --- a/test/io/test_io.py +++ b/test/io/test_io.py @@ -1,6 +1,6 @@ "Unit tests for Input/Ouput of PostgREST seen as a black box." -from datetime import datetime +from datetime import datetime, timedelta, timezone from operator import attrgetter import os import re @@ -1095,3 +1095,122 @@ def test_fail_with_automatic_recovery_disabled_and_terminated_using_query(defaul exitCode = wait_until_exit(postgrest) assert exitCode == 1 + + +def test_server_timing_jwt_should_decrease_on_subsequent_requests(defaultenv): + "assert that server-timing duration for JWT should decrease on subsequent requests" + + env = { + **defaultenv, + "PGRST_DB_PLAN_ENABLED": "true", + "PGRST_JWT_CACHE_MAX_LIFETIME": "86400", + "PGRST_JWT_SECRET": "@/dev/stdin", + "PGRST_DB_CONFIG": "false", + } + + headers = jwtauthheader( + { + "role": "postgrest_test_author", + "exp": int( + (datetime.now(timezone.utc) + timedelta(minutes=30)).timestamp() + ), + }, + SECRET, + ) + + with run(stdin=SECRET.encode(), env=env) as postgrest: + first_dur_text = postgrest.session.get( + "/authors_only", headers=headers + ).headers["Server-Timing"] + second_dur_text = postgrest.session.get( + "/authors_only", headers=headers + ).headers["Server-Timing"] + + first_dur = float(first_dur_text[8:]) # skip "jwt;dur=" + second_dur = float(second_dur_text[8:]) + + # their difference should be atleast 300, implying + # that JWT Caching is working as expected + assert (first_dur - second_dur) > 300.0 + + +# just added to complete code coverage +def test_jwt_caching_works_with_db_plan_disabled(defaultenv): + "assert that JWT caching words even when Server-Timing header is not returned" + + env = { + **defaultenv, + "PGRST_DB_PLAN_ENABLED": "false", + "PGRST_JWT_CACHE_MAX_LIFETIME": "86400", + "PGRST_JWT_SECRET": "@/dev/stdin", + "PGRST_DB_CONFIG": "false", + } + + headers = jwtauthheader({"role": "postgrest_test_author"}, SECRET) + + with run(stdin=SECRET.encode(), env=env) as postgrest: + first_request = postgrest.session.get("/authors_only", headers=headers) + second_request = postgrest.session.get("/authors_only", headers=headers) + + # in this case we don't get server-timing in response headers + # so we can't compare durations, we just check if request succeeds + assert first_request.status_code == 200 and second_request.status_code == 200 + + +def test_server_timing_jwt_should_not_decrease_when_caching_disabled(defaultenv): + "assert than jwt duration should not decrease when disabled" + + env = { + **defaultenv, + "PGRST_DB_PLAN_ENABLED": "true", + "PGRST_JWT_CACHE_MAX_LIFETIME": "0", # cache disabled + "PGRST_JWT_SECRET": "@/dev/stdin", + "PGRST_DB_CONFIG": "false", + } + + headers = jwtauthheader({"role": "postgrest_test_author"}, SECRET) + + with run(stdin=SECRET.encode(), env=env) as postgrest: + warmup_req = postgrest.session.get("/authors_only", headers=headers) + first_dur_text = postgrest.session.get( + "/authors_only", headers=headers + ).headers["Server-Timing"] + second_dur_text = postgrest.session.get( + "/authors_only", headers=headers + ).headers["Server-Timing"] + + first_dur = float(first_dur_text[8:]) # skip "jwt;dur=" + second_dur = float(second_dur_text[8:]) + + # their difference should be less than 100 + # implying that token is not cached + assert (first_dur - second_dur) < 100.0 + + +def test_jwt_cache_with_no_exp_claim(defaultenv): + "assert than jwt duration should decrease" + + env = { + **defaultenv, + "PGRST_DB_PLAN_ENABLED": "true", + "PGRST_JWT_CACHE_MAX_LIFETIME": "86400", + "PGRST_JWT_SECRET": "@/dev/stdin", + "PGRST_DB_CONFIG": "false", + } + + headers = jwtauthheader({"role": "postgrest_test_author"}, SECRET) # no exp + + with run(stdin=SECRET.encode(), env=env) as postgrest: + first_dur_text = postgrest.session.get( + "/authors_only", headers=headers + ).headers["Server-Timing"] + second_dur_text = postgrest.session.get( + "/authors_only", headers=headers + ).headers["Server-Timing"] + + first_dur = float(first_dur_text[8:]) # skip "jwt;dur=" + second_dur = float(second_dur_text[8:]) + + # their difference should be less than 100 + # implying that token is not cached + assert (first_dur - second_dur) > 300.0 diff --git a/test/memory/memory-tests.sh b/test/memory/memory-tests.sh index 3002d264fc..2fb00962a3 100755 --- a/test/memory/memory-tests.sh +++ b/test/memory/memory-tests.sh @@ -102,7 +102,7 @@ postJsonArrayTest(){ echo "Running memory usage tests.." -jsonKeyTest "1M" "POST" "/rpc/leak?columns=blob" "16M" +jsonKeyTest "1M" "POST" "/rpc/leak?columns=blob" "23M" jsonKeyTest "1M" "POST" "/leak?columns=blob" "16M" jsonKeyTest "1M" "PATCH" "/leak?id=eq.1&columns=blob" "16M" diff --git a/test/spec/SpecHelper.hs b/test/spec/SpecHelper.hs index 93b8dac147..7bee04a95b 100644 --- a/test/spec/SpecHelper.hs +++ b/test/spec/SpecHelper.hs @@ -118,6 +118,7 @@ baseCfg = let secret = Just $ encodeUtf8 "reallyreallyreallyreallyverysafe" in , configJwtRoleClaimKey = [JSPKey "role"] , configJwtSecret = secret , configJwtSecretIsBase64 = False + , configJwtCacheMaxLifetime = 0 , configLogLevel = LogCrit , configOpenApiMode = OAFollowPriv , configOpenApiSecurityActive = False