Skip to content

Commit

Permalink
Add within group clause support for aggregate function builder (#1024)
Browse files Browse the repository at this point in the history
Co-authored-by: Ivashkin Olexiy <[email protected]>
Co-authored-by: Dev K0te <[email protected]>
Co-authored-by: igalklebanov <[email protected]>
  • Loading branch information
3 people committed Jan 9, 2025
1 parent e31e221 commit a619971
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 8 deletions.
8 changes: 6 additions & 2 deletions src/operation-node/aggregate-function-node.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ export interface AggregateFunctionNode extends OperationNode {
readonly aggregated: readonly OperationNode[]
readonly distinct?: boolean
readonly orderBy?: OrderByNode
readonly withinGroup?: OrderByNode
readonly filter?: WhereNode
readonly over?: OverNode
}
Expand Down Expand Up @@ -46,11 +47,14 @@ export const AggregateFunctionNode = freeze({
cloneWithOrderBy(
aggregateFunctionNode: AggregateFunctionNode,
orderItems: ReadonlyArray<OrderByItemNode>,
withinGroup = false,
): AggregateFunctionNode {
const prop = withinGroup ? 'withinGroup' : 'orderBy'

return freeze({
...aggregateFunctionNode,
orderBy: aggregateFunctionNode.orderBy
? OrderByNode.cloneWithItems(aggregateFunctionNode.orderBy, orderItems)
[prop]: aggregateFunctionNode[prop]
? OrderByNode.cloneWithItems(aggregateFunctionNode[prop], orderItems)
: OrderByNode.create(orderItems),
})
},
Expand Down
3 changes: 2 additions & 1 deletion src/operation-node/operation-node-transformer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -904,11 +904,12 @@ export class OperationNodeTransformer {
): AggregateFunctionNode {
return requireAllProps({
kind: 'AggregateFunctionNode',
func: node.func,
aggregated: this.transformNodeList(node.aggregated),
distinct: node.distinct,
orderBy: this.transformNode(node.orderBy),
withinGroup: this.transformNode(node.withinGroup),
filter: this.transformNode(node.filter),
func: node.func,
over: this.transformNode(node.over),
})
}
Expand Down
47 changes: 44 additions & 3 deletions src/query-builder/aggregate-function-builder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import {
} from '../expression/expression.js'
import {
ReferenceExpression,
StringReference,
SimpleReferenceExpression,
} from '../parser/reference-parser.js'
import {
ComparisonOperatorExpression,
Expand All @@ -21,7 +21,6 @@ import {
} from '../parser/binary-operation-parser.js'
import { SqlBool } from '../util/type-utils.js'
import { ExpressionOrFactory } from '../parser/expression-parser.js'
import { DynamicReferenceBuilder } from '../dynamic/dynamic-reference-builder.js'
import {
OrderByDirectionExpression,
parseOrderBy,
Expand Down Expand Up @@ -125,7 +124,7 @@ export class AggregateFunctionBuilder<DB, TB extends keyof DB, O = unknown>
* inner join "pet" ON "pet"."owner_id" = "person"."id"
* ```
*/
orderBy<OE extends StringReference<DB, TB> | DynamicReferenceBuilder<any>>(
orderBy<OE extends SimpleReferenceExpression<DB, TB>>(
orderBy: OE,
direction?: OrderByDirectionExpression,
): AggregateFunctionBuilder<DB, TB, O> {
Expand All @@ -138,6 +137,48 @@ export class AggregateFunctionBuilder<DB, TB extends keyof DB, O = unknown>
})
}

/**
* Adds a `withing group` clause with a nested `order by` clause after the function.
*
* This is only supported by some dialects like PostgreSQL or MS SQL Server.
*
* ### Examples
*
* Most frequent person name:
*
* ```ts
* const result = await db
* .selectFrom('person')
* .select((eb) => [
* eb.fn
* .agg<string>('mode')
* .withinGroupOrderBy('person.first_name')
* .as('most_frequent_name')
* ])
* .executeTakeFirstOrThrow()
* ```
*
* The generated SQL (PostgreSQL):
*
* ```sql
* select mode() within group (order by "person"."first_name") as "most_frequent_name"
* from "person"
* ```
*/
withinGroupOrderBy<OE extends SimpleReferenceExpression<DB, TB>>(
orderBy: OE,
direction?: OrderByDirectionExpression,
): AggregateFunctionBuilder<DB, TB, O> {
return new AggregateFunctionBuilder({
...this.#props,
aggregateFunctionNode: AggregateFunctionNode.cloneWithOrderBy(
this.#props.aggregateFunctionNode,
parseOrderBy([orderBy, direction]),
true,
),
})
}

/**
* Adds a `filter` clause with a nested `where` clause after the function.
*
Expand Down
6 changes: 6 additions & 0 deletions src/query-compiler/default-query-compiler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1418,6 +1418,12 @@ export class DefaultQueryCompiler

this.append(')')

if (node.withinGroup) {
this.append(' within group (')
this.visitNode(node.withinGroup)
this.append(')')
}

if (node.filter) {
this.append(' filter(')
this.visitNode(node.filter)
Expand Down
42 changes: 40 additions & 2 deletions test/node/src/aggregate-function.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
SimpleReferenceExpression,
ReferenceExpression,
sql,
expressionBuilder,
} from '../../../'
import {
Database,
Expand Down Expand Up @@ -1108,8 +1109,12 @@ for (const dialect of DIALECTS) {
await query.execute()
})

describe(`should execute order-sensitive aggregate functions`, () => {
if (dialect === 'postgres' || dialect === 'mysql' || dialect === 'sqlite') {
describe('should execute order-sensitive aggregate functions', () => {
if (
dialect === 'postgres' ||
dialect === 'mysql' ||
dialect === 'sqlite'
) {
const isMySql = dialect === 'mysql'
const funcName = isMySql ? 'group_concat' : 'string_agg'
const funcArgs: Array<ReferenceExpression<Database, 'person'>> = [
Expand Down Expand Up @@ -1157,6 +1162,39 @@ for (const dialect of DIALECTS) {
await query.execute()
})
}

if (dialect === 'postgres' || dialect === 'mssql') {
it(`should execute a query with within group (order by column) in select clause`, async () => {
const query = ctx.db.selectFrom('toy').select((eb) =>
eb.fn
.agg('percentile_cont', [sql.lit(0.5)])
.withinGroupOrderBy('toy.price')
.$call((ab) => (dialect === 'mssql' ? ab.over() : ab))
.as('median_price'),
)

testSql(query, dialect, {
postgres: {
sql: [
`select percentile_cont(0.5) within group (order by "toy"."price") as "median_price"`,
`from "toy"`,
],
parameters: [],
},
mysql: NOT_SUPPORTED,
mssql: {
sql: [
`select percentile_cont(0.5) within group (order by "toy"."price") over() as "median_price"`,
`from "toy"`,
],
parameters: [],
},
sqlite: NOT_SUPPORTED,
})

await query.execute()
})
}
})
})
}
Expand Down

0 comments on commit a619971

Please sign in to comment.