diff --git a/pkg/sql2pgroll/alter_table.go b/pkg/sql2pgroll/alter_table.go index c124d829..31c26c30 100644 --- a/pkg/sql2pgroll/alter_table.go +++ b/pkg/sql2pgroll/alter_table.go @@ -46,6 +46,8 @@ func convertAlterTableStmt(stmt *pgq.AlterTableStmt) (migrations.Operations, err op, err = convertAlterTableSetColumnDefault(stmt, alterTableCmd) case pgq.AlterTableType_AT_DropConstraint: op, err = convertAlterTableDropConstraint(stmt, alterTableCmd) + case pgq.AlterTableType_AT_AddColumn: + op, err = convertAlterTableAddColumn(stmt, alterTableCmd) } if err != nil { @@ -198,20 +200,9 @@ func convertAlterTableAddForeignKeyConstraint(stmt *pgq.AlterTableStmt, constrai migs[column] = PlaceHolderSQL } - var onDelete migrations.ForeignKeyReferenceOnDelete - switch constraint.GetFkDelAction() { - case "a": - onDelete = migrations.ForeignKeyReferenceOnDeleteNOACTION - case "c": - onDelete = migrations.ForeignKeyReferenceOnDeleteCASCADE - case "r": - onDelete = migrations.ForeignKeyReferenceOnDeleteRESTRICT - case "d": - onDelete = migrations.ForeignKeyReferenceOnDeleteSETDEFAULT - case "n": - onDelete = migrations.ForeignKeyReferenceOnDeleteSETNULL - default: - return nil, fmt.Errorf("unknown delete action: %q", constraint.GetFkDelAction()) + onDelete, err := parseOnDeleteAction(constraint.GetFkDelAction()) + if err != nil { + return nil, fmt.Errorf("failed to parse on delete action: %w", err) } tableName := getQualifiedRelationName(stmt.Relation) @@ -232,6 +223,23 @@ func convertAlterTableAddForeignKeyConstraint(stmt *pgq.AlterTableStmt, constrai }, nil } +func parseOnDeleteAction(action string) (migrations.ForeignKeyReferenceOnDelete, error) { + switch action { + case "a": + return migrations.ForeignKeyReferenceOnDeleteNOACTION, nil + case "c": + return migrations.ForeignKeyReferenceOnDeleteCASCADE, nil + case "r": + return migrations.ForeignKeyReferenceOnDeleteRESTRICT, nil + case "d": + return migrations.ForeignKeyReferenceOnDeleteSETDEFAULT, nil + case "n": + return migrations.ForeignKeyReferenceOnDeleteSETNULL, nil + default: + return migrations.ForeignKeyReferenceOnDeleteNOACTION, fmt.Errorf("unknown delete action: %q", action) + } +} + func canConvertAlterTableAddForeignKeyConstraint(constraint *pgq.Constraint) bool { if constraint.SkipValidation { return false @@ -319,21 +327,12 @@ func convertAlterTableSetColumnDefault(stmt *pgq.AlterTableStmt, cmd *pgq.AlterT Up: PlaceHolderSQL, } - if c := cmd.GetDef().GetAConst(); c != nil { - if c.GetIsnull() { - // The default can be set to null - operation.Default = nullable.NewNullNullable[string]() - return operation, nil - } + def, err := extractDefault(cmd.GetDef()) + if err != nil { + return nil, err } - - // We're setting it to an expression - if cmd.GetDef() != nil { - def, err := pgq.DeparseExpr(cmd.GetDef()) - if err != nil { - return nil, fmt.Errorf("failed to deparse expression: %w", err) - } - operation.Default = nullable.NewNullableWithValue(def) + if def.IsSpecified() { + operation.Default = def return operation, nil } @@ -347,7 +346,25 @@ func convertAlterTableSetColumnDefault(stmt *pgq.AlterTableStmt, cmd *pgq.AlterT return nil, nil } -// convertAlterTableDropConstraint convert DROP CONSTRAINT SQL into an OpDropMultiColumnConstraint. +func extractDefault(node *pgq.Node) (nullable.Nullable[string], error) { + if c := node.GetAConst(); c != nil && c.GetIsnull() { + // The default can be set to null + return nullable.NewNullNullable[string](), nil + } + + // It's an expression + if node != nil { + def, err := pgq.DeparseExpr(node) + if err != nil { + return nil, fmt.Errorf("failed to deparse expression: %w", err) + } + return nullable.NewNullableWithValue(def), nil + } + + return nil, nil +} + +// convertAlterTableDropConstraint converts DROP CONSTRAINT SQL into an OpDropMultiColumnConstraint. // Because we are unable to infer the columns involved, placeholder migrations are used. // // SQL statements like the following are supported: @@ -380,6 +397,115 @@ func canConvertDropConstraint(cmd *pgq.AlterTableCmd) bool { return cmd.Behavior != pgq.DropBehavior_DROP_CASCADE } +// convertAlterTableAddColumn converts ADD COLUMN SQL into an OpAddColumn. +// +// See TestConvertAlterTableStatements and TestUnconvertableAlterTableStatements for statements we +// support. +func convertAlterTableAddColumn(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd) (migrations.Operation, error) { + if !canConvertAddColumn(cmd) { + return nil, nil + } + + columnDef := cmd.GetDef().GetColumnDef() + if !canConvertColumnDef(columnDef) { + return nil, nil + } + + columnType, err := pgq.DeparseTypeName(columnDef.GetTypeName()) + if err != nil { + return nil, fmt.Errorf("failed to deparse type name: %w", err) + } + + operation := &migrations.OpAddColumn{ + Column: migrations.Column{ + Name: columnDef.GetColname(), + Type: columnType, + }, + Table: getQualifiedRelationName(stmt.GetRelation()), + Up: PlaceHolderSQL, + } + + if len(columnDef.GetConstraints()) > 0 { + for _, constraint := range columnDef.GetConstraints() { + switch constraint.GetConstraint().GetContype() { + case pgq.ConstrType_CONSTR_NULL: + operation.Column.Nullable = true + case pgq.ConstrType_CONSTR_PRIMARY: + operation.Column.Pk = true + case pgq.ConstrType_CONSTR_UNIQUE: + operation.Column.Unique = true + case pgq.ConstrType_CONSTR_CHECK: + raw, err := pgq.DeparseExpr(constraint.GetConstraint().GetRawExpr()) + if err != nil { + return nil, fmt.Errorf("failed to deparse raw expression: %w", err) + } + operation.Column.Check = &migrations.CheckConstraint{ + Constraint: raw, + Name: constraint.GetConstraint().GetConname(), + } + case pgq.ConstrType_CONSTR_DEFAULT: + defaultExpr := constraint.GetConstraint().GetRawExpr() + def, err := extractDefault(defaultExpr) + if err != nil { + return nil, err + } + if !def.IsNull() { + v := def.MustGet() + operation.Column.Default = &v + } + case pgq.ConstrType_CONSTR_FOREIGN: + onDelete, err := parseOnDeleteAction(constraint.GetConstraint().GetFkDelAction()) + if err != nil { + return nil, err + } + fk := &migrations.ForeignKeyReference{ + Name: constraint.GetConstraint().GetConname(), + OnDelete: onDelete, + Column: constraint.GetConstraint().GetPkAttrs()[0].GetString_().GetSval(), + Table: getQualifiedRelationName(constraint.GetConstraint().GetPktable()), + } + operation.Column.References = fk + } + } + } + + return operation, nil +} + +func canConvertAddColumn(cmd *pgq.AlterTableCmd) bool { + if cmd.GetMissingOk() { + return false + } + for _, constraint := range cmd.GetDef().GetColumnDef().GetConstraints() { + switch constraint.GetConstraint().GetContype() { + case pgq.ConstrType_CONSTR_DEFAULT, + pgq.ConstrType_CONSTR_NULL, + pgq.ConstrType_CONSTR_NOTNULL, + pgq.ConstrType_CONSTR_PRIMARY, + pgq.ConstrType_CONSTR_UNIQUE, + pgq.ConstrType_CONSTR_FOREIGN, + pgq.ConstrType_CONSTR_CHECK: + switch constraint.GetConstraint().GetFkUpdAction() { + case "r", "c", "n", "d": + // RESTRICT, CASCADE, SET NULL, SET DEFAULT + return false + case "a": + // NO ACTION, the default + break + } + case pgq.ConstrType_CONSTR_ATTR_DEFERRABLE, + pgq.ConstrType_CONSTR_ATTR_DEFERRED, + pgq.ConstrType_CONSTR_IDENTITY, + pgq.ConstrType_CONSTR_GENERATED: + return false + case pgq.ConstrType_CONSTR_ATTR_NOT_DEFERRABLE, pgq.ConstrType_CONSTR_ATTR_IMMEDIATE: + break + } + } + + return true +} + func convertAlterTableDropColumn(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd) (migrations.Operation, error) { if !canConvertDropColumn(cmd) { return nil, nil diff --git a/pkg/sql2pgroll/alter_table_test.go b/pkg/sql2pgroll/alter_table_test.go index dd584db6..f74eace5 100644 --- a/pkg/sql2pgroll/alter_table_test.go +++ b/pkg/sql2pgroll/alter_table_test.go @@ -140,6 +140,88 @@ func TestConvertAlterTableStatements(t *testing.T) { sql: "ALTER TABLE schema.foo ADD CONSTRAINT bar CHECK (age > 0)", expectedOp: expect.CreateConstraintOp4, }, + + // Add column + { + sql: "ALTER TABLE foo ADD COLUMN bar int", + expectedOp: expect.AddColumnOp1, + }, + { + sql: "ALTER TABLE foo ADD COLUMN bar int NOT NULL", + expectedOp: expect.AddColumnOp1, + }, + { + sql: "ALTER TABLE schema.foo ADD COLUMN bar int", + expectedOp: expect.AddColumnOp2, + }, + { + sql: "ALTER TABLE foo ADD COLUMN bar int DEFAULT 123", + expectedOp: expect.AddColumnOp1WithDefault(ptr("123")), + }, + { + sql: "ALTER TABLE foo ADD COLUMN bar int DEFAULT 'baz'", + expectedOp: expect.AddColumnOp1WithDefault(ptr("'baz'")), + }, + { + sql: "ALTER TABLE foo ADD COLUMN bar int DEFAULT null", + expectedOp: expect.AddColumnOp1WithDefault(nil), + }, + { + sql: "ALTER TABLE foo ADD COLUMN bar int NULL", + expectedOp: expect.AddColumnOp3, + }, + { + sql: "ALTER TABLE foo ADD COLUMN bar int UNIQUE", + expectedOp: expect.AddColumnOp4, + }, + { + sql: "ALTER TABLE foo ADD COLUMN bar int UNIQUE NOT DEFERRABLE", + expectedOp: expect.AddColumnOp4, + }, + { + sql: "ALTER TABLE foo ADD COLUMN bar int UNIQUE INITIALLY IMMEDIATE", + expectedOp: expect.AddColumnOp4, + }, + { + sql: "ALTER TABLE foo ADD COLUMN bar int PRIMARY KEY", + expectedOp: expect.AddColumnOp5, + }, + { + sql: "ALTER TABLE foo ADD COLUMN bar int CHECK (bar > 0)", + expectedOp: expect.AddColumnOp6, + }, + { + sql: "ALTER TABLE foo ADD COLUMN bar int CONSTRAINT check_bar CHECK (bar > 0)", + expectedOp: expect.AddColumnOp7, + }, + { + sql: "ALTER TABLE foo ADD COLUMN bar int CONSTRAINT fk_baz REFERENCES baz (bar)", + expectedOp: expect.AddColumnOp8WithOnDeleteAction(migrations.ForeignKeyReferenceOnDeleteNOACTION), + }, + { + sql: "ALTER TABLE foo ADD COLUMN bar int CONSTRAINT fk_baz REFERENCES baz (bar) ON UPDATE NO ACTION", + expectedOp: expect.AddColumnOp8WithOnDeleteAction(migrations.ForeignKeyReferenceOnDeleteNOACTION), + }, + { + sql: "ALTER TABLE foo ADD COLUMN bar int CONSTRAINT fk_baz REFERENCES baz (bar) ON DELETE NO ACTION", + expectedOp: expect.AddColumnOp8WithOnDeleteAction(migrations.ForeignKeyReferenceOnDeleteNOACTION), + }, + { + sql: "ALTER TABLE foo ADD COLUMN bar int CONSTRAINT fk_baz REFERENCES baz (bar) ON DELETE RESTRICT", + expectedOp: expect.AddColumnOp8WithOnDeleteAction(migrations.ForeignKeyReferenceOnDeleteRESTRICT), + }, + { + sql: "ALTER TABLE foo ADD COLUMN bar int CONSTRAINT fk_baz REFERENCES baz (bar) ON DELETE SET NULL ", + expectedOp: expect.AddColumnOp8WithOnDeleteAction(migrations.ForeignKeyReferenceOnDeleteSETNULL), + }, + { + sql: "ALTER TABLE foo ADD COLUMN bar int CONSTRAINT fk_baz REFERENCES baz (bar) ON DELETE SET DEFAULT", + expectedOp: expect.AddColumnOp8WithOnDeleteAction(migrations.ForeignKeyReferenceOnDeleteSETDEFAULT), + }, + { + sql: "ALTER TABLE foo ADD COLUMN bar int CONSTRAINT fk_baz REFERENCES baz (bar) ON DELETE CASCADE", + expectedOp: expect.AddColumnOp8WithOnDeleteAction(migrations.ForeignKeyReferenceOnDeleteCASCADE), + }, } for _, tc := range tests { @@ -191,6 +273,19 @@ func TestUnconvertableAlterTableStatements(t *testing.T) { // representable by `OpCreateConstraint` "ALTER TABLE foo ADD CONSTRAINT bar CHECK (age > 0) NO INHERIT", "ALTER TABLE foo ADD CONSTRAINT bar CHECK (age > 0) NOT VALID", + + // ADD COLUMN cases not yet covered + "ALTER TABLE foo ADD COLUMN bar int REFERENCES bar (c) ON UPDATE RESTRICT", + "ALTER TABLE foo ADD COLUMN bar int REFERENCES bar (c) ON UPDATE CASCADE", + "ALTER TABLE foo ADD COLUMN bar int REFERENCES bar (c) ON UPDATE SET NULL", + "ALTER TABLE foo ADD COLUMN bar int REFERENCES bar (c) ON UPDATE SET DEFAULT", + "ALTER TABLE foo ADD COLUMN IF NOT EXISTS bar int", + "ALTER TABLE foo ADD COLUMN bar int UNIQUE DEFERRABLE", + "ALTER TABLE foo ADD COLUMN bar int UNIQUE INITIALLY DEFERRED", + "ALTER TABLE foo ADD COLUMN bar int GENERATED BY DEFAULT AS IDENTITY ", + "ALTER TABLE foo ADD COLUMN bar int GENERATED ALWAYS AS ( 123 ) STORED", + "ALTER TABLE foo ADD COLUMN bar int COLLATE en_US", + "ALTER TABLE foo ADD COLUMN bar int COMPRESSION pglz", } for _, sql := range tests { @@ -204,3 +299,7 @@ func TestUnconvertableAlterTableStatements(t *testing.T) { }) } } + +func ptr[T any](v T) *T { + return &v +} diff --git a/pkg/sql2pgroll/expect/add_column.go b/pkg/sql2pgroll/expect/add_column.go new file mode 100644 index 00000000..76de2583 --- /dev/null +++ b/pkg/sql2pgroll/expect/add_column.go @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: Apache-2.0 + +package expect + +import ( + "github.com/xataio/pgroll/pkg/migrations" + "github.com/xataio/pgroll/pkg/sql2pgroll" +) + +var AddColumnOp1 = &migrations.OpAddColumn{ + Table: "foo", + Up: sql2pgroll.PlaceHolderSQL, + Column: migrations.Column{ + Name: "bar", + Type: "int", + }, +} + +var AddColumnOp2 = &migrations.OpAddColumn{ + Table: "schema.foo", + Up: sql2pgroll.PlaceHolderSQL, + Column: migrations.Column{ + Name: "bar", + Type: "int", + }, +} + +func AddColumnOp1WithDefault(def *string) *migrations.OpAddColumn { + return &migrations.OpAddColumn{ + Table: "foo", + Up: sql2pgroll.PlaceHolderSQL, + Column: migrations.Column{ + Name: "bar", + Type: "int", + Default: def, + }, + } +} + +var AddColumnOp3 = &migrations.OpAddColumn{ + Table: "foo", + Up: sql2pgroll.PlaceHolderSQL, + Column: migrations.Column{ + Name: "bar", + Type: "int", + Nullable: true, + }, +} + +var AddColumnOp4 = &migrations.OpAddColumn{ + Table: "foo", + Up: sql2pgroll.PlaceHolderSQL, + Column: migrations.Column{ + Name: "bar", + Type: "int", + Unique: true, + }, +} + +var AddColumnOp5 = &migrations.OpAddColumn{ + Table: "foo", + Up: sql2pgroll.PlaceHolderSQL, + Column: migrations.Column{ + Name: "bar", + Type: "int", + Pk: true, + }, +} + +var AddColumnOp6 = &migrations.OpAddColumn{ + Table: "foo", + Up: sql2pgroll.PlaceHolderSQL, + Column: migrations.Column{ + Name: "bar", + Type: "int", + Check: &migrations.CheckConstraint{ + Constraint: "bar > 0", + Name: "", + }, + }, +} + +var AddColumnOp7 = &migrations.OpAddColumn{ + Table: "foo", + Up: sql2pgroll.PlaceHolderSQL, + Column: migrations.Column{ + Name: "bar", + Type: "int", + Check: &migrations.CheckConstraint{ + Constraint: "bar > 0", + Name: "check_bar", + }, + }, +} + +func AddColumnOp8WithOnDeleteAction(action migrations.ForeignKeyReferenceOnDelete) *migrations.OpAddColumn { + return &migrations.OpAddColumn{ + Table: "foo", + Up: sql2pgroll.PlaceHolderSQL, + Column: migrations.Column{ + Name: "bar", + Type: "int", + References: &migrations.ForeignKeyReference{ + Column: "bar", + Name: "fk_baz", + OnDelete: action, + Table: "baz", + }, + }, + } +}