Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: apply all function settings as transaction-scoped settings #3142

Merged
merged 1 commit into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ This project adheres to [Semantic Versioning](http://semver.org/).

- #2887, Add Preference `max-affected` to limit affected resources - @taimoorzaeem
- #3171, Add an ability to dump config via admin API - @skywriter
- #3061, Apply all function settings as transaction-scoped settings - @taimoorzaeem

### Fixed

Expand Down
18 changes: 9 additions & 9 deletions src/PostgREST/App.hs
Original file line number Diff line number Diff line change
Expand Up @@ -170,43 +170,43 @@ handleRequest AuthResult{..} conf appState authenticated prepared pgVer apiReq@A
case (iAction, iTarget) of
(ActionRead headersOnly, TargetIdent identifier) -> do
(planTime', wrPlan) <- withTiming $ liftEither $ Plan.wrappedReadPlan identifier conf sCache apiReq
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.wrTxMode wrPlan) $ Query.readQuery wrPlan conf apiReq
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl mempty (Plan.wrTxMode wrPlan) $ Query.readQuery wrPlan conf apiReq
(respTime', pgrst) <- withTiming $ liftEither $ Response.readResponse wrPlan headersOnly identifier apiReq resultSet
return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst

(ActionMutate MutationCreate, TargetIdent identifier) -> do
(planTime', mrPlan) <- withTiming $ liftEither $ Plan.mutateReadPlan MutationCreate apiReq identifier conf sCache
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.mrTxMode mrPlan) $ Query.createQuery mrPlan apiReq conf
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl mempty (Plan.mrTxMode mrPlan) $ Query.createQuery mrPlan apiReq conf
(respTime', pgrst) <- withTiming $ liftEither $ Response.createResponse identifier mrPlan apiReq resultSet
return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst

(ActionMutate MutationUpdate, TargetIdent identifier) -> do
(planTime', mrPlan) <- withTiming $ liftEither $ Plan.mutateReadPlan MutationUpdate apiReq identifier conf sCache
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.mrTxMode mrPlan) $ Query.updateQuery mrPlan apiReq conf
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl mempty (Plan.mrTxMode mrPlan) $ Query.updateQuery mrPlan apiReq conf
(respTime', pgrst) <- withTiming $ liftEither $ Response.updateResponse mrPlan apiReq resultSet
return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst

(ActionMutate MutationSingleUpsert, TargetIdent identifier) -> do
(planTime', mrPlan) <- withTiming $ liftEither $ Plan.mutateReadPlan MutationSingleUpsert apiReq identifier conf sCache
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.mrTxMode mrPlan) $ Query.singleUpsertQuery mrPlan apiReq conf
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl mempty (Plan.mrTxMode mrPlan) $ Query.singleUpsertQuery mrPlan apiReq conf
(respTime', pgrst) <- withTiming $ liftEither $ Response.singleUpsertResponse mrPlan apiReq resultSet
return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst

(ActionMutate MutationDelete, TargetIdent identifier) -> do
(planTime', mrPlan) <- withTiming $ liftEither $ Plan.mutateReadPlan MutationDelete apiReq identifier conf sCache
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.mrTxMode mrPlan) $ Query.deleteQuery mrPlan apiReq conf
(txTime', resultSet) <- withTiming $ runQuery roleIsoLvl mempty (Plan.mrTxMode mrPlan) $ Query.deleteQuery mrPlan apiReq conf
(respTime', pgrst) <- withTiming $ liftEither $ Response.deleteResponse mrPlan apiReq resultSet
return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst

(ActionInvoke invMethod, TargetProc identifier _) -> do
(planTime', cPlan) <- withTiming $ liftEither $ Plan.callReadPlan identifier conf sCache apiReq invMethod
(txTime', resultSet) <- withTiming $ runQuery (fromMaybe roleIsoLvl $ pdIsoLvl (Plan.crProc cPlan)) (pdTimeout $ Plan.crProc cPlan) (Plan.crTxMode cPlan) $ Query.invokeQuery (Plan.crProc cPlan) cPlan apiReq conf pgVer
(txTime', resultSet) <- withTiming $ runQuery (fromMaybe roleIsoLvl $ pdIsoLvl (Plan.crProc cPlan)) (pdFuncSettings $ Plan.crProc cPlan) (Plan.crTxMode cPlan) $ Query.invokeQuery (Plan.crProc cPlan) cPlan apiReq conf pgVer
(respTime', pgrst) <- withTiming $ liftEither $ Response.invokeResponse cPlan invMethod (Plan.crProc cPlan) apiReq resultSet
return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst

(ActionInspect headersOnly, TargetDefaultSpec tSchema) -> do
(planTime', iPlan) <- withTiming $ liftEither $ Plan.inspectPlan apiReq
(txTime', oaiResult) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.ipTxmode iPlan) $ Query.openApiQuery sCache pgVer conf tSchema
(txTime', oaiResult) <- withTiming $ runQuery roleIsoLvl mempty (Plan.ipTxmode iPlan) $ Query.openApiQuery sCache pgVer conf tSchema
(respTime', pgrst) <- withTiming $ liftEither $ Response.openApiResponse (T.decodeUtf8 prettyVersion, docsVersion) headersOnly oaiResult conf sCache iSchema iNegotiatedByProfile
return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst

Expand All @@ -230,9 +230,9 @@ handleRequest AuthResult{..} conf appState authenticated prepared pgVer apiReq@A
where
roleSettings = fromMaybe mempty (HM.lookup authRole $ configRoleSettings conf)
roleIsoLvl = HM.findWithDefault SQL.ReadCommitted authRole $ configRoleIsoLvl conf
runQuery isoLvl timeout mode query =
runQuery isoLvl funcSets mode query =
runDbHandler appState conf isoLvl mode authenticated prepared $ do
Query.setPgLocals conf authClaims authRole (HM.toList roleSettings) apiReq timeout
Query.setPgLocals conf authClaims authRole (HM.toList roleSettings) funcSets apiReq
Query.runPreReq conf
query

Expand Down
8 changes: 4 additions & 4 deletions src/PostgREST/Query.hs
Original file line number Diff line number Diff line change
Expand Up @@ -247,12 +247,12 @@ optionalRollback AppConfig{..} ApiRequest{iPreferences=Preferences{..}} = do

-- | Set transaction scoped settings
setPgLocals :: AppConfig -> KM.KeyMap JSON.Value -> BS.ByteString -> [(ByteString, ByteString)] ->
ApiRequest -> Maybe Text -> DbHandler ()
setPgLocals AppConfig{..} claims role roleSettings ApiRequest{..} tout = lift $
[(Text,Text)] -> ApiRequest -> DbHandler ()
setPgLocals AppConfig{..} claims role roleSettings funcSettings ApiRequest{..} = lift $
SQL.statement mempty $ SQL.dynamicallyParameterized
-- To ensure `GRANT SET ON PARAMETER <superuser_setting> TO authenticator` works, the role settings must be set before the impersonated role.
-- Otherwise the GRANT SET would have to be applied to the impersonated role. See https://github.com/PostgREST/postgrest/issues/3045
("select " <> intercalateSnippet ", " (searchPathSql : roleSettingsSql ++ roleSql ++ claimsSql ++ [methodSql, pathSql] ++ headersSql ++ cookiesSql ++ timezoneSql ++ timeoutSql ++ appSettingsSql))
("select " <> intercalateSnippet ", " (searchPathSql : roleSettingsSql ++ roleSql ++ claimsSql ++ [methodSql, pathSql] ++ headersSql ++ cookiesSql ++ timezoneSql ++ funcSettingsSql ++ appSettingsSql))
HD.noResult configDbPreparedStatements
where
methodSql = setConfigWithConstantName ("request.method", iMethod)
Expand All @@ -264,7 +264,7 @@ setPgLocals AppConfig{..} claims role roleSettings ApiRequest{..} tout = lift $
roleSettingsSql = setConfigWithDynamicName <$> roleSettings
appSettingsSql = setConfigWithDynamicName <$> (join bimap toUtf8 <$> configAppSettings)
timezoneSql = maybe mempty (\(PreferTimezone tz) -> [setConfigWithConstantName ("timezone", tz)]) $ preferTimezone iPreferences
timeoutSql = maybe mempty ((\t -> [setConfigWithConstantName ("statement_timeout", t)]) . encodeUtf8) tout
funcSettingsSql = setConfigWithDynamicName <$> (join bimap toUtf8 <$> funcSettings)
searchPathSql =
let schemas = escapeIdentList (iSchema : configDbExtraSearchPath) in
setConfigWithConstantName ("search_path", schemas)
Expand Down
15 changes: 11 additions & 4 deletions src/PostgREST/SchemaCache.hs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ import qualified PostgREST.MediaType as MediaType

import Protolude


data SchemaCache = SchemaCache
{ dbTables :: TablesMap
, dbRelationships :: RelationshipsMap
Expand Down Expand Up @@ -297,7 +296,7 @@ decodeFuncs =
<*> (parseVolatility <$> column HD.char)
<*> column HD.bool
<*> nullableColumn (toIsolationLevel <$> HD.text)
<*> nullableColumn HD.text
<*> compositeArrayColumn ((,) <$> compositeField HD.text <*> compositeField HD.text) -- function setting

addKey :: Routine -> (QualifiedIdentifier, Routine)
addKey pd = (QualifiedIdentifier (pdSchema pd) (pdName pd), pd)
Expand Down Expand Up @@ -432,7 +431,7 @@ funcsSqlQuery pgVer = [q|
p.provolatile,
p.provariadic > 0 as hasvariadic,
lower((regexp_split_to_array((regexp_split_to_array(iso_config, '='))[2], ','))[1]) AS transaction_isolation_level,
lower((regexp_split_to_array((regexp_split_to_array(timeout_config, '='))[2], ','))[1]) AS statement_timeout
coalesce(func_settings.kvs, '{}') as kvs
FROM pg_proc p
LEFT JOIN arguments a ON a.oid = p.oid
JOIN pg_namespace pn ON pn.oid = p.pronamespace
Expand All @@ -442,7 +441,15 @@ funcsSqlQuery pgVer = [q|
LEFT JOIN pg_class comp ON comp.oid = t.typrelid
LEFT JOIN pg_description as d ON d.objoid = p.oid
LEFT JOIN LATERAL unnest(proconfig) iso_config ON iso_config like 'default_transaction_isolation%'
LEFT JOIN LATERAL unnest(proconfig) timeout_config ON timeout_config like 'statement_timeout%'
Copy link
Member

@steve-chavez steve-chavez Feb 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use this query here:

  LEFT JOIN LATERAL (
    SELECT
      array_agg(row(
        substr(setting, 1, strpos(setting, '=') - 1),
        lower(substr(setting, strpos(setting, '=') + 1))
      )) as kvs
    FROM unnest(proconfig) setting
    WHERE setting not like 'default_transaction_isolation%'
  ) func_settings ON TRUE

This will output the settings like:

-[ RECORD 14 ]--------------+--------------------------------------------------------------
proc_schema                 | public
proc_name                   | four_sec_timeout
kvs                         | {"(statement_timeout,4s)","(log_min_duration_sample,5s)"}

LEFT JOIN LATERAL (
SELECT
array_agg(row(
substr(setting, 1, strpos(setting, '=') - 1),
lower(substr(setting, strpos(setting, '=') + 1))
)) as kvs
FROM unnest(proconfig) setting
WHERE setting not LIKE 'default_transaction_isolation%'
) func_settings ON TRUE
WHERE t.oid <> 'trigger'::regtype AND COALESCE(a.callable, true)
|] <> (if pgVer >= pgVersion110 then "AND prokind = 'f'" else "AND NOT (proisagg OR proiswindow)")

Expand Down
43 changes: 23 additions & 20 deletions src/PostgREST/SchemaCache/Routine.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ module PostgREST.SchemaCache.Routine
, Routine(..)
, RoutineParam(..)
, FuncVolatility(..)
, FuncSettings
, RoutineMap
, RetType(..)
, funcReturnsScalar
Expand Down Expand Up @@ -49,30 +50,32 @@ data FuncVolatility
| Immutable
deriving (Eq, Show, Ord, Generic, JSON.ToJSON)

type FuncSettings = [(Text,Text)]

data Routine = Function
{ pdSchema :: Schema
, pdName :: Text
, pdDescription :: Maybe Text
, pdParams :: [RoutineParam]
, pdReturnType :: RetType
, pdVolatility :: FuncVolatility
, pdHasVariadic :: Bool
, pdIsoLvl :: Maybe SQL.IsolationLevel
, pdTimeout :: Maybe Text
{ pdSchema :: Schema
, pdName :: Text
, pdDescription :: Maybe Text
, pdParams :: [RoutineParam]
, pdReturnType :: RetType
, pdVolatility :: FuncVolatility
, pdHasVariadic :: Bool
, pdIsoLvl :: Maybe SQL.IsolationLevel
, pdFuncSettings :: FuncSettings
}
deriving (Eq, Show, Generic)
-- need to define JSON manually bc SQL.IsolationLevel doesn't have a JSON instance(and we can't define one for that type without getting a compiler error)
instance JSON.ToJSON Routine where
toJSON (Function sch nam desc params ret vol hasVar _ tout) = JSON.object
toJSON (Function sch nam desc params ret vol hasVar _ sets) = JSON.object
[
"pdSchema" .= sch
, "pdName" .= nam
, "pdDescription" .= desc
, "pdParams" .= JSON.toJSON params
, "pdReturnType" .= JSON.toJSON ret
, "pdVolatility" .= JSON.toJSON vol
, "pdHasVariadic" .= JSON.toJSON hasVar
, "pdTimeout" .= tout
"pdSchema" .= sch
, "pdName" .= nam
, "pdDescription" .= desc
, "pdParams" .= JSON.toJSON params
, "pdReturnType" .= JSON.toJSON ret
, "pdVolatility" .= JSON.toJSON vol
, "pdHasVariadic" .= JSON.toJSON hasVar
, "pdFuncSettings" .= JSON.toJSON sets
]

data RoutineParam = RoutineParam
Expand All @@ -86,10 +89,10 @@ data RoutineParam = RoutineParam

-- Order by least number of params in the case of overloaded functions
instance Ord Routine where
Function schema1 name1 des1 prms1 rt1 vol1 hasVar1 iso1 tout1 `compare` Function schema2 name2 des2 prms2 rt2 vol2 hasVar2 iso2 tout2
Function schema1 name1 des1 prms1 rt1 vol1 hasVar1 iso1 sets1 `compare` Function schema2 name2 des2 prms2 rt2 vol2 hasVar2 iso2 sets2
| schema1 == schema2 && name1 == name2 && length prms1 < length prms2 = LT
| schema2 == schema2 && name1 == name2 && length prms1 > length prms2 = GT
| otherwise = (schema1, name1, des1, prms1, rt1, vol1, hasVar1, iso1, tout1) `compare` (schema2, name2, des2, prms2, rt2, vol2, hasVar2, iso2, tout2)
| otherwise = (schema1, name1, des1, prms1, rt1, vol1, hasVar1, iso1, sets1) `compare` (schema2, name2, des2, prms2, rt2, vol2, hasVar2, iso2, sets2)

-- | A map of all procs, all of which can be overloaded(one entry will have more than one Routine).
-- | It uses a HashMap for a faster lookup.
Expand Down
11 changes: 11 additions & 0 deletions test/io/fixtures.sql
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,14 @@ $$ language sql set statement_timeout = '4s';
create function get_postgres_version() returns int as $$
select current_setting('server_version_num')::int;
$$ language sql;

create or replace function work_mem_test() returns text as $$
select current_setting('work_mem',false);
$$ language sql set work_mem = '6000';

create or replace function multiple_func_settings_test() returns setof record as $$
select current_setting('work_mem',false) as work_mem,
current_setting('statement_timeout',false) as statement_timeout;
$$ language sql
set work_mem = '5000'
set statement_timeout = '10s';
62 changes: 40 additions & 22 deletions test/io/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1337,28 +1337,6 @@ def test_no_preflight_request_with_CORS_config_should_not_return_header(defaulte
assert "Access-Control-Allow-Origin" not in response.headers


def test_fail_with_3_sec_statement_and_1_sec_statement_timeout(defaultenv):
"statement that takes three seconds to execute should fail with one second timeout"

with run(env=defaultenv) as postgrest:
response = postgrest.session.post("/rpc/one_sec_timeout")

assert response.status_code == 500
assert (
response.text
== '{"code":"57014","details":null,"hint":null,"message":"canceling statement due to statement timeout"}'
)


def test_passes_with_3_sec_statement_and_4_sec_statement_timeout(defaultenv):
"statement that takes three seconds to execute should succeed with four second timeout"

with run(env=defaultenv) as postgrest:
response = postgrest.session.post("/rpc/four_sec_timeout")

assert response.status_code == 204


@pytest.mark.parametrize("level", ["crit", "error", "warn", "info"])
def test_db_error_logging_to_stderr(level, defaultenv, metapostgrest):
"verify that DB errors are logged to stderr"
Expand All @@ -1385,3 +1363,43 @@ def test_db_error_logging_to_stderr(level, defaultenv, metapostgrest):
else:
assert " 500 " in output[0]
assert "canceling statement due to statement timeout" in output[1]


def test_function_setting_statement_timeout_fails(defaultenv):
"statement that takes three seconds to execute should fail with one second timeout"

with run(env=defaultenv) as postgrest:
response = postgrest.session.post("/rpc/one_sec_timeout")

assert response.status_code == 500
assert (
response.text
== '{"code":"57014","details":null,"hint":null,"message":"canceling statement due to statement timeout"}'
)


def test_function_setting_statement_timeout_passes(defaultenv):
"statement that takes three seconds to execute should succeed with four second timeout"

with run(env=defaultenv) as postgrest:
response = postgrest.session.post("/rpc/four_sec_timeout")

assert response.status_code == 204


def test_function_setting_work_mem(defaultenv):
"check function setting work_mem is applied"

with run(env=defaultenv) as postgrest:
response = postgrest.session.post("/rpc/work_mem_test")

assert response.text == '"6000kB"'


def test_multiple_func_settings(defaultenv):
"check multiple function settings are applied"

with run(env=defaultenv) as postgrest:
response = postgrest.session.post("/rpc/multiple_func_settings_test")

assert response.text == '[{"work_mem":"5000kB","statement_timeout":"10s"}]'
Loading