Skip to content

Commit

Permalink
Merge pull request #410 from go-jet/row-exp
Browse files Browse the repository at this point in the history
Add support for ROW expressions and VALUES statement
  • Loading branch information
go-jet authored Oct 17, 2024
2 parents 58a386a + 8d112f7 commit 369c657
Show file tree
Hide file tree
Showing 53 changed files with 1,532 additions and 176 deletions.
8 changes: 8 additions & 0 deletions internal/jet/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ type Dialect interface {
ArgumentPlaceholder() QueryPlaceholderFunc
IsReservedWord(name string) bool
SerializeOrderBy() func(expression Expression, ascending, nullsFirst *bool) SerializerFunc
ValuesDefaultColumnName(index int) string
}

// SerializerFunc func
Expand All @@ -35,6 +36,7 @@ type DialectParams struct {
ArgumentPlaceholder QueryPlaceholderFunc
ReservedWords []string
SerializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc
ValuesDefaultColumnName func(index int) string
}

// NewDialect creates new dialect with params
Expand All @@ -49,6 +51,7 @@ func NewDialect(params DialectParams) Dialect {
argumentPlaceholder: params.ArgumentPlaceholder,
reservedWords: arrayOfStringsToMapOfStrings(params.ReservedWords),
serializeOrderBy: params.SerializeOrderBy,
valuesDefaultColumnName: params.ValuesDefaultColumnName,
}
}

Expand All @@ -62,6 +65,7 @@ type dialectImpl struct {
argumentPlaceholder QueryPlaceholderFunc
reservedWords map[string]bool
serializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc
valuesDefaultColumnName func(index int) string
}

func (d *dialectImpl) Name() string {
Expand Down Expand Up @@ -107,6 +111,10 @@ func (d *dialectImpl) SerializeOrderBy() func(expression Expression, ascending,
return d.serializeOrderBy
}

func (d *dialectImpl) ValuesDefaultColumnName(index int) string {
return d.valuesDefaultColumnName(index)
}

func arrayOfStringsToMapOfStrings(arr []string) map[string]bool {
ret := map[string]bool{}
for _, elem := range arr {
Expand Down
17 changes: 4 additions & 13 deletions internal/jet/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ func (e *ExpressionInterfaceImpl) IS_NOT_NULL() BoolExpression {

// IN checks if this expressions matches any in expressions list
func (e *ExpressionInterfaceImpl) IN(expressions ...Expression) BoolExpression {
return newBinaryBoolOperatorExpression(e.Parent, WRAP(expressions...), "IN")
return newBinaryBoolOperatorExpression(e.Parent, wrap(expressions...), "IN")
}

// NOT_IN checks if this expressions is different of all expressions in expressions list
func (e *ExpressionInterfaceImpl) NOT_IN(expressions ...Expression) BoolExpression {
return newBinaryBoolOperatorExpression(e.Parent, WRAP(expressions...), "NOT IN")
return newBinaryBoolOperatorExpression(e.Parent, wrap(expressions...), "NOT IN")
}

// AS the temporary alias name to assign to the expression
Expand Down Expand Up @@ -316,15 +316,6 @@ func (s *complexExpression) serialize(statement StatementType, out *SQLBuilder,
}
}

type skipParenthesisWrap struct {
Expression
}

func skipWrap(expression Expression) Expression {
return &skipParenthesisWrap{expression}
}

// since the expression is a function parameter, there is no need to wrap it in parentheses
func (s *skipParenthesisWrap) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
s.Expression.serialize(statement, out, append(options, NoWrap)...)
func wrap(expressions ...Expression) Expression {
return NewFunc("", expressions, nil)
}
7 changes: 1 addition & 6 deletions internal/jet/func_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,6 @@ func OR(expressions ...BoolExpression) BoolExpression {
return newBoolExpressionListOperator("OR", expressions...)
}

// ROW function is used to create a tuple value that consists of a set of expressions or column values.
func ROW(expressions ...Expression) Expression {
return NewFunc("ROW", expressions, nil)
}

// ------------------ Mathematical functions ---------------//

// ABSf calculates absolute value from float expression
Expand Down Expand Up @@ -711,7 +706,7 @@ func (p parametersSerializer) serialize(statement StatementType, out *SQLBuilder
if _, isStatement := expression.(Statement); isStatement {
expression.serialize(statement, out, options...)
} else {
skipWrap(expression).serialize(statement, out, options...)
expression.serialize(statement, out, append(options, NoWrap, Ident)...)
}
}
}
Expand Down
26 changes: 0 additions & 26 deletions internal/jet/literal_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -374,32 +374,6 @@ func (n *starLiteral) serialize(statement StatementType, out *SQLBuilder, option

//---------------------------------------------------//

type wrap struct {
ExpressionInterfaceImpl
expressions []Expression
}

func (n *wrap) serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString("(")

if len(n.expressions) == 1 {
options = append(options, NoWrap, Ident)
}
serializeExpressionList(statementType, n.expressions, ", ", out, options...)

out.WriteString(")")
}

// WRAP wraps list of expressions with brackets - ( expression1, expression2, ... )
func WRAP(expression ...Expression) Expression {
wrap := &wrap{expressions: expression}
wrap.ExpressionInterfaceImpl.Parent = wrap

return wrap
}

//---------------------------------------------------//

type rawExpression struct {
ExpressionInterfaceImpl

Expand Down
7 changes: 6 additions & 1 deletion internal/jet/order_set_aggregate_functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,12 @@ type orderSetAggregateFuncExpression struct {

func (p *orderSetAggregateFuncExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString(p.name)
WRAP(p.fraction).serialize(statement, out, FallTrough(options)...)

if p.fraction != nil {
wrap(p.fraction).serialize(statement, out, FallTrough(options)...)
} else {
wrap().serialize(statement, out, FallTrough(options)...)
}
out.WriteString("WITHIN GROUP")
p.orderBy.serialize(statement, out)
}
2 changes: 1 addition & 1 deletion internal/jet/projection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ AVG(table1.col_int) AS "avg",
table2.col3 AS "col3",
table2.col4 AS "col4"`)

subQueryProjections := projectionList.fromImpl(NewSelectTable(nil, "subQuery"))
subQueryProjections := projectionList.fromImpl(NewSelectTable(nil, "subQuery", nil))

assertProjectionSerialize(t, subQueryProjections,
`"subQuery"."table1.col3" AS "table1.col3",
Expand Down
2 changes: 1 addition & 1 deletion internal/jet/raw_statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ type rawStatementImpl struct {
}

// RawStatement creates new sql statements from raw query and optional map of named arguments
func RawStatement(dialect Dialect, rawQuery string, namedArgument ...map[string]interface{}) Statement {
func RawStatement(dialect Dialect, rawQuery string, namedArgument ...map[string]interface{}) SerializerStatement {
newRawStatement := rawStatementImpl{
serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{
dialect: dialect,
Expand Down
102 changes: 102 additions & 0 deletions internal/jet/row_expression.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package jet

// RowExpression interface
type RowExpression interface {
Expression
HasProjections

EQ(rhs RowExpression) BoolExpression
NOT_EQ(rhs RowExpression) BoolExpression
IS_DISTINCT_FROM(rhs RowExpression) BoolExpression
IS_NOT_DISTINCT_FROM(rhs RowExpression) BoolExpression

LT(rhs RowExpression) BoolExpression
LT_EQ(rhs RowExpression) BoolExpression
GT(rhs RowExpression) BoolExpression
GT_EQ(rhs RowExpression) BoolExpression
}

type rowInterfaceImpl struct {
parent Expression
dialect Dialect
elemCount int
}

func (n *rowInterfaceImpl) EQ(rhs RowExpression) BoolExpression {
return Eq(n.parent, rhs)
}

func (n *rowInterfaceImpl) NOT_EQ(rhs RowExpression) BoolExpression {
return NotEq(n.parent, rhs)
}

func (n *rowInterfaceImpl) IS_DISTINCT_FROM(rhs RowExpression) BoolExpression {
return IsDistinctFrom(n.parent, rhs)
}

func (n *rowInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs RowExpression) BoolExpression {
return IsNotDistinctFrom(n.parent, rhs)
}

func (n *rowInterfaceImpl) GT(rhs RowExpression) BoolExpression {
return Gt(n.parent, rhs)
}

func (n *rowInterfaceImpl) GT_EQ(rhs RowExpression) BoolExpression {
return GtEq(n.parent, rhs)
}

func (n *rowInterfaceImpl) LT(rhs RowExpression) BoolExpression {
return Lt(n.parent, rhs)
}

func (n *rowInterfaceImpl) LT_EQ(rhs RowExpression) BoolExpression {
return LtEq(n.parent, rhs)
}

func (n *rowInterfaceImpl) projections() ProjectionList {
var ret ProjectionList

for i := 0; i < n.elemCount; i++ {
rowColumn := NewColumnImpl(n.dialect.ValuesDefaultColumnName(i), "", nil)
ret = append(ret, &rowColumn)
}

return ret
}

// ---------------------------------------------------//
type rowExpressionWrapper struct {
rowInterfaceImpl
Expression
}

func newRowExpression(name string, dialect Dialect, expressions ...Expression) RowExpression {
ret := &rowExpressionWrapper{}
ret.rowInterfaceImpl.parent = ret

ret.Expression = NewFunc(name, expressions, ret)
ret.dialect = dialect
ret.elemCount = len(expressions)

return ret
}

// ROW function is used to create a tuple value that consists of a set of expressions or column values.
func ROW(dialect Dialect, expressions ...Expression) RowExpression {
return newRowExpression("ROW", dialect, expressions...)
}

// WRAP creates row expressions without ROW keyword `( expression1, expression2, ... )`.
func WRAP(dialect Dialect, expressions ...Expression) RowExpression {
return newRowExpression("", dialect, expressions...)
}

// RowExp serves as a wrapper for an arbitrary expression, treating it as a row expression.
// This enables the Go compiler to interpret any expression as a row expression
// Note: This does not modify the generated SQL builder output by adding a SQL CAST operation.
func RowExp(expression Expression) RowExpression {
rowExpressionWrap := rowExpressionWrapper{Expression: expression}
rowExpressionWrap.rowInterfaceImpl.parent = &rowExpressionWrap
return &rowExpressionWrap
}
28 changes: 22 additions & 6 deletions internal/jet/select_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,21 @@ type SelectTable interface {
}

type selectTableImpl struct {
Statement SerializerHasProjections
alias string
Statement SerializerHasProjections
alias string
columnAliases []ColumnExpression
}

// NewSelectTable func
func NewSelectTable(selectStmt SerializerHasProjections, alias string) selectTableImpl {
func NewSelectTable(selectStmt SerializerHasProjections, alias string, columnAliases []ColumnExpression) selectTableImpl {
selectTable := selectTableImpl{
Statement: selectStmt,
alias: alias,
Statement: selectStmt,
alias: alias,
columnAliases: columnAliases,
}

for _, column := range selectTable.columnAliases {
column.setSubQuery(selectTable)
}

return selectTable
Expand All @@ -31,6 +37,10 @@ func (s selectTableImpl) Alias() string {
}

func (s selectTableImpl) AllColumns() ProjectionList {
if len(s.columnAliases) > 0 {
return ColumnListToProjectionList(s.columnAliases)
}

projectionList := s.projections().fromImpl(s)
return projectionList.(ProjectionList)
}
Expand All @@ -40,6 +50,12 @@ func (s selectTableImpl) serialize(statement StatementType, out *SQLBuilder, opt

out.WriteString("AS")
out.WriteIdentifier(s.alias)

if len(s.columnAliases) > 0 {
out.WriteByte('(')
SerializeColumnExpressionNames(s.columnAliases, out)
out.WriteByte(')')
}
}

// --------------------------------------
Expand All @@ -50,7 +66,7 @@ type lateralImpl struct {

// NewLateral creates new lateral expression from select statement with alias
func NewLateral(selectStmt SerializerStatement, alias string) SelectTable {
return lateralImpl{selectTableImpl: NewSelectTable(selectStmt, alias)}
return lateralImpl{selectTableImpl: NewSelectTable(selectStmt, alias, nil)}
}

func (s lateralImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
Expand Down
35 changes: 35 additions & 0 deletions internal/jet/values.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package jet

// Values hold a set of one or more rows
type Values []RowExpression

func (v Values) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteByte('(')
out.IncreaseIdent(5)

out.NewLine()
out.WriteString("VALUES")

for rowIndex, row := range v {
if rowIndex > 0 {
out.WriteString(",")
out.NewLine()
} else {
out.IncreaseIdent(7)
}

row.serialize(statement, out, options...)
}
out.DecreaseIdent(7)
out.DecreaseIdent(5)
out.NewLine()
out.WriteByte(')')
}

func (v Values) projections() ProjectionList {
if len(v) == 0 {
return nil
}

return v[0].projections()
}
11 changes: 1 addition & 10 deletions internal/jet/with_statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ type CommonTableExpression struct {
// CTE creates new named CommonTableExpression
func CTE(name string, columns ...ColumnExpression) CommonTableExpression {
cte := CommonTableExpression{
selectTableImpl: NewSelectTable(nil, name),
selectTableImpl: NewSelectTable(nil, name, columns),
Columns: columns,
}

Expand Down Expand Up @@ -99,12 +99,3 @@ func (c CommonTableExpression) serialize(statement StatementType, out *SQLBuilde
out.WriteIdentifier(c.alias)
}
}

// AllColumns returns list of all projections in the CTE
func (c CommonTableExpression) AllColumns() ProjectionList {
if len(c.Columns) > 0 {
return ColumnListToProjectionList(c.Columns)
}

return c.selectTableImpl.AllColumns()
}
Loading

0 comments on commit 369c657

Please sign in to comment.