Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
laurenceisla committed Dec 31, 2024
1 parent 4881ae4 commit cd39da3
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 21 deletions.
13 changes: 6 additions & 7 deletions src/PostgREST/Query/QueryBuilder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ getJoinSelects (Node ReadPlan{relSelect} _) =
JsonEmbed{rsSelName, rsEmbedMode = JsonArray} ->
Just $ "COALESCE( " <> aggAlias <> "." <> aggAlias <> ", '[]') AS " <> pgFmtIdent rsSelName
Spread{rsSpreadSel, rsAggAlias} ->
Just $ intercalateSnippet ", " (pgFmtSpreadSelectItem False rsAggAlias mempty <$> rsSpreadSel)
Just $ intercalateSnippet ", " (pgFmtSpreadSelectItem rsAggAlias <$> rsSpreadSel)

getJoins :: ReadPlanTree -> [SQL.Snippet]
getJoins (Node _ []) = []
Expand All @@ -99,22 +99,21 @@ getJoin fld node@(Node ReadPlan{order, relJoinType, relSpread} _) =
(if relJoinType == Just JTInner then "INNER" else "LEFT") <> " JOIN LATERAL ( " <> sub <> " ) AS " <> al <> " ON " <> cond
subquery = readPlanToQuery node
aggAlias = pgFmtIdent $ rsAggAlias fld
selectJsonArray = "SELECT json_agg(" <> aggAlias <> ")::jsonb AS " <> aggAlias
wrapSubqAlias = " FROM (" <> subquery <> " ) AS " <> aggAlias
selectSubqAgg = "SELECT json_agg(" <> aggAlias <> ")::jsonb AS " <> aggAlias
fromSubqAgg = " FROM (" <> subquery <> " ) AS " <> aggAlias
joinCondition = if relJoinType == Just JTInner then aggAlias <> " IS NOT NULL" else "TRUE"
in
case fld of
JsonEmbed{rsEmbedMode = JsonObject} ->
correlatedSubquery subquery aggAlias "TRUE"
Spread{rsSpreadSel, rsAggAlias} ->
if relSpread == Just ToManySpread then
let
selection = selectJsonArray <> (if null rsSpreadSel then mempty else ", ") <> intercalateSnippet ", " (pgFmtSpreadSelectItem True rsAggAlias order <$> rsSpreadSel)
in correlatedSubquery (selection <> wrapSubqAlias) aggAlias joinCondition
let selSpread = selectSubqAgg <> (if null rsSpreadSel then mempty else ", ") <> intercalateSnippet ", " (pgFmtSpreadJoinSelectItem rsAggAlias order <$> rsSpreadSel)
in correlatedSubquery (selSpread <> fromSubqAgg) aggAlias joinCondition
else
correlatedSubquery subquery aggAlias "TRUE"
JsonEmbed{rsEmbedMode = JsonArray} ->
correlatedSubquery (selectJsonArray <> wrapSubqAlias) aggAlias joinCondition
correlatedSubquery (selectSubqAgg <> fromSubqAgg) aggAlias joinCondition

mutatePlanToQuery :: MutatePlan -> SQL.Snippet
mutatePlanToQuery (Insert mainQi iCols body onConflict putConditions returnings _ applyDefaults) =
Expand Down
29 changes: 16 additions & 13 deletions src/PostgREST/Query/SqlFragment.hs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ module PostgREST.Query.SqlFragment
, pgFmtOrderTerm
, pgFmtSelectItem
, pgFmtSpreadSelectItem
, pgFmtSpreadJoinSelectItem
, fromJsonBodyF
, responseHeadersF
, responseStatusF
Expand Down Expand Up @@ -271,14 +272,9 @@ pgFmtSelectItem :: QualifiedIdentifier -> CoercibleSelectField -> SQL.Snippet
pgFmtSelectItem table CoercibleSelectField{csField=fld, csAggFunction=agg, csAggCast=aggCast, csCast=cast, csAlias=alias} =
pgFmtApplyAggregate agg aggCast (pgFmtApplyCast cast (pgFmtTableCoerce table fld)) <> pgFmtAs alias

pgFmtSpreadSelectItem :: Bool -> Alias -> [CoercibleOrderTerm] -> SpreadSelectField -> SQL.Snippet
pgFmtSpreadSelectItem applyToManySpr aggAlias order SpreadSelectField{ssSelName, ssSelAggFunction, ssSelAggCast, ssSelAlias}
| applyToManySpr = pgFmtApplyToManySpreadAgg ssSelAggFunction ssSelAggCast aggAlias order fullSelName <> " AS " <> pgFmtIdent (fromMaybe ssSelName ssSelAlias)
| otherwise = pgFmtApplyAggregate ssSelAggFunction ssSelAggCast fullSelName <> pgFmtAs ssSelAlias
where
fullSelName = case ssSelName of
"*" -> pgFmtIdent aggAlias <> ".*"
_ -> pgFmtIdent aggAlias <> "." <> pgFmtIdent ssSelName
pgFmtSpreadSelectItem :: Alias -> SpreadSelectField -> SQL.Snippet
pgFmtSpreadSelectItem aggAlias SpreadSelectField{ssSelName, ssSelAggFunction, ssSelAggCast, ssSelAlias} =
pgFmtApplyAggregate ssSelAggFunction ssSelAggCast (pgFmtFullSelName aggAlias ssSelName) <> pgFmtAs ssSelAlias

pgFmtApplyAggregate :: Maybe AggregateFunction -> Maybe Cast -> SQL.Snippet -> SQL.Snippet
pgFmtApplyAggregate Nothing _ snippet = snippet
Expand All @@ -290,11 +286,13 @@ pgFmtApplyAggregate (Just agg) aggCast snippet =
convertAggFunction = SQL.sql . BS.map toUpper . BS.pack . show
aggregatedSnippet = convertAggFunction agg <> "(" <> snippet <> ")"

pgFmtApplyToManySpreadAgg :: Maybe AggregateFunction -> Maybe Cast -> Alias -> [CoercibleOrderTerm] -> SQL.Snippet -> SQL.Snippet
pgFmtApplyToManySpreadAgg Nothing aggCast relAggAlias order snippet =
"COALESCE(json_agg(" <> pgFmtApplyCast aggCast snippet <> orderF (QualifiedIdentifier "" relAggAlias) order <> "),'[]')::jsonb"
pgFmtApplyToManySpreadAgg agg aggCast _ _ snippet =
pgFmtApplyAggregate agg aggCast snippet
pgFmtSpreadJoinSelectItem :: Alias -> [CoercibleOrderTerm] -> SpreadSelectField -> SQL.Snippet
pgFmtSpreadJoinSelectItem aggAlias order SpreadSelectField{ssSelName, ssSelAlias} =
"COALESCE(json_agg(" <> fmtField <> " " <> fmtOrder <> "),'[]')::jsonb" <> " AS " <> fmtAlias
where
fmtField = pgFmtFullSelName aggAlias ssSelName
fmtOrder = orderF (QualifiedIdentifier "" aggAlias) order
fmtAlias = pgFmtIdent (fromMaybe ssSelName ssSelAlias)

pgFmtApplyCast :: Maybe Cast -> SQL.Snippet -> SQL.Snippet
pgFmtApplyCast Nothing snippet = snippet
Expand All @@ -303,6 +301,11 @@ pgFmtApplyCast Nothing snippet = snippet
-- Not quoting should be fine, we validate the input on Parsers.
pgFmtApplyCast (Just cast) snippet = "CAST( " <> snippet <> " AS " <> SQL.sql (encodeUtf8 cast) <> " )"

pgFmtFullSelName :: Alias -> FieldName -> SQL.Snippet
pgFmtFullSelName aggAlias fieldName = case fieldName of
"*" -> pgFmtIdent aggAlias <> ".*"
_ -> pgFmtIdent aggAlias <> "." <> pgFmtIdent fieldName

-- TODO: At this stage there shouldn't be a Maybe since ApiRequest should ensure that an INSERT/UPDATE has a body
fromJsonBodyF :: Maybe LBS.ByteString -> [CoercibleField] -> Bool -> Bool -> Bool -> SQL.Snippet
fromJsonBodyF body fields includeSelect includeLimitOne includeDefaults =
Expand Down
13 changes: 12 additions & 1 deletion test/spec/Feature/Query/AggregateFunctionsSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ disallowed =
{ matchStatus = 400
, matchHeaders = [matchContentTypeJson] }

it "prevents the use of aggregates on spread embeds" $
it "prevents the use of aggregates on to-one spread embeds" $
get "/project_invoices?select=...projects(id.count())" `shouldRespondWith`
[json|{
"hint":null,
Expand All @@ -522,3 +522,14 @@ disallowed =
}|]
{ matchStatus = 400
, matchHeaders = [matchContentTypeJson] }

it "prevents the use of aggregates on to-many spread embeds" $
get "/factories?select=...processes(id.count())" `shouldRespondWith`
[json|{
"hint":null,
"details":null,
"code":"PGRST123",
"message":"Use of aggregate functions is not allowed"
}|]
{ matchStatus = 400
, matchHeaders = [matchContentTypeJson] }

0 comments on commit cd39da3

Please sign in to comment.