Skip to content

Commit

Permalink
feat: apply all function settings as transaction-scoped settings
Browse files Browse the repository at this point in the history
  • Loading branch information
taimoorzaeem authored and steve-chavez committed Feb 9, 2024
1 parent 4958472 commit f9ee1f7
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 59 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ This project adheres to [Semantic Versioning](http://semver.org/).

- #2887, Add Preference `max-affected` to limit affected resources - @taimoorzaeem
- #3171, Add an ability to dump config via admin API - @skywriter
- #3061, Apply all function settings as transaction-scoped settings - @taimoorzaeem

### Fixed

Expand Down
18 changes: 9 additions & 9 deletions src/PostgREST/App.hs
Original file line number Diff line number Diff line change
Expand Up @@ -170,43 +170,43 @@ handleRequest AuthResult{..} conf appState authenticated prepared pgVer apiReq@A
case (iAction, iTarget) of
(ActionRead headersOnly, TargetIdent identifier) -> do
(planTime', wrPlan) <- withTiming $ liftEither $ Plan.wrappedReadPlan identifier conf sCache apiReq
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.wrTxMode wrPlan) $ Query.readQuery wrPlan conf apiReq
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl mempty (Plan.wrTxMode wrPlan) $ Query.readQuery wrPlan conf apiReq
(respTime', pgrst) <- withTiming $ liftEither $ Response.readResponse wrPlan headersOnly identifier apiReq resultSet
return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst

(ActionMutate MutationCreate, TargetIdent identifier) -> do
(planTime', mrPlan) <- withTiming $ liftEither $ Plan.mutateReadPlan MutationCreate apiReq identifier conf sCache
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.mrTxMode mrPlan) $ Query.createQuery mrPlan apiReq conf
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl mempty (Plan.mrTxMode mrPlan) $ Query.createQuery mrPlan apiReq conf
(respTime', pgrst) <- withTiming $ liftEither $ Response.createResponse identifier mrPlan apiReq resultSet
return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst

(ActionMutate MutationUpdate, TargetIdent identifier) -> do
(planTime', mrPlan) <- withTiming $ liftEither $ Plan.mutateReadPlan MutationUpdate apiReq identifier conf sCache
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.mrTxMode mrPlan) $ Query.updateQuery mrPlan apiReq conf
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl mempty (Plan.mrTxMode mrPlan) $ Query.updateQuery mrPlan apiReq conf
(respTime', pgrst) <- withTiming $ liftEither $ Response.updateResponse mrPlan apiReq resultSet
return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst

(ActionMutate MutationSingleUpsert, TargetIdent identifier) -> do
(planTime', mrPlan) <- withTiming $ liftEither $ Plan.mutateReadPlan MutationSingleUpsert apiReq identifier conf sCache
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.mrTxMode mrPlan) $ Query.singleUpsertQuery mrPlan apiReq conf
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl mempty (Plan.mrTxMode mrPlan) $ Query.singleUpsertQuery mrPlan apiReq conf
(respTime', pgrst) <- withTiming $ liftEither $ Response.singleUpsertResponse mrPlan apiReq resultSet
return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst

(ActionMutate MutationDelete, TargetIdent identifier) -> do
(planTime', mrPlan) <- withTiming $ liftEither $ Plan.mutateReadPlan MutationDelete apiReq identifier conf sCache
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.mrTxMode mrPlan) $ Query.deleteQuery mrPlan apiReq conf
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl mempty (Plan.mrTxMode mrPlan) $ Query.deleteQuery mrPlan apiReq conf
(respTime', pgrst) <- withTiming $ liftEither $ Response.deleteResponse mrPlan apiReq resultSet
return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst

(ActionInvoke invMethod, TargetProc identifier _) -> do
(planTime', cPlan) <- withTiming $ liftEither $ Plan.callReadPlan identifier conf sCache apiReq invMethod
(txTime', resultSet) <- withTiming $ runQuery (fromMaybe roleIsoLvl $ pdIsoLvl (Plan.crProc cPlan)) (pdTimeout $ Plan.crProc cPlan) (Plan.crTxMode cPlan) $ Query.invokeQuery (Plan.crProc cPlan) cPlan apiReq conf pgVer
(txTime', resultSet) <- withTiming $ runQuery (fromMaybe roleIsoLvl $ pdIsoLvl (Plan.crProc cPlan)) (pdFuncSettings $ Plan.crProc cPlan) (Plan.crTxMode cPlan) $ Query.invokeQuery (Plan.crProc cPlan) cPlan apiReq conf pgVer
(respTime', pgrst) <- withTiming $ liftEither $ Response.invokeResponse cPlan invMethod (Plan.crProc cPlan) apiReq resultSet
return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst

(ActionInspect headersOnly, TargetDefaultSpec tSchema) -> do
(planTime', iPlan) <- withTiming $ liftEither $ Plan.inspectPlan apiReq
(txTime', oaiResult) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.ipTxmode iPlan) $ Query.openApiQuery sCache pgVer conf tSchema
(txTime', oaiResult) <- withTiming $ runQuery roleIsoLvl mempty (Plan.ipTxmode iPlan) $ Query.openApiQuery sCache pgVer conf tSchema
(respTime', pgrst) <- withTiming $ liftEither $ Response.openApiResponse (T.decodeUtf8 prettyVersion, docsVersion) headersOnly oaiResult conf sCache iSchema iNegotiatedByProfile
return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst

Expand All @@ -230,9 +230,9 @@ handleRequest AuthResult{..} conf appState authenticated prepared pgVer apiReq@A
where
roleSettings = fromMaybe mempty (HM.lookup authRole $ configRoleSettings conf)
roleIsoLvl = HM.findWithDefault SQL.ReadCommitted authRole $ configRoleIsoLvl conf
runQuery isoLvl timeout mode query =
runQuery isoLvl funcSets mode query =
runDbHandler appState conf isoLvl mode authenticated prepared $ do
Query.setPgLocals conf authClaims authRole (HM.toList roleSettings) apiReq timeout
Query.setPgLocals conf authClaims authRole (HM.toList roleSettings) funcSets apiReq
Query.runPreReq conf
query

Expand Down
8 changes: 4 additions & 4 deletions src/PostgREST/Query.hs
Original file line number Diff line number Diff line change
Expand Up @@ -247,12 +247,12 @@ optionalRollback AppConfig{..} ApiRequest{iPreferences=Preferences{..}} = do

-- | Set transaction scoped settings
setPgLocals :: AppConfig -> KM.KeyMap JSON.Value -> BS.ByteString -> [(ByteString, ByteString)] ->
ApiRequest -> Maybe Text -> DbHandler ()
setPgLocals AppConfig{..} claims role roleSettings ApiRequest{..} tout = lift $
[(Text,Text)] -> ApiRequest -> DbHandler ()
setPgLocals AppConfig{..} claims role roleSettings funcSettings ApiRequest{..} = lift $
SQL.statement mempty $ SQL.dynamicallyParameterized
-- To ensure `GRANT SET ON PARAMETER <superuser_setting> TO authenticator` works, the role settings must be set before the impersonated role.
-- Otherwise the GRANT SET would have to be applied to the impersonated role. See https://github.com/PostgREST/postgrest/issues/3045
("select " <> intercalateSnippet ", " (searchPathSql : roleSettingsSql ++ roleSql ++ claimsSql ++ [methodSql, pathSql] ++ headersSql ++ cookiesSql ++ timezoneSql ++ timeoutSql ++ appSettingsSql))
("select " <> intercalateSnippet ", " (searchPathSql : roleSettingsSql ++ roleSql ++ claimsSql ++ [methodSql, pathSql] ++ headersSql ++ cookiesSql ++ timezoneSql ++ funcSettingsSql ++ appSettingsSql))
HD.noResult configDbPreparedStatements
where
methodSql = setConfigWithConstantName ("request.method", iMethod)
Expand All @@ -264,7 +264,7 @@ setPgLocals AppConfig{..} claims role roleSettings ApiRequest{..} tout = lift $
roleSettingsSql = setConfigWithDynamicName <$> roleSettings
appSettingsSql = setConfigWithDynamicName <$> (join bimap toUtf8 <$> configAppSettings)
timezoneSql = maybe mempty (\(PreferTimezone tz) -> [setConfigWithConstantName ("timezone", tz)]) $ preferTimezone iPreferences
timeoutSql = maybe mempty ((\t -> [setConfigWithConstantName ("statement_timeout", t)]) . encodeUtf8) tout
funcSettingsSql = setConfigWithDynamicName <$> (join bimap toUtf8 <$> funcSettings)
searchPathSql =
let schemas = escapeIdentList (iSchema : configDbExtraSearchPath) in
setConfigWithConstantName ("search_path", schemas)
Expand Down
15 changes: 11 additions & 4 deletions src/PostgREST/SchemaCache.hs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ import qualified PostgREST.MediaType as MediaType

import Protolude


data SchemaCache = SchemaCache
{ dbTables :: TablesMap
, dbRelationships :: RelationshipsMap
Expand Down Expand Up @@ -297,7 +296,7 @@ decodeFuncs =
<*> (parseVolatility <$> column HD.char)
<*> column HD.bool
<*> nullableColumn (toIsolationLevel <$> HD.text)
<*> nullableColumn HD.text
<*> compositeArrayColumn ((,) <$> compositeField HD.text <*> compositeField HD.text) -- function setting

addKey :: Routine -> (QualifiedIdentifier, Routine)
addKey pd = (QualifiedIdentifier (pdSchema pd) (pdName pd), pd)
Expand Down Expand Up @@ -432,7 +431,7 @@ funcsSqlQuery pgVer = [q|
p.provolatile,
p.provariadic > 0 as hasvariadic,
lower((regexp_split_to_array((regexp_split_to_array(iso_config, '='))[2], ','))[1]) AS transaction_isolation_level,
lower((regexp_split_to_array((regexp_split_to_array(timeout_config, '='))[2], ','))[1]) AS statement_timeout
coalesce(func_settings.kvs, '{}') as kvs
FROM pg_proc p
LEFT JOIN arguments a ON a.oid = p.oid
JOIN pg_namespace pn ON pn.oid = p.pronamespace
Expand All @@ -442,7 +441,15 @@ funcsSqlQuery pgVer = [q|
LEFT JOIN pg_class comp ON comp.oid = t.typrelid
LEFT JOIN pg_description as d ON d.objoid = p.oid
LEFT JOIN LATERAL unnest(proconfig) iso_config ON iso_config like 'default_transaction_isolation%'
LEFT JOIN LATERAL unnest(proconfig) timeout_config ON timeout_config like 'statement_timeout%'
LEFT JOIN LATERAL (
SELECT
array_agg(row(
substr(setting, 1, strpos(setting, '=') - 1),
lower(substr(setting, strpos(setting, '=') + 1))
)) as kvs
FROM unnest(proconfig) setting
WHERE setting not LIKE 'default_transaction_isolation%'
) func_settings ON TRUE
WHERE t.oid <> 'trigger'::regtype AND COALESCE(a.callable, true)
|] <> (if pgVer >= pgVersion110 then "AND prokind = 'f'" else "AND NOT (proisagg OR proiswindow)")

Expand Down
43 changes: 23 additions & 20 deletions src/PostgREST/SchemaCache/Routine.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ module PostgREST.SchemaCache.Routine
, Routine(..)
, RoutineParam(..)
, FuncVolatility(..)
, FuncSettings
, RoutineMap
, RetType(..)
, funcReturnsScalar
Expand Down Expand Up @@ -49,30 +50,32 @@ data FuncVolatility
| Immutable
deriving (Eq, Show, Ord, Generic, JSON.ToJSON)

type FuncSettings = [(Text,Text)]

data Routine = Function
{ pdSchema :: Schema
, pdName :: Text
, pdDescription :: Maybe Text
, pdParams :: [RoutineParam]
, pdReturnType :: RetType
, pdVolatility :: FuncVolatility
, pdHasVariadic :: Bool
, pdIsoLvl :: Maybe SQL.IsolationLevel
, pdTimeout :: Maybe Text
{ pdSchema :: Schema
, pdName :: Text
, pdDescription :: Maybe Text
, pdParams :: [RoutineParam]
, pdReturnType :: RetType
, pdVolatility :: FuncVolatility
, pdHasVariadic :: Bool
, pdIsoLvl :: Maybe SQL.IsolationLevel
, pdFuncSettings :: FuncSettings
}
deriving (Eq, Show, Generic)
-- need to define JSON manually bc SQL.IsolationLevel doesn't have a JSON instance(and we can't define one for that type without getting a compiler error)
instance JSON.ToJSON Routine where
toJSON (Function sch nam desc params ret vol hasVar _ tout) = JSON.object
toJSON (Function sch nam desc params ret vol hasVar _ sets) = JSON.object
[
"pdSchema" .= sch
, "pdName" .= nam
, "pdDescription" .= desc
, "pdParams" .= JSON.toJSON params
, "pdReturnType" .= JSON.toJSON ret
, "pdVolatility" .= JSON.toJSON vol
, "pdHasVariadic" .= JSON.toJSON hasVar
, "pdTimeout" .= tout
"pdSchema" .= sch
, "pdName" .= nam
, "pdDescription" .= desc
, "pdParams" .= JSON.toJSON params
, "pdReturnType" .= JSON.toJSON ret
, "pdVolatility" .= JSON.toJSON vol
, "pdHasVariadic" .= JSON.toJSON hasVar
, "pdFuncSettings" .= JSON.toJSON sets
]

data RoutineParam = RoutineParam
Expand All @@ -86,10 +89,10 @@ data RoutineParam = RoutineParam

-- Order by least number of params in the case of overloaded functions
instance Ord Routine where
Function schema1 name1 des1 prms1 rt1 vol1 hasVar1 iso1 tout1 `compare` Function schema2 name2 des2 prms2 rt2 vol2 hasVar2 iso2 tout2
Function schema1 name1 des1 prms1 rt1 vol1 hasVar1 iso1 sets1 `compare` Function schema2 name2 des2 prms2 rt2 vol2 hasVar2 iso2 sets2
| schema1 == schema2 && name1 == name2 && length prms1 < length prms2 = LT
| schema2 == schema2 && name1 == name2 && length prms1 > length prms2 = GT
| otherwise = (schema1, name1, des1, prms1, rt1, vol1, hasVar1, iso1, tout1) `compare` (schema2, name2, des2, prms2, rt2, vol2, hasVar2, iso2, tout2)
| otherwise = (schema1, name1, des1, prms1, rt1, vol1, hasVar1, iso1, sets1) `compare` (schema2, name2, des2, prms2, rt2, vol2, hasVar2, iso2, sets2)

-- | A map of all procs, all of which can be overloaded(one entry will have more than one Routine).
-- | It uses a HashMap for a faster lookup.
Expand Down
11 changes: 11 additions & 0 deletions test/io/fixtures.sql
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,14 @@ $$ language sql set statement_timeout = '4s';
create function get_postgres_version() returns int as $$
select current_setting('server_version_num')::int;
$$ language sql;

create or replace function work_mem_test() returns text as $$
select current_setting('work_mem',false);
$$ language sql set work_mem = '6000';

create or replace function multiple_func_settings_test() returns setof record as $$
select current_setting('work_mem',false) as work_mem,
current_setting('statement_timeout',false) as statement_timeout;
$$ language sql
set work_mem = '5000'
set statement_timeout = '10s';
62 changes: 40 additions & 22 deletions test/io/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1337,28 +1337,6 @@ def test_no_preflight_request_with_CORS_config_should_not_return_header(defaulte
assert "Access-Control-Allow-Origin" not in response.headers


def test_fail_with_3_sec_statement_and_1_sec_statement_timeout(defaultenv):
"statement that takes three seconds to execute should fail with one second timeout"

with run(env=defaultenv) as postgrest:
response = postgrest.session.post("/rpc/one_sec_timeout")

assert response.status_code == 500
assert (
response.text
== '{"code":"57014","details":null,"hint":null,"message":"canceling statement due to statement timeout"}'
)


def test_passes_with_3_sec_statement_and_4_sec_statement_timeout(defaultenv):
"statement that takes three seconds to execute should succeed with four second timeout"

with run(env=defaultenv) as postgrest:
response = postgrest.session.post("/rpc/four_sec_timeout")

assert response.status_code == 204


@pytest.mark.parametrize("level", ["crit", "error", "warn", "info"])
def test_db_error_logging_to_stderr(level, defaultenv, metapostgrest):
"verify that DB errors are logged to stderr"
Expand All @@ -1385,3 +1363,43 @@ def test_db_error_logging_to_stderr(level, defaultenv, metapostgrest):
else:
assert " 500 " in output[0]
assert "canceling statement due to statement timeout" in output[1]


def test_function_setting_statement_timeout_fails(defaultenv):
"statement that takes three seconds to execute should fail with one second timeout"

with run(env=defaultenv) as postgrest:
response = postgrest.session.post("/rpc/one_sec_timeout")

assert response.status_code == 500
assert (
response.text
== '{"code":"57014","details":null,"hint":null,"message":"canceling statement due to statement timeout"}'
)


def test_function_setting_statement_timeout_passes(defaultenv):
"statement that takes three seconds to execute should succeed with four second timeout"

with run(env=defaultenv) as postgrest:
response = postgrest.session.post("/rpc/four_sec_timeout")

assert response.status_code == 204


def test_function_setting_work_mem(defaultenv):
"check function setting work_mem is applied"

with run(env=defaultenv) as postgrest:
response = postgrest.session.post("/rpc/work_mem_test")

assert response.text == '"6000kB"'


def test_multiple_func_settings(defaultenv):
"check multiple function settings are applied"

with run(env=defaultenv) as postgrest:
response = postgrest.session.post("/rpc/multiple_func_settings_test")

assert response.text == '[{"work_mem":"5000kB","statement_timeout":"10s"}]'

0 comments on commit f9ee1f7

Please sign in to comment.