diff --git a/internal/jet/column_assigment.go b/internal/jet/column_assigment.go index 440f3eb8..c8884331 100644 --- a/internal/jet/column_assigment.go +++ b/internal/jet/column_assigment.go @@ -11,6 +11,13 @@ type columnAssigmentImpl struct { expression Expression } +func NewColumnAssignment(serializer ColumnSerializer, expression Expression) ColumnAssigment { + return &columnAssigmentImpl{ + column: serializer, + expression: expression, + } +} + func (a columnAssigmentImpl) isColumnAssigment() {} func (a columnAssigmentImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { diff --git a/postgres/columns.go b/postgres/columns.go index 819da380..a70c234b 100644 --- a/postgres/columns.go +++ b/postgres/columns.go @@ -109,13 +109,20 @@ type ColumnInterval interface { jet.Column From(subQuery SelectTable) ColumnInterval + SET(intervalExp IntervalExpression) ColumnAssigment } +//------------------------------------------------------// + type intervalColumnImpl struct { jet.ColumnExpressionImpl intervalInterfaceImpl } +func (i *intervalColumnImpl) SET(intervalExp IntervalExpression) ColumnAssigment { + return jet.NewColumnAssignment(i, intervalExp) +} + func (i *intervalColumnImpl) From(subQuery SelectTable) ColumnInterval { newIntervalColumn := IntervalColumn(i.Name()) jet.SetTableName(newIntervalColumn, i.TableName()) diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index 2f1be14d..9f45bd11 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -3,6 +3,7 @@ package postgres import ( "database/sql" "github.com/go-jet/jet/v2/internal/utils/ptr" + "github.com/stretchr/testify/assert" "testing" "time" @@ -931,6 +932,43 @@ func TestTimeExpression(t *testing.T) { require.NoError(t, err) } +func TestIntervalUpsert(t *testing.T) { + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + stmt := SELECT(Employee.AllColumns).FROM(Employee). + WHERE(Employee.EmployeeID.EQ(Int(1))) + + //Validate initial dataset + var windy model.Employee + err := stmt.Query(db, &windy) + assert.Equal(t, windy.EmployeeID, int32(1)) + assert.Equal(t, windy.FirstName, "Windy") + assert.Equal(t, windy.LastName, "Hays") + assert.Equal(t, *windy.PtoAccrual, "22:00:00") + assert.Nil(t, err) + windy.PtoAccrual = ptr.Of("3h") + //Update data + updateStmt := Employee.UPDATE(Employee.PtoAccrual).SET( + Employee.PtoAccrual.SET(INTERVAL(3, HOUR)), + ).WHERE(Employee.EmployeeID.EQ(Int(1))).RETURNING(Employee.AllColumns) + + err = updateStmt.Query(db, &windy) + err = stmt.Query(db, &windy) + assert.Nil(t, err) + assert.Equal(t, *windy.PtoAccrual, "03:00:00") + //Upsert dataset with a different value + windy.PtoAccrual = ptr.Of("5h") + insertStmt := Employee.INSERT(Employee.AllColumns). + MODEL(&windy). + ON_CONFLICT(Employee.EmployeeID). + DO_UPDATE(SET( + Employee.PtoAccrual.SET(Employee.EXCLUDED.PtoAccrual), + )).RETURNING(Employee.AllColumns) + err = insertStmt.Query(db, &windy) + assert.Nil(t, err) + assert.Equal(t, *windy.PtoAccrual, "05:00:00") + }) +} + func TestInterval(t *testing.T) { skipForCockroachDB(t) diff --git a/tests/postgres/insert_test.go b/tests/postgres/insert_test.go index a0790918..7edefd47 100644 --- a/tests/postgres/insert_test.go +++ b/tests/postgres/insert_test.go @@ -93,9 +93,9 @@ func TestInsertOnConflict(t *testing.T) { ON_CONFLICT().DO_NOTHING() testutils.AssertStatementSql(t, stmt, ` -INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id) -VALUES ($1, $2, $3, $4, $5), - ($6, $7, $8, $9, $10) +INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id, pto_accrual) +VALUES ($1, $2, $3, $4, $5, $6), + ($7, $8, $9, $10, $11, $12) ON CONFLICT DO NOTHING; `) testutils.AssertExecAndRollback(t, stmt, db, 1) @@ -111,9 +111,9 @@ ON CONFLICT DO NOTHING; ON_CONFLICT(Employee.EmployeeID).DO_NOTHING() testutils.AssertStatementSql(t, stmt, ` -INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id) -VALUES ($1, $2, $3, $4, $5), - ($6, $7, $8, $9, $10) +INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id, pto_accrual) +VALUES ($1, $2, $3, $4, $5, $6), + ($7, $8, $9, $10, $11, $12) ON CONFLICT (employee_id) DO NOTHING; `) testutils.AssertExecAndRollback(t, stmt, db, 1) @@ -130,9 +130,9 @@ ON CONFLICT (employee_id) DO NOTHING; ON_CONFLICT().ON_CONSTRAINT("employee_pkey").DO_NOTHING() testutils.AssertStatementSql(t, stmt, ` -INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id) -VALUES ($1, $2, $3, $4, $5), - ($6, $7, $8, $9, $10) +INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id, pto_accrual) +VALUES ($1, $2, $3, $4, $5, $6), + ($7, $8, $9, $10, $11, $12) ON CONFLICT ON CONSTRAINT employee_pkey DO NOTHING; `) testutils.AssertExecAndRollback(t, stmt, db, 1) @@ -234,8 +234,8 @@ ON CONFLICT (id) WHERE (id * 2) > 10 DO UPDATE ON_CONFLICT().DO_UPDATE(nil) testutils.AssertStatementSql(t, stmt, ` -INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id) -VALUES ($1, $2, $3, $4, $5); +INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id, pto_accrual) +VALUES ($1, $2, $3, $4, $5, $6); `) testutils.AssertExecAndRollback(t, stmt, db, 1) requireLogged(t, stmt) diff --git a/tests/postgres/sample_test.go b/tests/postgres/sample_test.go index a1d4c2dd..f2526317 100644 --- a/tests/postgres/sample_test.go +++ b/tests/postgres/sample_test.go @@ -331,11 +331,13 @@ SELECT employee.employee_id AS "employee.employee_id", employee.last_name AS "employee.last_name", employee.employment_date AS "employee.employment_date", employee.manager_id AS "employee.manager_id", + employee.pto_accrual AS "employee.pto_accrual", manager.employee_id AS "manager.employee_id", manager.first_name AS "manager.first_name", manager.last_name AS "manager.last_name", manager.employment_date AS "manager.employment_date", - manager.manager_id AS "manager.manager_id" + manager.manager_id AS "manager.manager_id", + manager.pto_accrual AS "manager.pto_accrual" FROM test_sample.employee LEFT JOIN test_sample.employee AS manager ON (manager.employee_id = employee.manager_id) ORDER BY employee.employee_id; @@ -370,6 +372,7 @@ ORDER BY employee.employee_id; LastName: "Hays", EmploymentDate: testutils.TimestampWithTimeZone("1999-01-08 04:05:06.1 +0100 CET", 1), ManagerID: nil, + PtoAccrual: ptr.Of("22:00:00"), }) require.True(t, dest[0].Manager == nil) diff --git a/tests/testdata b/tests/testdata index 1e9247e3..6a397747 160000 --- a/tests/testdata +++ b/tests/testdata @@ -1 +1 @@ -Subproject commit 1e9247e333babd5172cf162e38518d993f5f3df4 +Subproject commit 6a397747d310938b41d3950d68009578180d3dd5