diff --git a/CHANGELOG.md b/CHANGELOG.md index 788f126031..2f6862480d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ This project adheres to [Semantic Versioning](http://semver.org/). - #2887, Add Preference `max-affected` to limit affected resources - @taimoorzaeem - #3171, Add an ability to dump config via admin API - @skywriter + - #3061, Apply all function settings as transaction-scoped settings - @taimoorzaeem ### Fixed diff --git a/src/PostgREST/App.hs b/src/PostgREST/App.hs index 3c3b187707..2ac3f34b9b 100644 --- a/src/PostgREST/App.hs +++ b/src/PostgREST/App.hs @@ -170,43 +170,43 @@ handleRequest AuthResult{..} conf appState authenticated prepared pgVer apiReq@A case (iAction, iTarget) of (ActionRead headersOnly, TargetIdent identifier) -> do (planTime', wrPlan) <- withTiming $ liftEither $ Plan.wrappedReadPlan identifier conf sCache apiReq - (txTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.wrTxMode wrPlan) $ Query.readQuery wrPlan conf apiReq + (txTime', resultSet) <- withTiming $ runQuery roleIsoLvl mempty (Plan.wrTxMode wrPlan) $ Query.readQuery wrPlan conf apiReq (respTime', pgrst) <- withTiming $ liftEither $ Response.readResponse wrPlan headersOnly identifier apiReq resultSet return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst (ActionMutate MutationCreate, TargetIdent identifier) -> do (planTime', mrPlan) <- withTiming $ liftEither $ Plan.mutateReadPlan MutationCreate apiReq identifier conf sCache - (txTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.mrTxMode mrPlan) $ Query.createQuery mrPlan apiReq conf + (txTime', resultSet) <- withTiming $ runQuery roleIsoLvl mempty (Plan.mrTxMode mrPlan) $ Query.createQuery mrPlan apiReq conf (respTime', pgrst) <- withTiming $ liftEither $ Response.createResponse identifier mrPlan apiReq resultSet return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst (ActionMutate MutationUpdate, TargetIdent identifier) -> do (planTime', mrPlan) <- withTiming $ liftEither $ Plan.mutateReadPlan MutationUpdate apiReq identifier conf sCache - (txTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.mrTxMode mrPlan) $ Query.updateQuery mrPlan apiReq conf + (txTime', resultSet) <- withTiming $ runQuery roleIsoLvl mempty (Plan.mrTxMode mrPlan) $ Query.updateQuery mrPlan apiReq conf (respTime', pgrst) <- withTiming $ liftEither $ Response.updateResponse mrPlan apiReq resultSet return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst (ActionMutate MutationSingleUpsert, TargetIdent identifier) -> do (planTime', mrPlan) <- withTiming $ liftEither $ Plan.mutateReadPlan MutationSingleUpsert apiReq identifier conf sCache - (txTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.mrTxMode mrPlan) $ Query.singleUpsertQuery mrPlan apiReq conf + (txTime', resultSet) <- withTiming $ runQuery roleIsoLvl mempty (Plan.mrTxMode mrPlan) $ Query.singleUpsertQuery mrPlan apiReq conf (respTime', pgrst) <- withTiming $ liftEither $ Response.singleUpsertResponse mrPlan apiReq resultSet return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst (ActionMutate MutationDelete, TargetIdent identifier) -> do (planTime', mrPlan) <- withTiming $ liftEither $ Plan.mutateReadPlan MutationDelete apiReq identifier conf sCache - (txTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.mrTxMode mrPlan) $ Query.deleteQuery mrPlan apiReq conf + (txTime', resultSet) <- withTiming $ runQuery roleIsoLvl mempty (Plan.mrTxMode mrPlan) $ Query.deleteQuery mrPlan apiReq conf (respTime', pgrst) <- withTiming $ liftEither $ Response.deleteResponse mrPlan apiReq resultSet return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst (ActionInvoke invMethod, TargetProc identifier _) -> do (planTime', cPlan) <- withTiming $ liftEither $ Plan.callReadPlan identifier conf sCache apiReq invMethod - (txTime', resultSet) <- withTiming $ runQuery (fromMaybe roleIsoLvl $ pdIsoLvl (Plan.crProc cPlan)) (pdTimeout $ Plan.crProc cPlan) (Plan.crTxMode cPlan) $ Query.invokeQuery (Plan.crProc cPlan) cPlan apiReq conf pgVer + (txTime', resultSet) <- withTiming $ runQuery (fromMaybe roleIsoLvl $ pdIsoLvl (Plan.crProc cPlan)) (pdFuncSettings $ Plan.crProc cPlan) (Plan.crTxMode cPlan) $ Query.invokeQuery (Plan.crProc cPlan) cPlan apiReq conf pgVer (respTime', pgrst) <- withTiming $ liftEither $ Response.invokeResponse cPlan invMethod (Plan.crProc cPlan) apiReq resultSet return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst (ActionInspect headersOnly, TargetDefaultSpec tSchema) -> do (planTime', iPlan) <- withTiming $ liftEither $ Plan.inspectPlan apiReq - (txTime', oaiResult) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.ipTxmode iPlan) $ Query.openApiQuery sCache pgVer conf tSchema + (txTime', oaiResult) <- withTiming $ runQuery roleIsoLvl mempty (Plan.ipTxmode iPlan) $ Query.openApiQuery sCache pgVer conf tSchema (respTime', pgrst) <- withTiming $ liftEither $ Response.openApiResponse (T.decodeUtf8 prettyVersion, docsVersion) headersOnly oaiResult conf sCache iSchema iNegotiatedByProfile return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst @@ -230,9 +230,9 @@ handleRequest AuthResult{..} conf appState authenticated prepared pgVer apiReq@A where roleSettings = fromMaybe mempty (HM.lookup authRole $ configRoleSettings conf) roleIsoLvl = HM.findWithDefault SQL.ReadCommitted authRole $ configRoleIsoLvl conf - runQuery isoLvl timeout mode query = + runQuery isoLvl funcSets mode query = runDbHandler appState conf isoLvl mode authenticated prepared $ do - Query.setPgLocals conf authClaims authRole (HM.toList roleSettings) apiReq timeout + Query.setPgLocals conf authClaims authRole (HM.toList roleSettings) funcSets apiReq Query.runPreReq conf query diff --git a/src/PostgREST/Query.hs b/src/PostgREST/Query.hs index 7242dce140..ceb0d9716f 100644 --- a/src/PostgREST/Query.hs +++ b/src/PostgREST/Query.hs @@ -247,12 +247,12 @@ optionalRollback AppConfig{..} ApiRequest{iPreferences=Preferences{..}} = do -- | Set transaction scoped settings setPgLocals :: AppConfig -> KM.KeyMap JSON.Value -> BS.ByteString -> [(ByteString, ByteString)] -> - ApiRequest -> Maybe Text -> DbHandler () -setPgLocals AppConfig{..} claims role roleSettings ApiRequest{..} tout = lift $ + [(Text,Text)] -> ApiRequest -> DbHandler () +setPgLocals AppConfig{..} claims role roleSettings funcSettings ApiRequest{..} = lift $ SQL.statement mempty $ SQL.dynamicallyParameterized -- To ensure `GRANT SET ON PARAMETER TO authenticator` works, the role settings must be set before the impersonated role. -- Otherwise the GRANT SET would have to be applied to the impersonated role. See https://github.com/PostgREST/postgrest/issues/3045 - ("select " <> intercalateSnippet ", " (searchPathSql : roleSettingsSql ++ roleSql ++ claimsSql ++ [methodSql, pathSql] ++ headersSql ++ cookiesSql ++ timezoneSql ++ timeoutSql ++ appSettingsSql)) + ("select " <> intercalateSnippet ", " (searchPathSql : roleSettingsSql ++ roleSql ++ claimsSql ++ [methodSql, pathSql] ++ headersSql ++ cookiesSql ++ timezoneSql ++ funcSettingsSql ++ appSettingsSql)) HD.noResult configDbPreparedStatements where methodSql = setConfigWithConstantName ("request.method", iMethod) @@ -264,7 +264,7 @@ setPgLocals AppConfig{..} claims role roleSettings ApiRequest{..} tout = lift $ roleSettingsSql = setConfigWithDynamicName <$> roleSettings appSettingsSql = setConfigWithDynamicName <$> (join bimap toUtf8 <$> configAppSettings) timezoneSql = maybe mempty (\(PreferTimezone tz) -> [setConfigWithConstantName ("timezone", tz)]) $ preferTimezone iPreferences - timeoutSql = maybe mempty ((\t -> [setConfigWithConstantName ("statement_timeout", t)]) . encodeUtf8) tout + funcSettingsSql = setConfigWithDynamicName <$> (join bimap toUtf8 <$> funcSettings) searchPathSql = let schemas = escapeIdentList (iSchema : configDbExtraSearchPath) in setConfigWithConstantName ("search_path", schemas) diff --git a/src/PostgREST/SchemaCache.hs b/src/PostgREST/SchemaCache.hs index 65d0416d37..2bfe801a43 100644 --- a/src/PostgREST/SchemaCache.hs +++ b/src/PostgREST/SchemaCache.hs @@ -74,7 +74,6 @@ import qualified PostgREST.MediaType as MediaType import Protolude - data SchemaCache = SchemaCache { dbTables :: TablesMap , dbRelationships :: RelationshipsMap @@ -297,7 +296,7 @@ decodeFuncs = <*> (parseVolatility <$> column HD.char) <*> column HD.bool <*> nullableColumn (toIsolationLevel <$> HD.text) - <*> nullableColumn HD.text + <*> compositeArrayColumn ((,) <$> compositeField HD.text <*> compositeField HD.text) -- function setting addKey :: Routine -> (QualifiedIdentifier, Routine) addKey pd = (QualifiedIdentifier (pdSchema pd) (pdName pd), pd) @@ -432,7 +431,7 @@ funcsSqlQuery pgVer = [q| p.provolatile, p.provariadic > 0 as hasvariadic, lower((regexp_split_to_array((regexp_split_to_array(iso_config, '='))[2], ','))[1]) AS transaction_isolation_level, - lower((regexp_split_to_array((regexp_split_to_array(timeout_config, '='))[2], ','))[1]) AS statement_timeout + coalesce(func_settings.kvs, '{}') as kvs FROM pg_proc p LEFT JOIN arguments a ON a.oid = p.oid JOIN pg_namespace pn ON pn.oid = p.pronamespace @@ -442,7 +441,15 @@ funcsSqlQuery pgVer = [q| LEFT JOIN pg_class comp ON comp.oid = t.typrelid LEFT JOIN pg_description as d ON d.objoid = p.oid LEFT JOIN LATERAL unnest(proconfig) iso_config ON iso_config like 'default_transaction_isolation%' - LEFT JOIN LATERAL unnest(proconfig) timeout_config ON timeout_config like 'statement_timeout%' + LEFT JOIN LATERAL ( + SELECT + array_agg(row( + substr(setting, 1, strpos(setting, '=') - 1), + lower(substr(setting, strpos(setting, '=') + 1)) + )) as kvs + FROM unnest(proconfig) setting + WHERE setting not LIKE 'default_transaction_isolation%' + ) func_settings ON TRUE WHERE t.oid <> 'trigger'::regtype AND COALESCE(a.callable, true) |] <> (if pgVer >= pgVersion110 then "AND prokind = 'f'" else "AND NOT (proisagg OR proiswindow)") diff --git a/src/PostgREST/SchemaCache/Routine.hs b/src/PostgREST/SchemaCache/Routine.hs index e84d722b9a..c01cb4829f 100644 --- a/src/PostgREST/SchemaCache/Routine.hs +++ b/src/PostgREST/SchemaCache/Routine.hs @@ -6,6 +6,7 @@ module PostgREST.SchemaCache.Routine , Routine(..) , RoutineParam(..) , FuncVolatility(..) + , FuncSettings , RoutineMap , RetType(..) , funcReturnsScalar @@ -49,30 +50,32 @@ data FuncVolatility | Immutable deriving (Eq, Show, Ord, Generic, JSON.ToJSON) +type FuncSettings = [(Text,Text)] + data Routine = Function - { pdSchema :: Schema - , pdName :: Text - , pdDescription :: Maybe Text - , pdParams :: [RoutineParam] - , pdReturnType :: RetType - , pdVolatility :: FuncVolatility - , pdHasVariadic :: Bool - , pdIsoLvl :: Maybe SQL.IsolationLevel - , pdTimeout :: Maybe Text + { pdSchema :: Schema + , pdName :: Text + , pdDescription :: Maybe Text + , pdParams :: [RoutineParam] + , pdReturnType :: RetType + , pdVolatility :: FuncVolatility + , pdHasVariadic :: Bool + , pdIsoLvl :: Maybe SQL.IsolationLevel + , pdFuncSettings :: FuncSettings } deriving (Eq, Show, Generic) -- need to define JSON manually bc SQL.IsolationLevel doesn't have a JSON instance(and we can't define one for that type without getting a compiler error) instance JSON.ToJSON Routine where - toJSON (Function sch nam desc params ret vol hasVar _ tout) = JSON.object + toJSON (Function sch nam desc params ret vol hasVar _ sets) = JSON.object [ - "pdSchema" .= sch - , "pdName" .= nam - , "pdDescription" .= desc - , "pdParams" .= JSON.toJSON params - , "pdReturnType" .= JSON.toJSON ret - , "pdVolatility" .= JSON.toJSON vol - , "pdHasVariadic" .= JSON.toJSON hasVar - , "pdTimeout" .= tout + "pdSchema" .= sch + , "pdName" .= nam + , "pdDescription" .= desc + , "pdParams" .= JSON.toJSON params + , "pdReturnType" .= JSON.toJSON ret + , "pdVolatility" .= JSON.toJSON vol + , "pdHasVariadic" .= JSON.toJSON hasVar + , "pdFuncSettings" .= JSON.toJSON sets ] data RoutineParam = RoutineParam @@ -86,10 +89,10 @@ data RoutineParam = RoutineParam -- Order by least number of params in the case of overloaded functions instance Ord Routine where - Function schema1 name1 des1 prms1 rt1 vol1 hasVar1 iso1 tout1 `compare` Function schema2 name2 des2 prms2 rt2 vol2 hasVar2 iso2 tout2 + Function schema1 name1 des1 prms1 rt1 vol1 hasVar1 iso1 sets1 `compare` Function schema2 name2 des2 prms2 rt2 vol2 hasVar2 iso2 sets2 | schema1 == schema2 && name1 == name2 && length prms1 < length prms2 = LT | schema2 == schema2 && name1 == name2 && length prms1 > length prms2 = GT - | otherwise = (schema1, name1, des1, prms1, rt1, vol1, hasVar1, iso1, tout1) `compare` (schema2, name2, des2, prms2, rt2, vol2, hasVar2, iso2, tout2) + | otherwise = (schema1, name1, des1, prms1, rt1, vol1, hasVar1, iso1, sets1) `compare` (schema2, name2, des2, prms2, rt2, vol2, hasVar2, iso2, sets2) -- | A map of all procs, all of which can be overloaded(one entry will have more than one Routine). -- | It uses a HashMap for a faster lookup. diff --git a/test/io/fixtures.sql b/test/io/fixtures.sql index bd4593ab80..f6df21710b 100644 --- a/test/io/fixtures.sql +++ b/test/io/fixtures.sql @@ -198,3 +198,14 @@ $$ language sql set statement_timeout = '4s'; create function get_postgres_version() returns int as $$ select current_setting('server_version_num')::int; $$ language sql; + +create or replace function work_mem_test() returns text as $$ + select current_setting('work_mem',false); +$$ language sql set work_mem = '6000'; + +create or replace function multiple_func_settings_test() returns setof record as $$ + select current_setting('work_mem',false) as work_mem, + current_setting('statement_timeout',false) as statement_timeout; +$$ language sql +set work_mem = '5000' +set statement_timeout = '10s'; diff --git a/test/io/test_io.py b/test/io/test_io.py index c4ce5e1edd..aebbf83503 100644 --- a/test/io/test_io.py +++ b/test/io/test_io.py @@ -1337,28 +1337,6 @@ def test_no_preflight_request_with_CORS_config_should_not_return_header(defaulte assert "Access-Control-Allow-Origin" not in response.headers -def test_fail_with_3_sec_statement_and_1_sec_statement_timeout(defaultenv): - "statement that takes three seconds to execute should fail with one second timeout" - - with run(env=defaultenv) as postgrest: - response = postgrest.session.post("/rpc/one_sec_timeout") - - assert response.status_code == 500 - assert ( - response.text - == '{"code":"57014","details":null,"hint":null,"message":"canceling statement due to statement timeout"}' - ) - - -def test_passes_with_3_sec_statement_and_4_sec_statement_timeout(defaultenv): - "statement that takes three seconds to execute should succeed with four second timeout" - - with run(env=defaultenv) as postgrest: - response = postgrest.session.post("/rpc/four_sec_timeout") - - assert response.status_code == 204 - - @pytest.mark.parametrize("level", ["crit", "error", "warn", "info"]) def test_db_error_logging_to_stderr(level, defaultenv, metapostgrest): "verify that DB errors are logged to stderr" @@ -1385,3 +1363,43 @@ def test_db_error_logging_to_stderr(level, defaultenv, metapostgrest): else: assert " 500 " in output[0] assert "canceling statement due to statement timeout" in output[1] + + +def test_function_setting_statement_timeout_fails(defaultenv): + "statement that takes three seconds to execute should fail with one second timeout" + + with run(env=defaultenv) as postgrest: + response = postgrest.session.post("/rpc/one_sec_timeout") + + assert response.status_code == 500 + assert ( + response.text + == '{"code":"57014","details":null,"hint":null,"message":"canceling statement due to statement timeout"}' + ) + + +def test_function_setting_statement_timeout_passes(defaultenv): + "statement that takes three seconds to execute should succeed with four second timeout" + + with run(env=defaultenv) as postgrest: + response = postgrest.session.post("/rpc/four_sec_timeout") + + assert response.status_code == 204 + + +def test_function_setting_work_mem(defaultenv): + "check function setting work_mem is applied" + + with run(env=defaultenv) as postgrest: + response = postgrest.session.post("/rpc/work_mem_test") + + assert response.text == '"6000kB"' + + +def test_multiple_func_settings(defaultenv): + "check multiple function settings are applied" + + with run(env=defaultenv) as postgrest: + response = postgrest.session.post("/rpc/multiple_func_settings_test") + + assert response.text == '[{"work_mem":"5000kB","statement_timeout":"10s"}]'