diff --git a/.golangci.yml b/.golangci.yml index 123606e..531fadc 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -41,7 +41,9 @@ linters-settings: - func errors.New(text string) error - func fmt.Errorf(format string, a ...any) error - func fmt.Errorf(format string, a ...interface{}) error - - func github.com/kunitsucom/util.go/errors.Errorf(format string, a ...interface{}) error + - func github.com/kunitsucom/util.go/apperr.Errorf(format string, a ...interface{}) error + - var github.com/kunitsucom/ddlctl/pkg/apperr.Errorf func(format string, a ...any) error + - var github.com/kunitsucom/ddlctl/pkg/apperr.Errorf func(format string, a ...interface{}) error issues: diff --git a/README.md b/README.md index e621288..4a8163e 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,8 @@ # [ddlctl](https://github.com/kunitsucom/ddlctl) +> [!WARNING] +> This project is experimental. It is operational in the author's environment, but it is not known if it can be operated in other environments without trouble. + [![license](https://img.shields.io/github/license/kunitsucom/ddlctl)](LICENSE) [![pkg](https://pkg.go.dev/badge/github.com/kunitsucom/ddlctl)](https://pkg.go.dev/github.com/kunitsucom/ddlctl) [![goreportcard](https://goreportcard.com/badge/github.com/kunitsucom/ddlctl)](https://goreportcard.com/report/github.com/kunitsucom/ddlctl) @@ -27,7 +30,7 @@ - `show` subcommand - dialect - [x] Support `mysql` (beta) - - [x] Support `postgres` (beta) + - [x] Support `postgres` (alpha) - [x] Support `cockroachdb` (beta) - [x] Support `spanner` (alpha) - [ ] Support `sqlite3` @@ -36,14 +39,14 @@ - [x] Support `mysql` (alpha) - [x] Support `postgres` (alpha) - [x] Support `cockroachdb` (alpha) - - [ ] Support `spanner` + - [x] Support `spanner` (alpha) - [ ] Support `sqlite3` - `apply` subcommand - dialect - [x] Support `mysql` (alpha) - [x] Support `postgres` (alpha) - [x] Support `cockroachdb` (alpha) - - [ ] Support `spanner` + - [x] Support `spanner` (alpha) - [ ] Support `sqlite3` ## Example diff --git a/go.mod b/go.mod index 3c10f72..c5aa688 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.21.5 require ( github.com/go-sql-driver/mysql v1.7.1 github.com/googleapis/go-sql-spanner v1.1.1 - github.com/kunitsucom/util.go v0.0.61-0.20240112184211-8b1d5e248ad7 + github.com/kunitsucom/util.go v0.0.62-rc.1 github.com/lib/pq v1.10.9 ) diff --git a/go.sum b/go.sum index 16353ac..f890c4f 100644 --- a/go.sum +++ b/go.sum @@ -74,8 +74,8 @@ github.com/googleapis/gax-go/v2 v2.12.0 h1:A+gCJKdRfqXkr+BIRGtZLibNXf0m1f9E4HG56 github.com/googleapis/gax-go/v2 v2.12.0/go.mod h1:y+aIqrI5eb1YGMVJfuV3185Ts/D7qKpsEkdD5+I6QGU= github.com/googleapis/go-sql-spanner v1.1.1 h1:Z5kRckvrSNokM/x21BBy23kg9b0e9ikkderuuLVXfGY= github.com/googleapis/go-sql-spanner v1.1.1/go.mod h1:e12AKZmltQH/2XGqR/2SAPWPKshc5+WF4W7OGD9YcAw= -github.com/kunitsucom/util.go v0.0.61-0.20240112184211-8b1d5e248ad7 h1:u6fQAHydd8pzKpGNDsHrAwGrU/7MM0CoIS9RdN46+Wo= -github.com/kunitsucom/util.go v0.0.61-0.20240112184211-8b1d5e248ad7/go.mod h1:bYFf2JvRqVF1brBtpdt3xkkTGJBxmYBxZlItrc/lf7Y= +github.com/kunitsucom/util.go v0.0.62-rc.1 h1:IgyOfnSNrzj0K0bxjU3oJaUTBsIFkHPLFMEeM32IyE8= +github.com/kunitsucom/util.go v0.0.62-rc.1/go.mod h1:bYFf2JvRqVF1brBtpdt3xkkTGJBxmYBxZlItrc/lf7Y= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= diff --git a/pkg/errors/errors.go b/pkg/apperr/errors.go similarity index 85% rename from pkg/errors/errors.go rename to pkg/apperr/errors.go index 0544d0a..c2f2611 100644 --- a/pkg/errors/errors.go +++ b/pkg/apperr/errors.go @@ -1,6 +1,9 @@ -package errors +package apperr -import "errors" +import ( + "errors" + "fmt" +) var ( ErrNotSupported = errors.New("not supported") @@ -11,3 +14,6 @@ var ( ErrBothArgumentsIsDSN = errors.New("both arguments is dsn") ErrBothArgumentsAreNotDSNOrSQLFile = errors.New("both arguments are not dsn or sql file") ) + +//nolint:gochecknoglobals +var Errorf = fmt.Errorf diff --git a/pkg/ddl/cockroachdb/ddl.go b/pkg/ddl/cockroachdb/ddl.go index 64bca93..9e21cda 100644 --- a/pkg/ddl/cockroachdb/ddl.go +++ b/pkg/ddl/cockroachdb/ddl.go @@ -7,8 +7,8 @@ import ( ) const ( - Dialect = "cockroachdb" - DriverName = "postgres" // cockroachdb's driver is postgres + Dialect = "cockroachdb" //diff:ignore-line-postgres-cockroach + DriverName = "postgres" // cockroachdb's driver is postgres //diff:ignore-line-postgres-cockroach Indent = " " CommentPrefix = "-- " ) diff --git a/pkg/ddl/cockroachdb/ddl_table.go b/pkg/ddl/cockroachdb/ddl_table.go index b57fab3..28e8231 100644 --- a/pkg/ddl/cockroachdb/ddl_table.go +++ b/pkg/ddl/cockroachdb/ddl_table.go @@ -85,6 +85,7 @@ type ForeignKeyConstraint struct { Columns []*ColumnIdent Ref *Ident RefColumns []*ColumnIdent + OnAction string } var _ Constraint = (*ForeignKeyConstraint)(nil) @@ -101,6 +102,9 @@ func (c *ForeignKeyConstraint) String() string { str += " (" + stringz.JoinStringers(", ", c.Columns...) + ")" str += " REFERENCES " + c.Ref.String() str += " (" + stringz.JoinStringers(", ", c.RefColumns...) + ")" + if c.OnAction != "" { + str += " " + c.OnAction + } return str } @@ -127,6 +131,9 @@ func (c *ForeignKeyConstraint) StringForDiff() string { str += v.StringForDiff() } str += ")" + if c.OnAction != "" { + str += " " + c.OnAction + } return str } diff --git a/pkg/ddl/cockroachdb/ddl_table_test.go b/pkg/ddl/cockroachdb/ddl_table_test.go index 8070388..80095d1 100644 --- a/pkg/ddl/cockroachdb/ddl_table_test.go +++ b/pkg/ddl/cockroachdb/ddl_table_test.go @@ -66,11 +66,16 @@ func TestForeignKeyConstraint(t *testing.T) { Columns: []*ColumnIdent{{Ident: &Ident{Name: "group_id", QuotationMark: `"`, Raw: `"group_id"`}}}, Ref: &Ident{Name: "groups", QuotationMark: `"`, Raw: `"groups"`}, RefColumns: []*ColumnIdent{{Ident: &Ident{Name: "id", QuotationMark: `"`, Raw: `"id"`}}}, + OnAction: "ON DELETE NO ACTION", } - expected := `CONSTRAINT "fk_users_groups" FOREIGN KEY ("group_id") REFERENCES "groups" ("id")` - actual := foreignKeyConstraint.String() - require.Equal(t, expected, actual) + expectedString := `CONSTRAINT "fk_users_groups" FOREIGN KEY ("group_id") REFERENCES "groups" ("id") ON DELETE NO ACTION` + actualString := foreignKeyConstraint.String() + require.Equal(t, expectedString, actualString) + + expectedStringForDiff := `CONSTRAINT fk_users_groups FOREIGN KEY (group_id ASC) REFERENCES groups (id ASC) ON DELETE NO ACTION` + actualStringForDiff := foreignKeyConstraint.StringForDiff() + require.Equal(t, expectedStringForDiff, actualStringForDiff) t.Logf("✅: %s: foreignKeyConstraint: %#v", t.Name(), foreignKeyConstraint) }) diff --git a/pkg/ddl/cockroachdb/diff.go b/pkg/ddl/cockroachdb/diff.go index 48b0dc4..5610efa 100644 --- a/pkg/ddl/cockroachdb/diff.go +++ b/pkg/ddl/cockroachdb/diff.go @@ -6,6 +6,8 @@ import ( errorz "github.com/kunitsucom/util.go/errors" "github.com/kunitsucom/util.go/exp/diff/simplediff" + apperr "github.com/kunitsucom/ddlctl/pkg/apperr" + "github.com/kunitsucom/ddlctl/pkg/ddl" ) @@ -29,7 +31,7 @@ func Diff(before, after *DDL) (*DDL, error) { Name: s.Name, }) default: - return nil, errorz.Errorf("%s: %T: %w", s.GetNameForDiff(), s, ddl.ErrNotSupported) + return nil, apperr.Errorf("%s: %T: %w", s.GetNameForDiff(), s, ddl.ErrNotSupported) } } return result, nil @@ -49,7 +51,7 @@ func Diff(before, after *DDL) (*DDL, error) { Name: beforeStmt.Name, }) default: - return nil, errorz.Errorf("%s: %T: %w", beforeStmt.GetNameForDiff(), beforeStmt, ddl.ErrNotSupported) + return nil, apperr.Errorf("%s: %T: %w", beforeStmt.GetNameForDiff(), beforeStmt, ddl.ErrNotSupported) } } @@ -61,7 +63,7 @@ func Diff(before, after *DDL) (*DDL, error) { case *CreateIndexStmt: result.Stmts = append(result.Stmts, afterStmt) default: - return nil, errorz.Errorf("%s: %T: %w", afterStmt.GetNameForDiff(), afterStmt, ddl.ErrNotSupported) + return nil, apperr.Errorf("%s: %T: %w", afterStmt.GetNameForDiff(), afterStmt, ddl.ErrNotSupported) } } diff --git a/pkg/ddl/cockroachdb/diff_create_table.go b/pkg/ddl/cockroachdb/diff_create_table.go index 6dca58b..58032e6 100644 --- a/pkg/ddl/cockroachdb/diff_create_table.go +++ b/pkg/ddl/cockroachdb/diff_create_table.go @@ -3,9 +3,10 @@ package cockroachdb import ( "reflect" - errorz "github.com/kunitsucom/util.go/errors" "github.com/kunitsucom/util.go/exp/diff/simplediff" + apperr "github.com/kunitsucom/ddlctl/pkg/apperr" + "github.com/kunitsucom/ddlctl/pkg/ddl" ) @@ -53,7 +54,7 @@ func DiffCreateTable(before, after *CreateTableStmt, opts ...DiffCreateTableOpti }) return result, nil case (before == nil && after == nil) || reflect.DeepEqual(before, after) || before.String() == after.String(): - return nil, errorz.Errorf("before: %s, after: %s: %w", before.GetNameForDiff(), after.GetNameForDiff(), ddl.ErrNoDifference) + return nil, apperr.Errorf("before: %s, after: %s: %w", before.GetNameForDiff(), after.GetNameForDiff(), ddl.ErrNoDifference) } if before.Name.StringForDiff() != after.Name.StringForDiff() { @@ -181,7 +182,7 @@ func DiffCreateTable(before, after *CreateTableStmt, opts ...DiffCreateTableOpti } if len(result.Stmts) == 0 { - return nil, errorz.Errorf("before: %s, after: %s: %w", before.GetNameForDiff(), after.GetNameForDiff(), ddl.ErrNoDifference) + return nil, apperr.Errorf("before: %s, after: %s: %w", before.GetNameForDiff(), after.GetNameForDiff(), ddl.ErrNoDifference) } return result, nil diff --git a/pkg/ddl/cockroachdb/lexar.go b/pkg/ddl/cockroachdb/lexar.go index 13034a2..78eb823 100644 --- a/pkg/ddl/cockroachdb/lexar.go +++ b/pkg/ddl/cockroachdb/lexar.go @@ -6,11 +6,6 @@ import ( // MEMO: https://www.postgresql.jp/docs/11/datatype.html -const ( - QuotationChar = '"' - QuotationStr = string(QuotationChar) -) - // Token はSQL文のトークンを表す型です。 type Token struct { Type TokenType @@ -63,6 +58,8 @@ const ( TOKEN_DROP TokenType = "DROP" TOKEN_RENAME TokenType = "RENAME" TOKEN_TRUNCATE TokenType = "TRUNCATE" + TOKEN_DELETE TokenType = "DELETE" + TOKEN_UPDATE TokenType = "UPDATE" // OBJECT. TOKEN_TABLE TokenType = "TABLE" @@ -108,6 +105,9 @@ const ( TOKEN_NOT TokenType = "NOT" TOKEN_ASC TokenType = "ASC" TOKEN_DESC TokenType = "DESC" + TOKEN_CASCADE TokenType = "CASCADE" + TOKEN_NO TokenType = "NO" + TOKEN_ACTION TokenType = "ACTION" // CONSTRAINT. TOKEN_CONSTRAINT TokenType = "CONSTRAINT" @@ -155,6 +155,10 @@ func lookupIdent(ident string) TokenType { return TOKEN_RENAME case "TRUNCATE": return TOKEN_TRUNCATE + case "DELETE": + return TOKEN_DELETE + case "UPDATE": + return TOKEN_UPDATE case "TABLE": return TOKEN_TABLE case "INDEX": @@ -225,6 +229,12 @@ func lookupIdent(ident string) TokenType { return TOKEN_ASC case "DESC": return TOKEN_DESC + case "CASCADE": + return TOKEN_CASCADE + case "NO": + return TOKEN_NO + case "ACTION": + return TOKEN_ACTION case "CONSTRAINT": return TOKEN_CONSTRAINT case "PRIMARY": diff --git a/pkg/ddl/cockroachdb/lexar_test.go b/pkg/ddl/cockroachdb/lexar_test.go index 822402b..7d57dac 100644 --- a/pkg/ddl/cockroachdb/lexar_test.go +++ b/pkg/ddl/cockroachdb/lexar_test.go @@ -26,6 +26,8 @@ func Test_lookupIdent(t *testing.T) { {name: "success,DROP", input: "DROP", want: TOKEN_DROP}, {name: "success,RENAME", input: "RENAME", want: TOKEN_RENAME}, {name: "success,TRUNCATE", input: "TRUNCATE", want: TOKEN_TRUNCATE}, + {name: "success,DELETE", input: "DELETE", want: TOKEN_DELETE}, + {name: "success,UPDATE", input: "UPDATE", want: TOKEN_UPDATE}, {name: "success,TABLE", input: "TABLE", want: TOKEN_TABLE}, {name: "success,INDEX", input: "INDEX", want: TOKEN_INDEX}, {name: "success,VIEW", input: "VIEW", want: TOKEN_VIEW}, @@ -63,6 +65,7 @@ func Test_lookupIdent(t *testing.T) { {name: "success,NULL", input: "NULL", want: TOKEN_NULL}, {name: "success,ASC", input: "ASC", want: TOKEN_ASC}, {name: "success,DESC", input: "DESC", want: TOKEN_DESC}, + {name: "success,CASCADE", input: "CASCADE", want: TOKEN_CASCADE}, {name: "success,CONSTRAINT", input: "CONSTRAINT", want: TOKEN_CONSTRAINT}, {name: "success,PRIMARY", input: "PRIMARY", want: TOKEN_PRIMARY}, {name: "success,KEY", input: "KEY", want: TOKEN_KEY}, diff --git a/pkg/ddl/cockroachdb/parser.go b/pkg/ddl/cockroachdb/parser.go index c7af9b8..0191f16 100644 --- a/pkg/ddl/cockroachdb/parser.go +++ b/pkg/ddl/cockroachdb/parser.go @@ -8,10 +8,11 @@ import ( "runtime" "strings" - errorz "github.com/kunitsucom/util.go/errors" filepathz "github.com/kunitsucom/util.go/path/filepath" stringz "github.com/kunitsucom/util.go/strings" + apperr "github.com/kunitsucom/ddlctl/pkg/apperr" + "github.com/kunitsucom/ddlctl/pkg/ddl" "github.com/kunitsucom/ddlctl/pkg/ddl/logs" ) @@ -83,7 +84,7 @@ LabelDDL: case TOKEN_CREATE: stmt, err := p.parseCreateStatement() if err != nil { - return nil, errorz.Errorf("parseCreateStatement: %w", err) + return nil, apperr.Errorf("parseCreateStatement: %w", err) } d.Stmts = append(d.Stmts, stmt) case TOKEN_CLOSE_PAREN: @@ -93,7 +94,7 @@ LabelDDL: case TOKEN_EOF: break LabelDDL default: - return nil, errorz.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + return nil, apperr.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) } p.nextToken() @@ -110,7 +111,7 @@ func (p *Parser) parseCreateStatement() (Stmt, error) { //nolint:ireturn case TOKEN_INDEX, TOKEN_UNIQUE: return p.parseCreateIndexStmt() default: - return nil, errorz.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + return nil, apperr.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) } } @@ -123,11 +124,11 @@ func (p *Parser) parseCreateTableStmt() (*CreateTableStmt, error) { if p.isPeekToken(TOKEN_IF) { p.nextToken() // current = IF if err := p.checkPeekToken(TOKEN_NOT); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = NOT if err := p.checkPeekToken(TOKEN_EXISTS); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = EXISTS createTableStmt.IfNotExists = true @@ -135,7 +136,7 @@ func (p *Parser) parseCreateTableStmt() (*CreateTableStmt, error) { p.nextToken() // current = table_name if err := p.checkCurrentToken(TOKEN_IDENT); err != nil { - return nil, errorz.Errorf("checkCurrentToken: %w", err) + return nil, apperr.Errorf("checkCurrentToken: %w", err) } createTableStmt.Name = NewObjectName(p.currentToken.Literal.Str) @@ -144,7 +145,7 @@ func (p *Parser) parseCreateTableStmt() (*CreateTableStmt, error) { p.nextToken() // current = ( if err := p.checkCurrentToken(TOKEN_OPEN_PAREN); err != nil { - return nil, errorz.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) + return nil, apperr.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) } p.nextToken() // current = column_name @@ -155,7 +156,7 @@ LabelColumns: case p.isCurrentToken(TOKEN_IDENT): column, constraints, err := p.parseColumn(createTableStmt.Name.Name) if err != nil { - return nil, errorz.Errorf(errFmtPrefix+"parseColumn: %w", err) + return nil, apperr.Errorf(errFmtPrefix+"parseColumn: %w", err) } createTableStmt.Columns = append(createTableStmt.Columns, column) if len(constraints) > 0 { @@ -166,7 +167,7 @@ LabelColumns: case isConstraint(p.currentToken.Type): constraint, err := p.parseTableConstraint(createTableStmt.Name.Name) if err != nil { - return nil, errorz.Errorf(errFmtPrefix+"parseConstraint: %w", err) + return nil, apperr.Errorf(errFmtPrefix+"parseConstraint: %w", err) } createTableStmt.Constraints = createTableStmt.Constraints.Append(constraint) case p.isCurrentToken(TOKEN_COMMA): @@ -177,10 +178,10 @@ LabelColumns: case TOKEN_SEMICOLON, TOKEN_EOF: break LabelColumns default: - return nil, errorz.Errorf(errFmtPrefix+"peekToken=%#v: %w", p.peekToken, ddl.ErrUnexpectedToken) + return nil, apperr.Errorf(errFmtPrefix+"peekToken=%#v: %w", p.peekToken, ddl.ErrUnexpectedToken) } default: - return nil, errorz.Errorf(errFmtPrefix+"currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + return nil, apperr.Errorf(errFmtPrefix+"currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) } } @@ -199,11 +200,11 @@ func (p *Parser) parseCreateIndexStmt() (*CreateIndexStmt, error) { if p.isPeekToken(TOKEN_IF) { p.nextToken() // current = IF if err := p.checkPeekToken(TOKEN_NOT); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = NOT if err := p.checkPeekToken(TOKEN_EXISTS); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = EXISTS createIndexStmt.IfNotExists = true @@ -211,7 +212,7 @@ func (p *Parser) parseCreateIndexStmt() (*CreateIndexStmt, error) { p.nextToken() // current = index_name if err := p.checkCurrentToken(TOKEN_IDENT); err != nil { - return nil, errorz.Errorf("checkCurrentToken: %w", err) + return nil, apperr.Errorf("checkCurrentToken: %w", err) } createIndexStmt.Name = NewObjectName(p.currentToken.Literal.Str) @@ -220,13 +221,13 @@ func (p *Parser) parseCreateIndexStmt() (*CreateIndexStmt, error) { p.nextToken() // current = ON if err := p.checkCurrentToken(TOKEN_ON); err != nil { - return nil, errorz.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) + return nil, apperr.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) } p.nextToken() // current = table_name if err := p.checkCurrentToken(TOKEN_IDENT); err != nil { - return nil, errorz.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) + return nil, apperr.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) } createIndexStmt.TableName = NewObjectName(p.currentToken.Literal.Str) @@ -240,12 +241,12 @@ func (p *Parser) parseCreateIndexStmt() (*CreateIndexStmt, error) { } if err := p.checkCurrentToken(TOKEN_OPEN_PAREN); err != nil { - return nil, errorz.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) + return nil, apperr.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) } idents, err := p.parseColumnIdents() if err != nil { - return nil, errorz.Errorf(errFmtPrefix+"parseColumnIdents: %w", err) + return nil, apperr.Errorf(errFmtPrefix+"parseColumnIdents: %w", err) } createIndexStmt.Columns = idents @@ -259,7 +260,7 @@ func (p *Parser) parseColumn(tableName *Ident) (*Column, []Constraint, error) { constraints := make(Constraints, 0) if err := p.checkCurrentToken(TOKEN_IDENT); err != nil { - return nil, nil, errorz.Errorf("checkCurrentToken: %w", err) + return nil, nil, apperr.Errorf("checkCurrentToken: %w", err) } column.Name = NewRawIdent(p.currentToken.Literal.Str) @@ -271,7 +272,7 @@ func (p *Parser) parseColumn(tableName *Ident) (*Column, []Constraint, error) { case isDataType(p.currentToken.Type): dataType, err := p.parseDataType() if err != nil { - return nil, nil, errorz.Errorf(errFmtPrefix+"parseDataType: %w", err) + return nil, nil, apperr.Errorf(errFmtPrefix+"parseDataType: %w", err) } column.DataType = dataType @@ -281,7 +282,7 @@ func (p *Parser) parseColumn(tableName *Ident) (*Column, []Constraint, error) { switch p.currentToken.Type { //nolint:exhaustive case TOKEN_NOT: if err := p.checkPeekToken(TOKEN_NULL); err != nil { - return nil, nil, errorz.Errorf(errFmtPrefix+"checkPeekToken: %w", err) + return nil, nil, apperr.Errorf(errFmtPrefix+"checkPeekToken: %w", err) } p.nextToken() // current = NULL column.NotNull = true @@ -291,7 +292,7 @@ func (p *Parser) parseColumn(tableName *Ident) (*Column, []Constraint, error) { p.nextToken() // current = DEFAULT def, err := p.parseColumnDefault() if err != nil { - return nil, nil, errorz.Errorf(errFmtPrefix+"parseColumnDefault: %w", err) + return nil, nil, apperr.Errorf(errFmtPrefix+"parseColumnDefault: %w", err) } column.Default = def continue @@ -304,7 +305,7 @@ func (p *Parser) parseColumn(tableName *Ident) (*Column, []Constraint, error) { cs, err := p.parseColumnConstraints(tableName, column) if err != nil { - return nil, nil, errorz.Errorf(errFmtPrefix+"parseColumnConstraints: %w", err) + return nil, nil, apperr.Errorf(errFmtPrefix+"parseColumnConstraints: %w", err) } if len(cs) > 0 { for _, c := range cs { @@ -312,7 +313,7 @@ func (p *Parser) parseColumn(tableName *Ident) (*Column, []Constraint, error) { } } default: - return nil, nil, errorz.Errorf(errFmtPrefix+"currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + return nil, nil, apperr.Errorf(errFmtPrefix+"currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) } return column, constraints, nil @@ -330,7 +331,7 @@ LabelDefault: case TOKEN_OPEN_PAREN: ids, err := p.parseExpr() if err != nil { - return nil, errorz.Errorf("parseExpr: %w", err) + return nil, apperr.Errorf("parseExpr: %w", err) } def.Value = def.Value.Append(ids...) continue @@ -355,7 +356,7 @@ LabelDefault: if isConstraint(p.currentToken.Type) { break LabelDefault } - return nil, errorz.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + return nil, apperr.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) } p.nextToken() @@ -364,11 +365,12 @@ LabelDefault: return def, nil } +//nolint:cyclop func (p *Parser) parseExpr() ([]*Ident, error) { idents := make([]*Ident, 0) if err := p.checkCurrentToken(TOKEN_OPEN_PAREN); err != nil { - return nil, errorz.Errorf("checkCurrentToken: %w", err) + return nil, apperr.Errorf("checkCurrentToken: %w", err) } idents = append(idents, NewRawIdent(p.currentToken.Literal.Str)) p.nextToken() // current = IDENT @@ -379,7 +381,7 @@ LabelExpr: case TOKEN_OPEN_PAREN: ids, err := p.parseExpr() if err != nil { - return nil, errorz.Errorf("parseExpr: %w", err) + return nil, apperr.Errorf("parseExpr: %w", err) } idents = append(idents, ids...) continue @@ -396,9 +398,13 @@ LabelExpr: } idents = append(idents, NewRawIdent(value)) case TOKEN_EOF: - return nil, errorz.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + return nil, apperr.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) default: - idents = append(idents, NewRawIdent(p.currentToken.Literal.Str)) + if isReservedValue(p.currentToken.Type) { + idents = append(idents, NewRawIdent(p.currentToken.Type.String())) + } else { + idents = append(idents, NewRawIdent(p.currentToken.Literal.Str)) + } } p.nextToken() @@ -416,7 +422,7 @@ LabelConstraints: switch p.currentToken.Type { //nolint:exhaustive case TOKEN_PRIMARY: if err := p.checkPeekToken(TOKEN_KEY); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = KEY constraints = constraints.Append(&PrimaryKeyConstraint{ @@ -425,7 +431,7 @@ LabelConstraints: }) case TOKEN_REFERENCES: if err := p.checkPeekToken(TOKEN_IDENT); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = table_name constraint := &ForeignKeyConstraint{ @@ -436,7 +442,30 @@ LabelConstraints: p.nextToken() // current = ( idents, err := p.parseColumnIdents() if err != nil { - return nil, errorz.Errorf("parseColumnIdents: %w", err) + return nil, apperr.Errorf("parseColumnIdents: %w", err) + } + // TODO: support ON DELETE, ON UPDATE + //nolint:nestif + if p.isCurrentToken(TOKEN_ON) { + onAction := p.currentToken.Literal.String() // current = ON + p.nextToken() // current = DELETE or UPDATE + if err := p.checkCurrentToken(TOKEN_DELETE, TOKEN_UPDATE); err != nil { + return nil, apperr.Errorf("checkCurrentToken: %w", err) + } + onAction += " " + p.currentToken.Literal.String() + if err := p.checkPeekToken(TOKEN_CASCADE, TOKEN_NO); err != nil { + return nil, apperr.Errorf("checkPeekToken: %w", err) + } + p.nextToken() // current = CASCADE or NO + onAction += " " + p.currentToken.Literal.String() // current = CASCADE or NO + if p.isCurrentToken(TOKEN_NO) { + if err := p.checkPeekToken(TOKEN_ACTION); err != nil { + return nil, apperr.Errorf("checkPeekToken: %w", err) + } + p.nextToken() // current = ACTION + onAction += " " + p.currentToken.Literal.String() // current = ACTION + } + constraint.OnAction = onAction } constraint.RefColumns = idents constraints = constraints.Append(constraint) @@ -448,7 +477,7 @@ LabelConstraints: }) case TOKEN_CHECK: if err := p.checkPeekToken(TOKEN_OPEN_PAREN); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = ( constraint := &CheckConstraint{ @@ -456,14 +485,14 @@ LabelConstraints: } idents, err := p.parseExpr() if err != nil { - return nil, errorz.Errorf("parseExpr: %w", err) + return nil, apperr.Errorf("parseExpr: %w", err) } constraint.Expr = constraint.Expr.Append(idents...) constraints = constraints.Append(constraint) case TOKEN_IDENT, TOKEN_COMMA, TOKEN_CLOSE_PAREN: break LabelConstraints default: - return nil, errorz.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + return nil, apperr.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) } p.nextToken() @@ -478,7 +507,7 @@ func (p *Parser) parseTableConstraint(tableName *Ident) (Constraint, error) { // if p.isCurrentToken(TOKEN_CONSTRAINT) { p.nextToken() // current = constraint_name if p.currentToken.Type != TOKEN_IDENT { - return nil, errorz.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + return nil, apperr.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) } constraintName = NewRawIdent(p.currentToken.Literal.Str) p.nextToken() // current = PRIMARY or CHECK //diff:ignore-line-postgres-cockroach @@ -487,16 +516,16 @@ func (p *Parser) parseTableConstraint(tableName *Ident) (Constraint, error) { // switch p.currentToken.Type { //nolint:exhaustive case TOKEN_PRIMARY: if err := p.checkPeekToken(TOKEN_KEY); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = KEY if err := p.checkPeekToken(TOKEN_OPEN_PAREN); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = ( idents, err := p.parseColumnIdents() if err != nil { - return nil, errorz.Errorf("parseColumnIdents: %w", err) + return nil, apperr.Errorf("parseColumnIdents: %w", err) } if constraintName == nil { constraintName = NewRawIdent(fmt.Sprintf("%s_pkey", tableName.StringForDiff())) @@ -507,30 +536,49 @@ func (p *Parser) parseTableConstraint(tableName *Ident) (Constraint, error) { // }, nil case TOKEN_FOREIGN: if err := p.checkPeekToken(TOKEN_KEY); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = KEY if err := p.checkPeekToken(TOKEN_OPEN_PAREN); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = ( idents, err := p.parseColumnIdents() if err != nil { - return nil, errorz.Errorf("parseColumnIdents: %w", err) + return nil, apperr.Errorf("parseColumnIdents: %w", err) } if err := p.checkCurrentToken(TOKEN_REFERENCES); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = ref_table_name if err := p.checkCurrentToken(TOKEN_IDENT); err != nil { - return nil, errorz.Errorf("checkCurrentToken: %w", err) + return nil, apperr.Errorf("checkCurrentToken: %w", err) } refName := NewRawIdent(p.currentToken.Literal.Str) p.nextToken() // current = ( identsRef, err := p.parseColumnIdents() if err != nil { - return nil, errorz.Errorf("parseColumnIdents: %w", err) + return nil, apperr.Errorf("parseColumnIdents: %w", err) + } + // TODO: support ON DELETE, ON UPDATE + var onAction string + if p.isCurrentToken(TOKEN_ON) { + onAction = p.currentToken.Literal.String() // current = ON + p.nextToken() // current = DELETE or UPDATE + if err := p.checkCurrentToken(TOKEN_DELETE, TOKEN_UPDATE); err != nil { + return nil, apperr.Errorf("checkCurrentToken: %w", err) + } + onAction += " " + p.currentToken.Literal.String() + if err := p.checkPeekToken(TOKEN_CASCADE, TOKEN_NO); err != nil { + return nil, apperr.Errorf("checkPeekToken: %w", err) + } + p.nextToken() // current = CASCADE or NO + onAction += " " + p.currentToken.Literal.String() // current = CASCADE or NO + if p.isCurrentToken(TOKEN_NO) && p.isPeekToken(TOKEN_ACTION) { + p.nextToken() // current = ACTION + onAction += " " + p.currentToken.Literal.String() // current = ACTION + } } if constraintName == nil { name := tableName.StringForDiff() @@ -545,6 +593,7 @@ func (p *Parser) parseTableConstraint(tableName *Ident) (Constraint, error) { // Columns: idents, Ref: refName, RefColumns: identsRef, + OnAction: onAction, }, nil case TOKEN_UNIQUE, TOKEN_INDEX: //diff:ignore-line-postgres-cockroach @@ -552,28 +601,28 @@ func (p *Parser) parseTableConstraint(tableName *Ident) (Constraint, error) { // if p.isCurrentToken(TOKEN_UNIQUE) { //diff:ignore-line-postgres-cockroach c.Unique = true //diff:ignore-line-postgres-cockroach if err := p.checkPeekToken(TOKEN_INDEX); err != nil { //diff:ignore-line-postgres-cockroach - return nil, errorz.Errorf("checkPeekToken: %w", err) //diff:ignore-line-postgres-cockroach + return nil, apperr.Errorf("checkPeekToken: %w", err) //diff:ignore-line-postgres-cockroach } //diff:ignore-line-postgres-cockroach p.nextToken() // current = INDEX //diff:ignore-line-postgres-cockroach } //diff:ignore-line-postgres-cockroach p.nextToken() // current = index_name //diff:ignore-line-postgres-cockroach if err := p.checkCurrentToken(TOKEN_IDENT); err != nil { //diff:ignore-line-postgres-cockroach - return nil, errorz.Errorf("checkCurrentToken: %w", err) //diff:ignore-line-postgres-cockroach + return nil, apperr.Errorf("checkCurrentToken: %w", err) //diff:ignore-line-postgres-cockroach } //diff:ignore-line-postgres-cockroach constraintName := NewRawIdent(p.currentToken.Literal.Str) //diff:ignore-line-postgres-cockroach if err := p.checkPeekToken(TOKEN_OPEN_PAREN); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = ( idents, err := p.parseColumnIdents() if err != nil { - return nil, errorz.Errorf("parseColumnIdents: %w", err) + return nil, apperr.Errorf("parseColumnIdents: %w", err) } c.Name = constraintName c.Columns = idents return c, nil default: - return nil, errorz.Errorf("currentToken=%s: %w", p.currentToken.Type, ddl.ErrUnexpectedToken) + return nil, apperr.Errorf("currentToken=%s: %w", p.currentToken.Type, ddl.ErrUnexpectedToken) } } @@ -588,12 +637,12 @@ func (p *Parser) parseDataType() (*DataType, error) { p.nextToken() // current = WITH dataType.Name += " " + p.currentToken.Literal.String() if err := p.checkPeekToken(TOKEN_TIME); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = TIME dataType.Name += " " + p.currentToken.Literal.String() if err := p.checkPeekToken(TOKEN_ZONE); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = ZONE dataType.Name += " " + p.currentToken.Literal.String() @@ -604,7 +653,7 @@ func (p *Parser) parseDataType() (*DataType, error) { case TOKEN_DOUBLE: dataType.Name = p.currentToken.Literal.String() if err := p.checkPeekToken(TOKEN_PRECISION); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = PRECISION dataType.Name += " " + p.currentToken.Literal.String() @@ -612,7 +661,7 @@ func (p *Parser) parseDataType() (*DataType, error) { case TOKEN_CHARACTER: dataType.Name = p.currentToken.Literal.String() if err := p.checkPeekToken(TOKEN_VARYING); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = VARYING dataType.Name += " " + p.currentToken.Literal.String() @@ -626,7 +675,7 @@ func (p *Parser) parseDataType() (*DataType, error) { p.nextToken() // current = ( idents, err := p.parseIdents() if err != nil { - return nil, errorz.Errorf("parseIdents: %w", err) + return nil, apperr.Errorf("parseIdents: %w", err) } dataType.Expr = dataType.Expr.Append(idents...) } @@ -659,7 +708,7 @@ LabelIdents: p.nextToken() break LabelIdents default: - return nil, errorz.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + return nil, apperr.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) } p.nextToken() } @@ -680,7 +729,7 @@ LabelIdents: case TOKEN_CLOSE_PAREN: break LabelIdents case TOKEN_EOF, TOKEN_ILLEGAL: - return nil, errorz.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + return nil, apperr.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) default: idents = append(idents, NewRawIdent(p.currentToken.Literal.Str)) } @@ -757,7 +806,7 @@ func (p *Parser) checkCurrentToken(expectedTypes ...TokenType) error { return nil } } - return errorz.Errorf("currentToken: expected=%s, but got=%#v: %w", stringz.JoinStringers(",", expectedTypes...), p.currentToken, ddl.ErrUnexpectedToken) + return apperr.Errorf("currentToken: expected=%s, but got=%#v: %w", stringz.JoinStringers(",", expectedTypes...), p.currentToken, ddl.ErrUnexpectedToken) } func (p *Parser) isPeekToken(expectedTypes ...TokenType) bool { @@ -775,5 +824,5 @@ func (p *Parser) checkPeekToken(expectedTypes ...TokenType) error { return nil } } - return errorz.Errorf("peekToken: expected=%s, but got=%#v: %w", stringz.JoinStringers(",", expectedTypes...), p.peekToken, ddl.ErrUnexpectedToken) + return apperr.Errorf("peekToken: expected=%s, but got=%#v: %w", stringz.JoinStringers(",", expectedTypes...), p.peekToken, ddl.ErrUnexpectedToken) } diff --git a/pkg/ddl/cockroachdb/parser_test.go b/pkg/ddl/cockroachdb/parser_test.go index 8254051..15aa07c 100644 --- a/pkg/ddl/cockroachdb/parser_test.go +++ b/pkg/ddl/cockroachdb/parser_test.go @@ -21,17 +21,13 @@ func TestParser_Parse(t *testing.T) { }) logs.TraceLog = log.New(os.Stderr, "TRACE: ", log.LstdFlags|log.Lshortfile) - successTests := []struct { - name string - input string - wantErr error - wantStr string - }{ - { - name: "success,CREATE_TABLE", - input: `CREATE TABLE "groups" ("id" UUID NOT NULL PRIMARY KEY, description TEXT); CREATE TABLE "users" (id UUID NOT NULL, group_id UUID NOT NULL REFERENCES "groups" ("id"), "name" VARCHAR(255) NOT NULL UNIQUE, "age" INT DEFAULT 0 CHECK ("age" >= 0), description TEXT, PRIMARY KEY ("id"));`, - wantErr: nil, - wantStr: `CREATE TABLE "groups" ( + t.Run("success,CREATE_TABLE", func(t *testing.T) { + l := NewLexer(`CREATE TABLE "groups" ("id" UUID NOT NULL PRIMARY KEY, description TEXT); CREATE TABLE "users" (id UUID NOT NULL, group_id UUID NOT NULL REFERENCES "groups" ("id") ON DELETE NO ACTION, "name" VARCHAR(255) NOT NULL UNIQUE, "age" INT DEFAULT 0 CHECK ("age" >= 0), description TEXT, PRIMARY KEY ("id"));`) + p := NewParser(l) + actualDDL, err := p.Parse() + require.NoError(t, err) + + const expected = `CREATE TABLE "groups" ( "id" UUID NOT NULL, description TEXT, CONSTRAINT groups_pkey PRIMARY KEY ("id") @@ -43,15 +39,23 @@ CREATE TABLE "users" ( "age" INT DEFAULT 0, description TEXT, CONSTRAINT users_pkey PRIMARY KEY ("id"), - CONSTRAINT users_group_id_fkey FOREIGN KEY (group_id) REFERENCES "groups" ("id"), + CONSTRAINT users_group_id_fkey FOREIGN KEY (group_id) REFERENCES "groups" ("id") ON DELETE NO ACTION, UNIQUE INDEX users_unique_name ("name"), CONSTRAINT users_age_check CHECK ("age" >= 0) ); -`, - }, - { - name: "success,complex_defaults", - input: `-- table: complex_defaults +` + actual := actualDDL.String() + + if !assert.Equal(t, expected, actual) { + t.Fail() + } + + t.Logf("✅: %s: actual: %%#v: \n%#v", t.Name(), actualDDL) + t.Logf("✅: %s: actual: %%s: \n%s", t.Name(), actualDDL) + }) + + t.Run("success,complex_defaults", func(t *testing.T) { + l := NewLexer(`-- table: complex_defaults CREATE TABLE IF NOT EXISTS complex_defaults ( -- id is the primary key. id SERIAL PRIMARY KEY, @@ -63,9 +67,12 @@ CREATE TABLE IF NOT EXISTS complex_defaults ( json_data JSONB DEFAULT '{}', calculated_value INTEGER DEFAULT (SELECT COUNT(*) FROM another_table) ); -`, - wantErr: nil, - wantStr: `CREATE TABLE IF NOT EXISTS complex_defaults ( +`) + p := NewParser(l) + actualDDL, err := p.Parse() + require.NoError(t, err) + + const expected = `CREATE TABLE IF NOT EXISTS complex_defaults ( id SERIAL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, @@ -76,11 +83,19 @@ CREATE TABLE IF NOT EXISTS complex_defaults ( calculated_value INTEGER DEFAULT (SELECT COUNT(*) FROM another_table), CONSTRAINT complex_defaults_pkey PRIMARY KEY (id) ); -`, - }, - { - name: "success,CREATE_TABLE_TYPE_ANNOTATION", - input: `CREATE TABLE IF NOT EXISTS public.users ( +` + actual := actualDDL.String() + + if !assert.Equal(t, expected, actual) { + t.Fail() + } + + t.Logf("✅: %s: actual: %%#v: \n%#v", t.Name(), actualDDL) + t.Logf("✅: %s: actual: %%s: \n%s", t.Name(), actualDDL) + }) + + t.Run("success,CREATE_TABLE_TYPE_ANNOTATION", func(t *testing.T) { + l := NewLexer(`CREATE TABLE IF NOT EXISTS public.users ( user_id UUID NOT NULL, username VARCHAR(256) NOT NULL, is_verified BOOL NOT NULL DEFAULT false, @@ -89,9 +104,12 @@ CREATE TABLE IF NOT EXISTS complex_defaults ( CONSTRAINT users_pkey PRIMARY KEY (user_id ASC), INDEX users_idx_by_username (username DESC) ); -`, - wantErr: nil, - wantStr: `CREATE TABLE IF NOT EXISTS public.users ( +`) + p := NewParser(l) + actualDDL, err := p.Parse() + require.NoError(t, err) + + const expected = `CREATE TABLE IF NOT EXISTS public.users ( user_id UUID NOT NULL, username VARCHAR(256) NOT NULL, is_verified BOOL NOT NULL DEFAULT false, @@ -100,9 +118,23 @@ CREATE TABLE IF NOT EXISTS complex_defaults ( CONSTRAINT users_pkey PRIMARY KEY (user_id ASC), INDEX users_idx_by_username (username DESC) ); -`, - }, - } +` + actual := actualDDL.String() + + if !assert.Equal(t, expected, actual) { + t.Fail() + } + + t.Logf("✅: %s: actual: %%#v: \n%#v", t.Name(), actualDDL) + t.Logf("✅: %s: actual: %%s: \n%s", t.Name(), actualDDL) + }) + + successTests := []struct { + name string + input string + wantErr error + wantStr string + }{} for _, tt := range successTests { tt := tt @@ -249,6 +281,26 @@ CREATE TABLE IF NOT EXISTS complex_defaults ( input: `CREATE TABLE "users" ("id" UUID, PRIMARY KEY (NOT`, wantErr: ddl.ErrUnexpectedToken, }, + { + name: "failure,CREATE_TABLE_table_name_column_name_REFERENCES_INVALID", + input: `CREATE TABLE "users" ("id" UUID REFERENCES foo (foo_id) ON`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_REFERENCES_INVALID", + input: `CREATE TABLE "users" ("id" UUID REFERENCES foo (foo_id) ON DELETE`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_REFERENCES_INVALID", + input: `CREATE TABLE "users" ("id" UUID REFERENCES foo (foo_id) ON DELETE NO`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_REFERENCES_INVALID", + input: `CREATE TABLE "users" ("id" UUID REFERENCES foo (foo_id) ON DELETE NO ACTION`, + wantErr: ddl.ErrUnexpectedToken, + }, { name: "failure,CREATE_TABLE_table_name_column_name_CONSTRAINT_INVALID_FOREIGN", input: `CREATE TABLE "users" ("id" UUID, FOREIGN NOT`, @@ -284,6 +336,26 @@ CREATE TABLE IF NOT EXISTS complex_defaults ( input: `CREATE TABLE "users" ("id" UUID, FOREIGN KEY ("group_id") REFERENCES "groups" ("id")`, wantErr: ddl.ErrUnexpectedToken, }, + { + name: "failure,CREATE_TABLE_table_name_column_name_CONSTRAINT_FOREIGN_KEY_IDENTS_REFERENCES_ON_INVALID", + input: `CREATE TABLE "users" ("id" UUID, FOREIGN KEY ("group_id") REFERENCES "groups" ("id") ON`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_CONSTRAINT_FOREIGN_KEY_IDENTS_REFERENCES_ON_DELETE_INVALID", + input: `CREATE TABLE "users" ("id" UUID, FOREIGN KEY ("group_id") REFERENCES "groups" ("id") ON DELETE`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_CONSTRAINT_FOREIGN_KEY_IDENTS_REFERENCES_ON_DELETE_NO_INVALID", + input: `CREATE TABLE "users" ("id" UUID, FOREIGN KEY ("group_id") REFERENCES "groups" ("id") ON DELETE NO`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_CONSTRAINT_FOREIGN_KEY_IDENTS_REFERENCES_ON_DELETE_NO_INVALID", + input: `CREATE TABLE "users" ("id" UUID, FOREIGN KEY ("group_id") REFERENCES "groups" ("id") ON DELETE NO ACTION`, + wantErr: ddl.ErrUnexpectedToken, + }, { name: "failure,CREATE_TABLE_table_name_column_name_CONSTRAINT_UNIQUE_INVALID", input: `CREATE TABLE "users" ("id" UUID, UNIQUE NOT`, @@ -414,6 +486,16 @@ func TestParser_parseColumn(t *testing.T) { func TestParser_parseExpr(t *testing.T) { t.Parallel() + t.Run("success,isReservedValue", func(t *testing.T) { + t.Parallel() + + p := NewParser(NewLexer(`(null)`)) + p.nextToken() + p.nextToken() + _, err := p.parseExpr() + require.NoError(t, err) + }) + t.Run("failure,invalid", func(t *testing.T) { t.Parallel() diff --git a/pkg/ddl/diff-cockroachdb-spanner.sh b/pkg/ddl/diff-cockroachdb-spanner.sh new file mode 100755 index 0000000..5d6c269 --- /dev/null +++ b/pkg/ddl/diff-cockroachdb-spanner.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash +set -Eeuo pipefail + +# https://github.com/ginokent/cdiff/blob/cbf77fa4186b309c829be3b15fa00b99e563de7c/bin/cdiff#L36 +cdiff() { ( + if command -v diff-so-fancy >/dev/null; then + diff -u "$@" | diff-so-fancy + else + if [ -t 0 ]; then + P=printf C="\033" R=$($P "$C\[31m") + G=$($P "$C\[32m") + B=$($P "$C\[36m") + W=$($P "$C\[1m") + N=$($P "$C\[0m") + fi + diff -u "$@" | sed "s/^\(@@..*@@\)$/${B-}\1${N-}/;s/^\(+.*\)/${G-}\1${N-}/;s/^\(-.*\)/${R-}\1${N-}/;s/^${G-}\(+++ [^ ].*\)/${W-}\1/;s/^${R-}\(--- [^ ].*\)/${W-}\1/;" + fi +); } +export -f cdiff + +diff_envs() { cdiff "$@" | perl -pe "s/(Only in .*: .*)/\033\[1;33m\1\033\[0m/"; } +export -f diff_envs + +cd "$(dirname "$0")" + +diff_envs \ + --recursive \ + --exclude="*_test.go" \ + --ignore-blank-lines \ + --ignore-space-change \ + --ignore-matching-lines="//diff:ignore-line-spanner-cockroach" \ + --ignore-matching-lines="package spanner" \ + --ignore-matching-lines="package cockroachdb" \ + cockroachdb \ + spanner | + less --tabs=4 -RFX diff --git a/pkg/ddl/errors.go b/pkg/ddl/errors.go index 1aeec77..7247f1d 100644 --- a/pkg/ddl/errors.go +++ b/pkg/ddl/errors.go @@ -1,9 +1,12 @@ package ddl -import "errors" +import ( + "errors" +) var ( - ErrUnexpectedToken = errors.New("unexpected token") - ErrNoDifference = errors.New("no difference") - ErrNotSupported = errors.New("not supported") + ErrUnexpectedToken = errors.New("unexpected token") + ErrNoDifference = errors.New("no difference") + ErrNotSupported = errors.New("not supported") + ErrAlterOptionNotSupported = errors.New("alter option not supported") ) diff --git a/pkg/ddl/mysql/ddl_table_alter.go b/pkg/ddl/mysql/ddl_table_alter.go index 64e1500..604bb83 100644 --- a/pkg/ddl/mysql/ddl_table_alter.go +++ b/pkg/ddl/mysql/ddl_table_alter.go @@ -7,6 +7,7 @@ import ( ) // MEMO: https://dev.mysql.com/doc/refman/8.0/ja/alter-table.html +// NOTE: https://dev.mysql.com/doc/refman/8.0/ja/alter-table-examples.html var _ Stmt = (*AlterTableStmt)(nil) @@ -49,18 +50,16 @@ func (s *AlterTableStmt) String() string { case *DropColumn: str += "DROP COLUMN " + a.Name.String() case *AlterColumn: - str += "ALTER COLUMN " + a.Name.String() + " " switch ca := a.Action.(type) { - case *AlterColumnSetDataType: - str += "SET DATA TYPE " + ca.DataType.String() + case *AlterColumnDataType: + str += "MODIFY " + a.Name.String() + " " + ca.DataType.String() + if ca.NotNull { + str += " NOT NULL" + } case *AlterColumnSetDefault: - str += "SET " + ca.Default.String() + str += "ALTER " + a.Name.String() + " " + "SET " + ca.Default.String() case *AlterColumnDropDefault: - str += "DROP DEFAULT" - case *AlterColumnSetNotNull: - str += "SET NOT NULL" - case *AlterColumnDropNotNull: - str += "DROP NOT NULL" + str += "ALTER " + a.Name.String() + " " + "DROP DEFAULT" } case *AddConstraint: str += "ADD " + a.Constraint.String() @@ -160,14 +159,16 @@ type AlterColumnAction interface { GoString() string } -// AlterColumnSetDataType represents ALTER TABLE table_name ALTER COLUMN column_name SET DATA TYPE. -type AlterColumnSetDataType struct { +// AlterColumnDataType represents ALTER TABLE table_name MODIFY column_name data_type NOT NULL. +// NOTE: https://dev.mysql.com/doc/refman/8.0/ja/alter-table-examples.html +type AlterColumnDataType struct { DataType *DataType + NotNull bool } -func (*AlterColumnSetDataType) isAlterColumnAction() {} +func (*AlterColumnDataType) isAlterColumnAction() {} -func (s *AlterColumnSetDataType) GoString() string { return internal.GoString(*s) } +func (s *AlterColumnDataType) GoString() string { return internal.GoString(*s) } // AlterColumnSetDefault represents ALTER TABLE table_name ALTER COLUMN column_name SET DEFAULT. type AlterColumnSetDefault struct { @@ -185,20 +186,6 @@ func (*AlterColumnDropDefault) isAlterColumnAction() {} func (s *AlterColumnDropDefault) GoString() string { return internal.GoString(*s) } -// AlterColumnSetNotNull represents ALTER TABLE table_name ALTER COLUMN column_name SET NOT NULL. -type AlterColumnSetNotNull struct{} - -func (*AlterColumnSetNotNull) isAlterColumnAction() {} - -func (s *AlterColumnSetNotNull) GoString() string { return internal.GoString(*s) } - -// AlterColumnDropNotNull represents ALTER TABLE table_name ALTER COLUMN column_name DROP NOT NULL. -type AlterColumnDropNotNull struct{} - -func (*AlterColumnDropNotNull) isAlterColumnAction() {} - -func (s *AlterColumnDropNotNull) GoString() string { return internal.GoString(*s) } - // AddConstraint represents ALTER TABLE table_name ADD CONSTRAINT. type AddConstraint struct { Constraint Constraint diff --git a/pkg/ddl/mysql/ddl_table_alter_test.go b/pkg/ddl/mysql/ddl_table_alter_test.go index cdb4744..d200cf1 100644 --- a/pkg/ddl/mysql/ddl_table_alter_test.go +++ b/pkg/ddl/mysql/ddl_table_alter_test.go @@ -25,11 +25,9 @@ func Test_isAlterTableAction(t *testing.T) { func Test_isAlterColumnAction(t *testing.T) { t.Parallel() - (&AlterColumnSetDataType{}).isAlterColumnAction() + (&AlterColumnDataType{}).isAlterColumnAction() (&AlterColumnSetDefault{}).isAlterColumnAction() (&AlterColumnDropDefault{}).isAlterColumnAction() - (&AlterColumnSetNotNull{}).isAlterColumnAction() - (&AlterColumnDropNotNull{}).isAlterColumnAction() } func TestAlterTableStmt_String(t *testing.T) { @@ -134,11 +132,11 @@ func TestAlterTableStmt_String(t *testing.T) { Name: &ObjectName{Name: &Ident{Name: "users", QuotationMark: `"`, Raw: `"users"`}}, Action: &AlterColumn{ Name: &Ident{Name: "age", QuotationMark: `"`, Raw: `"age"`}, - Action: &AlterColumnSetDataType{DataType: &DataType{Name: "INTEGER"}}, + Action: &AlterColumnDataType{DataType: &DataType{Name: "INTEGER"}}, }, } - expected := `ALTER TABLE "users" ALTER COLUMN "age" SET DATA TYPE INTEGER;` + "\n" + expected := `ALTER TABLE "users" MODIFY "age" INTEGER;` + "\n" actual := stmt.String() if !assert.Equal(t, expected, actual) { @@ -158,7 +156,7 @@ func TestAlterTableStmt_String(t *testing.T) { }, } - expected := `ALTER TABLE "users" ALTER COLUMN "age" SET DEFAULT 0;` + "\n" + expected := `ALTER TABLE "users" ALTER "age" SET DEFAULT 0;` + "\n" actual := stmt.String() if !assert.Equal(t, expected, actual) { @@ -178,47 +176,7 @@ func TestAlterTableStmt_String(t *testing.T) { }, } - expected := `ALTER TABLE "users" ALTER COLUMN "age" DROP DEFAULT;` + "\n" - actual := stmt.String() - - if !assert.Equal(t, expected, actual) { - assert.Equal(t, fmt.Sprintf("%#v", expected), fmt.Sprintf("%#v", actual)) - } - t.Logf("✅: %s: stmt: %#v", t.Name(), stmt) - }) - - t.Run("success,AlterColumnSetNotNull", func(t *testing.T) { - t.Parallel() - - stmt := &AlterTableStmt{ - Name: &ObjectName{Name: &Ident{Name: "users", QuotationMark: `"`, Raw: `"users"`}}, - Action: &AlterColumn{ - Name: &Ident{Name: "age", QuotationMark: `"`, Raw: `"age"`}, - Action: &AlterColumnSetNotNull{}, - }, - } - - expected := `ALTER TABLE "users" ALTER COLUMN "age" SET NOT NULL;` + "\n" - actual := stmt.String() - - if !assert.Equal(t, expected, actual) { - assert.Equal(t, fmt.Sprintf("%#v", expected), fmt.Sprintf("%#v", actual)) - } - t.Logf("✅: %s: stmt: %#v", t.Name(), stmt) - }) - - t.Run("success,AlterColumnDropNotNull", func(t *testing.T) { - t.Parallel() - - stmt := &AlterTableStmt{ - Name: &ObjectName{Name: &Ident{Name: "users", QuotationMark: `"`, Raw: `"users"`}}, - Action: &AlterColumn{ - Name: &Ident{Name: "age", QuotationMark: `"`, Raw: `"age"`}, - Action: &AlterColumnDropNotNull{}, - }, - } - - expected := `ALTER TABLE "users" ALTER COLUMN "age" DROP NOT NULL;` + "\n" + expected := `ALTER TABLE "users" ALTER "age" DROP DEFAULT;` + "\n" actual := stmt.String() if !assert.Equal(t, expected, actual) { diff --git a/pkg/ddl/mysql/diff.go b/pkg/ddl/mysql/diff.go index 8a8e35f..5df42d8 100644 --- a/pkg/ddl/mysql/diff.go +++ b/pkg/ddl/mysql/diff.go @@ -6,6 +6,8 @@ import ( errorz "github.com/kunitsucom/util.go/errors" "github.com/kunitsucom/util.go/exp/diff/simplediff" + apperr "github.com/kunitsucom/ddlctl/pkg/apperr" + "github.com/kunitsucom/ddlctl/pkg/ddl" ) @@ -29,7 +31,7 @@ func Diff(before, after *DDL) (*DDL, error) { Name: s.Name, }) default: - return nil, errorz.Errorf("%s: %T: %w", s.GetNameForDiff(), s, ddl.ErrNotSupported) + return nil, apperr.Errorf("%s: %T: %w", s.GetNameForDiff(), s, ddl.ErrNotSupported) } } return result, nil @@ -49,7 +51,7 @@ func Diff(before, after *DDL) (*DDL, error) { Name: beforeStmt.Name, }) default: - return nil, errorz.Errorf("%s: %T: %w", beforeStmt.GetNameForDiff(), beforeStmt, ddl.ErrNotSupported) + return nil, apperr.Errorf("%s: %T: %w", beforeStmt.GetNameForDiff(), beforeStmt, ddl.ErrNotSupported) } } @@ -61,7 +63,7 @@ func Diff(before, after *DDL) (*DDL, error) { case *CreateIndexStmt: result.Stmts = append(result.Stmts, afterStmt) default: - return nil, errorz.Errorf("%s: %T: %w", afterStmt.GetNameForDiff(), afterStmt, ddl.ErrNotSupported) + return nil, apperr.Errorf("%s: %T: %w", afterStmt.GetNameForDiff(), afterStmt, ddl.ErrNotSupported) } } diff --git a/pkg/ddl/mysql/diff_create_table.go b/pkg/ddl/mysql/diff_create_table.go index 8e9c2ee..1525734 100644 --- a/pkg/ddl/mysql/diff_create_table.go +++ b/pkg/ddl/mysql/diff_create_table.go @@ -3,9 +3,10 @@ package mysql import ( "reflect" - errorz "github.com/kunitsucom/util.go/errors" "github.com/kunitsucom/util.go/exp/diff/simplediff" + apperr "github.com/kunitsucom/ddlctl/pkg/apperr" + "github.com/kunitsucom/ddlctl/pkg/ddl" ) @@ -53,7 +54,7 @@ func DiffCreateTable(before, after *CreateTableStmt, opts ...DiffCreateTableOpti }) return result, nil case (before == nil && after == nil) || reflect.DeepEqual(before, after) || before.String() == after.String(): - return nil, errorz.Errorf("before: %s, after: %s: %w", before.GetNameForDiff(), after.GetNameForDiff(), ddl.ErrNoDifference) + return nil, apperr.Errorf("before: %s, after: %s: %w", before.GetNameForDiff(), after.GetNameForDiff(), ddl.ErrNoDifference) } if before.Name.StringForDiff() != after.Name.StringForDiff() { @@ -181,7 +182,7 @@ func DiffCreateTable(before, after *CreateTableStmt, opts ...DiffCreateTableOpti } if len(result.Stmts) == 0 { - return nil, errorz.Errorf("before: %s, after: %s: %w", before.GetNameForDiff(), after.GetNameForDiff(), ddl.ErrNoDifference) + return nil, apperr.Errorf("before: %s, after: %s: %w", before.GetNameForDiff(), after.GetNameForDiff(), ddl.ErrNoDifference) } return result, nil @@ -203,16 +204,33 @@ func (config *DiffCreateTableConfig) diffCreateTableColumn(ddls *DDL, before, af continue } - if beforeColumn.DataType.StringForDiff() != afterColumn.DataType.StringForDiff() { - // ALTER TABLE table_name ALTER COLUMN column_name SET DATA TYPE data_type; + if beforeColumn.DataType.StringForDiff() != afterColumn.DataType.StringForDiff() || + beforeColumn.NotNull && !afterColumn.NotNull || + !beforeColumn.NotNull && afterColumn.NotNull { + // ALTER TABLE table_name MODIFY column_name data_type NOT NULL; ddls.Stmts = append(ddls.Stmts, &AlterTableStmt{ Comment: simplediff.Diff(beforeColumn.String(), afterColumn.String()).String(), Name: after.Name, Action: &AlterColumn{ - Name: afterColumn.Name, - Action: &AlterColumnSetDataType{DataType: afterColumn.DataType}, + Name: afterColumn.Name, + Action: &AlterColumnDataType{ + DataType: afterColumn.DataType, + NotNull: afterColumn.NotNull, + }, }, }) + + if afterColumn.Default != nil { + // ALTER TABLE table_name ALTER COLUMN column_name SET DEFAULT default_value; + ddls.Stmts = append(ddls.Stmts, &AlterTableStmt{ + Comment: simplediff.Diff(beforeColumn.String(), afterColumn.String()).String(), + Name: after.Name, + Action: &AlterColumn{ + Name: afterColumn.Name, + Action: &AlterColumnSetDefault{Default: afterColumn.Default}, + }, + }) + } } switch { @@ -237,29 +255,6 @@ func (config *DiffCreateTableConfig) diffCreateTableColumn(ddls *DDL, before, af }, }) } - - switch { - case beforeColumn.NotNull && !afterColumn.NotNull: - // ALTER TABLE table_name ALTER COLUMN column_name DROP NOT NULL; - ddls.Stmts = append(ddls.Stmts, &AlterTableStmt{ - Comment: simplediff.Diff(beforeColumn.String(), afterColumn.String()).String(), - Name: after.Name, - Action: &AlterColumn{ - Name: afterColumn.Name, - Action: &AlterColumnDropNotNull{}, - }, - }) - case !beforeColumn.NotNull && afterColumn.NotNull: - // ALTER TABLE table_name ALTER COLUMN column_name SET NOT NULL; - ddls.Stmts = append(ddls.Stmts, &AlterTableStmt{ - Comment: simplediff.Diff(beforeColumn.String(), afterColumn.String()).String(), - Name: after.Name, - Action: &AlterColumn{ - Name: afterColumn.Name, - Action: &AlterColumnSetNotNull{}, - }, - }) - } } for _, afterColumn := range onlyLeftColumn(after.Columns, before.Columns) { diff --git a/pkg/ddl/mysql/diff_create_table_test.go b/pkg/ddl/mysql/diff_create_table_test.go index 9962a76..7c336b2 100644 --- a/pkg/ddl/mysql/diff_create_table_test.go +++ b/pkg/ddl/mysql/diff_create_table_test.go @@ -123,10 +123,13 @@ ALTER TABLE "users" DROP COLUMN "age"; expectedStr := `-- -"name" VARCHAR(255) NOT NULL -- +"name" TEXT NOT NULL -ALTER TABLE "users" ALTER COLUMN "name" SET DATA TYPE TEXT; +ALTER TABLE "users" MODIFY "name" TEXT NOT NULL; -- -"age" INT DEFAULT 0 -- +"age" BIGINT DEFAULT 0 -ALTER TABLE "users" ALTER COLUMN "age" SET DATA TYPE BIGINT; +ALTER TABLE "users" MODIFY "age" BIGINT; +-- -"age" INT DEFAULT 0 +-- +"age" BIGINT DEFAULT 0 +ALTER TABLE "users" ALTER "age" SET DEFAULT 0; -- - -- +UNIQUE KEY users_unique_name (name) CREATE UNIQUE INDEX users_unique_name ON "users" ("name"); @@ -149,7 +152,7 @@ CREATE UNIQUE INDEX users_unique_name ON "users" ("name"); expectedStr := `-- -"age" INT DEFAULT 0 -- +"age" INT -ALTER TABLE "users" ALTER COLUMN "age" DROP DEFAULT; +ALTER TABLE "users" ALTER "age" DROP DEFAULT; ` actual, err := DiffCreateTable( @@ -174,7 +177,7 @@ ALTER TABLE "users" ALTER COLUMN "age" DROP DEFAULT; expectedStr := `-- -"age" INT -- +"age" INT DEFAULT 0 -ALTER TABLE "users" ALTER COLUMN "age" SET DEFAULT 0; +ALTER TABLE "users" ALTER "age" SET DEFAULT 0; -- -CONSTRAINT users_age_check CHECK ("age" >= 0) -- + ALTER TABLE "users" DROP CONSTRAINT users_age_check; @@ -253,7 +256,10 @@ ALTER TABLE "public.app_users" ADD CONSTRAINT app_users_age_check CHECK ("age" > expectedStr := `-- -"age" INT DEFAULT 0 -- +"age" INTEGER NOT NULL DEFAULT 0 -ALTER TABLE "users" ALTER COLUMN "age" SET NOT NULL; +ALTER TABLE "users" MODIFY "age" INTEGER NOT NULL; +-- -"age" INT DEFAULT 0 +-- +"age" INTEGER NOT NULL DEFAULT 0 +ALTER TABLE "users" ALTER "age" SET DEFAULT 0; ` actual, err := DiffCreateTable( @@ -280,7 +286,10 @@ ALTER TABLE "users" ALTER COLUMN "age" SET NOT NULL; expectedStr := `-- -"age" INT NOT NULL DEFAULT 0 -- +"age" INT DEFAULT 0 -ALTER TABLE "users" ALTER COLUMN "age" DROP NOT NULL; +ALTER TABLE "users" MODIFY "age" INT; +-- -"age" INT NOT NULL DEFAULT 0 +-- +"age" INT DEFAULT 0 +ALTER TABLE "users" ALTER "age" SET DEFAULT 0; ` actual, err := DiffCreateTable( @@ -393,7 +402,7 @@ CREATE UNIQUE INDEX users_unique_name ON "users" ("id", name); expectedStr := `-- -"age" INT NOT NULL DEFAULT 0 -- +"age" INT NOT NULL DEFAULT ((0 + 3) - 1 * 4 / 2) -ALTER TABLE "users" ALTER COLUMN "age" SET DEFAULT ((0 + 3) - 1 * 4 / 2); +ALTER TABLE "users" ALTER "age" SET DEFAULT ((0 + 3) - 1 * 4 / 2); ` actual, err := DiffCreateTable( diff --git a/pkg/ddl/mysql/diff_test.go b/pkg/ddl/mysql/diff_test.go index 564f41f..9b75fd5 100644 --- a/pkg/ddl/mysql/diff_test.go +++ b/pkg/ddl/mysql/diff_test.go @@ -388,7 +388,7 @@ CREATE UNIQUE INDEX IF NOT EXISTS public.users_idx_by_username ON public.users ( expected := `-- -username VARCHAR(10) NOT NULL -- +username VARCHAR(11) NOT NULL -ALTER TABLE public.users ALTER COLUMN username SET DATA TYPE VARCHAR(11); +ALTER TABLE public.users MODIFY username VARCHAR(11) NOT NULL; ` actual, err := Diff(before, after) require.NoError(t, err) diff --git a/pkg/ddl/mysql/parser.go b/pkg/ddl/mysql/parser.go index a4c5e3e..d75073e 100644 --- a/pkg/ddl/mysql/parser.go +++ b/pkg/ddl/mysql/parser.go @@ -7,10 +7,11 @@ import ( "runtime" "strings" - errorz "github.com/kunitsucom/util.go/errors" filepathz "github.com/kunitsucom/util.go/path/filepath" stringz "github.com/kunitsucom/util.go/strings" + apperr "github.com/kunitsucom/ddlctl/pkg/apperr" + "github.com/kunitsucom/ddlctl/pkg/ddl" "github.com/kunitsucom/ddlctl/pkg/ddl/logs" ) @@ -82,7 +83,7 @@ LabelDDL: case TOKEN_CREATE: stmt, err := p.parseCreateStatement() if err != nil { - return nil, errorz.Errorf("parseCreateStatement: %w", err) + return nil, apperr.Errorf("parseCreateStatement: %w", err) } d.Stmts = append(d.Stmts, stmt) case TOKEN_CLOSE_PAREN: @@ -92,7 +93,7 @@ LabelDDL: case TOKEN_EOF: break LabelDDL default: - return nil, errorz.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + return nil, apperr.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) } p.nextToken() @@ -109,7 +110,7 @@ func (p *Parser) parseCreateStatement() (Stmt, error) { //nolint:ireturn case TOKEN_INDEX, TOKEN_UNIQUE: return p.parseCreateIndexStmt() default: - return nil, errorz.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + return nil, apperr.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) } } @@ -122,11 +123,11 @@ func (p *Parser) parseCreateTableStmt() (*CreateTableStmt, error) { if p.isPeekToken(TOKEN_IF) { p.nextToken() // current = IF if err := p.checkPeekToken(TOKEN_NOT); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = NOT if err := p.checkPeekToken(TOKEN_EXISTS); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = EXISTS createTableStmt.IfNotExists = true @@ -134,7 +135,7 @@ func (p *Parser) parseCreateTableStmt() (*CreateTableStmt, error) { p.nextToken() // current = table_name if err := p.checkCurrentToken(TOKEN_IDENT); err != nil { - return nil, errorz.Errorf("checkCurrentToken: %w", err) + return nil, apperr.Errorf("checkCurrentToken: %w", err) } createTableStmt.Name = NewObjectName(p.currentToken.Literal.Str) @@ -143,7 +144,7 @@ func (p *Parser) parseCreateTableStmt() (*CreateTableStmt, error) { p.nextToken() // current = ( if err := p.checkCurrentToken(TOKEN_OPEN_PAREN); err != nil { - return nil, errorz.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) + return nil, apperr.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) } p.nextToken() // current = column_name @@ -154,7 +155,7 @@ LabelColumns: case p.isCurrentToken(TOKEN_IDENT): column, constraints, err := p.parseColumn(createTableStmt.Name.Name) if err != nil { - return nil, errorz.Errorf(errFmtPrefix+"parseColumn: %w", err) + return nil, apperr.Errorf(errFmtPrefix+"parseColumn: %w", err) } createTableStmt.Columns = append(createTableStmt.Columns, column) if len(constraints) > 0 { @@ -165,7 +166,7 @@ LabelColumns: case isConstraint(p.currentToken.Type): constraint, err := p.parseTableConstraint(createTableStmt.Name.Name) if err != nil { - return nil, errorz.Errorf(errFmtPrefix+"parseConstraint: %w", err) + return nil, apperr.Errorf(errFmtPrefix+"parseConstraint: %w", err) } createTableStmt.Constraints = createTableStmt.Constraints.Append(constraint) case p.isCurrentToken(TOKEN_COMMA): @@ -175,7 +176,7 @@ LabelColumns: p.nextToken() break LabelColumns default: - return nil, errorz.Errorf(errFmtPrefix+"currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + return nil, apperr.Errorf(errFmtPrefix+"currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) } } @@ -187,43 +188,43 @@ LabelTableOptions: opt.Name = "ENGINE" p.nextToken() // current = `=` if err := p.checkCurrentToken(TOKEN_EQUAL); err != nil { - return nil, errorz.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) + return nil, apperr.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) } p.nextToken() // current = TOKEN_IDENT if err := p.checkCurrentToken(TOKEN_IDENT); err != nil { - return nil, errorz.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) + return nil, apperr.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) } opt.Value = NewRawIdent(p.currentToken.Literal.Str) case TOKEN_DEFAULT: if err := p.checkPeekToken(TOKEN_CHARSET); err != nil { - return nil, errorz.Errorf(errFmtPrefix+"checkPeekToken: %w", err) + return nil, apperr.Errorf(errFmtPrefix+"checkPeekToken: %w", err) } p.nextToken() // current = CHARSET opt.Name = "DEFAULT CHARSET" p.nextToken() // current = `=` if err := p.checkCurrentToken(TOKEN_EQUAL); err != nil { - return nil, errorz.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) + return nil, apperr.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) } p.nextToken() // current = TOKEN_IDENT if err := p.checkCurrentToken(TOKEN_IDENT); err != nil { - return nil, errorz.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) + return nil, apperr.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) } opt.Value = NewRawIdent(p.currentToken.Literal.Str) case TOKEN_COLLATE: opt.Name = "COLLATE" p.nextToken() // current = `=` if err := p.checkCurrentToken(TOKEN_EQUAL); err != nil { - return nil, errorz.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) + return nil, apperr.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) } p.nextToken() // current = TOKEN_IDENT if err := p.checkCurrentToken(TOKEN_IDENT); err != nil { - return nil, errorz.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) + return nil, apperr.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) } opt.Value = NewRawIdent(p.currentToken.Literal.Str) case TOKEN_SEMICOLON, TOKEN_EOF: break LabelTableOptions default: - return nil, errorz.Errorf(errFmtPrefix+"peekToken=%#v: %w", p.peekToken, ddl.ErrUnexpectedToken) + return nil, apperr.Errorf(errFmtPrefix+"peekToken=%#v: %w", p.peekToken, ddl.ErrUnexpectedToken) } createTableStmt.Options = append(createTableStmt.Options, opt) p.nextToken() @@ -244,11 +245,11 @@ func (p *Parser) parseCreateIndexStmt() (*CreateIndexStmt, error) { if p.isPeekToken(TOKEN_IF) { p.nextToken() // current = IF if err := p.checkPeekToken(TOKEN_NOT); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = NOT if err := p.checkPeekToken(TOKEN_EXISTS); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = EXISTS createIndexStmt.IfNotExists = true @@ -256,7 +257,7 @@ func (p *Parser) parseCreateIndexStmt() (*CreateIndexStmt, error) { p.nextToken() // current = index_name if err := p.checkCurrentToken(TOKEN_IDENT); err != nil { - return nil, errorz.Errorf("checkCurrentToken: %w", err) + return nil, apperr.Errorf("checkCurrentToken: %w", err) } createIndexStmt.Name = NewObjectName(p.currentToken.Literal.Str) @@ -265,13 +266,13 @@ func (p *Parser) parseCreateIndexStmt() (*CreateIndexStmt, error) { p.nextToken() // current = ON if err := p.checkCurrentToken(TOKEN_ON); err != nil { - return nil, errorz.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) + return nil, apperr.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) } p.nextToken() // current = table_name if err := p.checkCurrentToken(TOKEN_IDENT); err != nil { - return nil, errorz.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) + return nil, apperr.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) } createIndexStmt.TableName = NewObjectName(p.currentToken.Literal.Str) @@ -285,12 +286,12 @@ func (p *Parser) parseCreateIndexStmt() (*CreateIndexStmt, error) { } if err := p.checkCurrentToken(TOKEN_OPEN_PAREN); err != nil { - return nil, errorz.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) + return nil, apperr.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) } idents, err := p.parseColumnIdents() if err != nil { - return nil, errorz.Errorf(errFmtPrefix+"parseColumnIdents: %w", err) + return nil, apperr.Errorf(errFmtPrefix+"parseColumnIdents: %w", err) } createIndexStmt.Columns = idents @@ -304,7 +305,7 @@ func (p *Parser) parseColumn(tableName *Ident) (*Column, []Constraint, error) { constraints := make(Constraints, 0) if err := p.checkCurrentToken(TOKEN_IDENT); err != nil { - return nil, nil, errorz.Errorf("checkCurrentToken: %w", err) + return nil, nil, apperr.Errorf("checkCurrentToken: %w", err) } column.Name = NewRawIdent(p.currentToken.Literal.Str) @@ -316,7 +317,7 @@ func (p *Parser) parseColumn(tableName *Ident) (*Column, []Constraint, error) { case isDataType(p.currentToken.Type): dataType, err := p.parseDataType() if err != nil { - return nil, nil, errorz.Errorf(errFmtPrefix+"parseDataType: %w", err) + return nil, nil, apperr.Errorf(errFmtPrefix+"parseDataType: %w", err) } column.DataType = dataType @@ -326,7 +327,7 @@ func (p *Parser) parseColumn(tableName *Ident) (*Column, []Constraint, error) { switch p.currentToken.Type { //nolint:exhaustive case TOKEN_NOT: if err := p.checkPeekToken(TOKEN_NULL); err != nil { - return nil, nil, errorz.Errorf(errFmtPrefix+"checkPeekToken: %w", err) + return nil, nil, apperr.Errorf(errFmtPrefix+"checkPeekToken: %w", err) } p.nextToken() // current = NULL column.NotNull = true @@ -336,7 +337,7 @@ func (p *Parser) parseColumn(tableName *Ident) (*Column, []Constraint, error) { p.nextToken() // current = default_value def, err := p.parseColumnDefault() if err != nil { - return nil, nil, errorz.Errorf(errFmtPrefix+"parseColumnDefault: %w", err) + return nil, nil, apperr.Errorf(errFmtPrefix+"parseColumnDefault: %w", err) } column.Default = def continue @@ -349,7 +350,7 @@ func (p *Parser) parseColumn(tableName *Ident) (*Column, []Constraint, error) { cs, err := p.parseColumnConstraints(tableName, column) if err != nil { - return nil, nil, errorz.Errorf(errFmtPrefix+"parseColumnConstraints: %w", err) + return nil, nil, apperr.Errorf(errFmtPrefix+"parseColumnConstraints: %w", err) } if len(cs) > 0 { for _, c := range cs { @@ -357,7 +358,7 @@ func (p *Parser) parseColumn(tableName *Ident) (*Column, []Constraint, error) { } } default: - return nil, nil, errorz.Errorf(errFmtPrefix+"currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + return nil, nil, apperr.Errorf(errFmtPrefix+"currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) } return column, constraints, nil @@ -375,7 +376,7 @@ LabelDefault: case TOKEN_OPEN_PAREN: ids, err := p.parseExpr() if err != nil { - return nil, errorz.Errorf("parseExpr: %w", err) + return nil, apperr.Errorf("parseExpr: %w", err) } def.Value = def.Value.Append(ids...) continue @@ -402,7 +403,7 @@ LabelDefault: if isConstraint(p.currentToken.Type) { break LabelDefault } - return nil, errorz.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + return nil, apperr.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) } p.nextToken() @@ -411,11 +412,12 @@ LabelDefault: return def, nil } +//nolint:funlen,cyclop func (p *Parser) parseExpr() ([]*Ident, error) { idents := make([]*Ident, 0) if err := p.checkCurrentToken(TOKEN_OPEN_PAREN); err != nil { - return nil, errorz.Errorf("checkCurrentToken: %w", err) + return nil, apperr.Errorf("checkCurrentToken: %w", err) } idents = append(idents, NewRawIdent(p.currentToken.Literal.Str)) p.nextToken() // current = IDENT @@ -426,7 +428,7 @@ LabelExpr: case TOKEN_OPEN_PAREN: ids, err := p.parseExpr() if err != nil { - return nil, errorz.Errorf("parseExpr: %w", err) + return nil, apperr.Errorf("parseExpr: %w", err) } idents = append(idents, ids...) continue @@ -443,9 +445,13 @@ LabelExpr: } idents = append(idents, NewRawIdent(value)) case TOKEN_EOF: - return nil, errorz.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + return nil, apperr.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) default: - idents = append(idents, NewRawIdent(p.currentToken.Literal.Str)) + if isReservedValue(p.currentToken.Type) { + idents = append(idents, NewRawIdent(p.currentToken.Type.String())) + } else { + idents = append(idents, NewRawIdent(p.currentToken.Literal.Str)) + } } p.nextToken() @@ -463,7 +469,7 @@ LabelConstraints: switch p.currentToken.Type { //nolint:exhaustive case TOKEN_PRIMARY: if err := p.checkPeekToken(TOKEN_KEY); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = KEY constraints = constraints.Append(&PrimaryKeyConstraint{ @@ -472,7 +478,7 @@ LabelConstraints: }) case TOKEN_REFERENCES: if err := p.checkPeekToken(TOKEN_IDENT); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = table_name constraint := &ForeignKeyConstraint{ @@ -483,7 +489,7 @@ LabelConstraints: p.nextToken() // current = ( idents, err := p.parseColumnIdents() if err != nil { - return nil, errorz.Errorf("parseColumnIdents: %w", err) + return nil, apperr.Errorf("parseColumnIdents: %w", err) } constraint.RefColumns = idents constraints = constraints.Append(constraint) @@ -495,7 +501,7 @@ LabelConstraints: }) case TOKEN_CHECK: if err := p.checkPeekToken(TOKEN_OPEN_PAREN); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = ( constraint := &CheckConstraint{ @@ -503,14 +509,14 @@ LabelConstraints: } idents, err := p.parseExpr() if err != nil { - return nil, errorz.Errorf("parseExpr: %w", err) + return nil, apperr.Errorf("parseExpr: %w", err) } constraint.Expr = constraint.Expr.Append(idents...) constraints = constraints.Append(constraint) case TOKEN_IDENT, TOKEN_COMMA, TOKEN_CLOSE_PAREN: break LabelConstraints default: - return nil, errorz.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + return nil, apperr.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) } p.nextToken() @@ -525,7 +531,7 @@ func (p *Parser) parseTableConstraint(tableName *Ident) (Constraint, error) { // if p.isCurrentToken(TOKEN_CONSTRAINT) { p.nextToken() // current = constraint_name if p.currentToken.Type != TOKEN_IDENT { - return nil, errorz.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + return nil, apperr.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) } constraintName = NewRawIdent(p.currentToken.Literal.Str) p.nextToken() // current = PRIMARY or CHECK @@ -534,16 +540,16 @@ func (p *Parser) parseTableConstraint(tableName *Ident) (Constraint, error) { // switch p.currentToken.Type { //nolint:exhaustive case TOKEN_PRIMARY: if err := p.checkPeekToken(TOKEN_KEY); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = KEY if err := p.checkPeekToken(TOKEN_OPEN_PAREN); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = ( idents, err := p.parseColumnIdents() if err != nil { - return nil, errorz.Errorf("parseColumnIdents: %w", err) + return nil, apperr.Errorf("parseColumnIdents: %w", err) } return &PrimaryKeyConstraint{ Name: NewRawIdent("PRIMARY KEY"), @@ -551,30 +557,30 @@ func (p *Parser) parseTableConstraint(tableName *Ident) (Constraint, error) { // }, nil case TOKEN_FOREIGN: if err := p.checkPeekToken(TOKEN_KEY); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = KEY if err := p.checkPeekToken(TOKEN_OPEN_PAREN); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = ( idents, err := p.parseColumnIdents() if err != nil { - return nil, errorz.Errorf("parseColumnIdents: %w", err) + return nil, apperr.Errorf("parseColumnIdents: %w", err) } if err := p.checkCurrentToken(TOKEN_REFERENCES); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = ref_table_name if err := p.checkCurrentToken(TOKEN_IDENT); err != nil { - return nil, errorz.Errorf("checkCurrentToken: %w", err) + return nil, apperr.Errorf("checkCurrentToken: %w", err) } refName := NewRawIdent(p.currentToken.Literal.Str) p.nextToken() // current = ( identsRef, err := p.parseColumnIdents() if err != nil { - return nil, errorz.Errorf("parseColumnIdents: %w", err) + return nil, apperr.Errorf("parseColumnIdents: %w", err) } if constraintName == nil { name := tableName.StringForDiff() @@ -598,20 +604,20 @@ func (p *Parser) parseTableConstraint(tableName *Ident) (Constraint, error) { // p.nextToken() // current = KEY or INDEX } if err := p.checkCurrentToken(TOKEN_INDEX, TOKEN_KEY); err != nil { - return nil, errorz.Errorf("checkCurrentToken: %w", err) + return nil, apperr.Errorf("checkCurrentToken: %w", err) } p.nextToken() // current = index_name if err := p.checkCurrentToken(TOKEN_IDENT); err != nil { - return nil, errorz.Errorf("checkCurrentToken: %w", err) + return nil, apperr.Errorf("checkCurrentToken: %w", err) } constraintName := NewRawIdent(p.currentToken.Literal.Str) if err := p.checkPeekToken(TOKEN_OPEN_PAREN); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = ( idents, err := p.parseColumnIdents() if err != nil { - return nil, errorz.Errorf("parseColumnIdents: %w", err) + return nil, apperr.Errorf("parseColumnIdents: %w", err) } c.Name = constraintName c.Columns = idents @@ -619,12 +625,12 @@ func (p *Parser) parseTableConstraint(tableName *Ident) (Constraint, error) { // case TOKEN_CHECK: constraint := &CheckConstraint{} if err := p.checkPeekToken(TOKEN_OPEN_PAREN); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = ( idents, err := p.parseExpr() if err != nil { - return nil, errorz.Errorf("parseExpr: %w", err) + return nil, apperr.Errorf("parseExpr: %w", err) } if constraintName == nil { // TODO: handle CONSTRAINT name @@ -634,7 +640,7 @@ func (p *Parser) parseTableConstraint(tableName *Ident) (Constraint, error) { // constraint.Expr = constraint.Expr.Append(idents...) return constraint, nil default: - return nil, errorz.Errorf("currentToken=%s: %w", p.currentToken.Type, ddl.ErrUnexpectedToken) + return nil, apperr.Errorf("currentToken=%s: %w", p.currentToken.Type, ddl.ErrUnexpectedToken) } } @@ -662,7 +668,7 @@ func (p *Parser) parseDataType() (*DataType, error) { case TOKEN_DOUBLE: dataType.Name = p.currentToken.Literal.String() if err := p.checkPeekToken(TOKEN_PRECISION); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = PRECISION dataType.Name += " " + p.currentToken.Literal.String() @@ -670,7 +676,7 @@ func (p *Parser) parseDataType() (*DataType, error) { case TOKEN_CHARACTER: dataType.Name = p.currentToken.Literal.String() if err := p.checkPeekToken(TOKEN_VARYING); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = VARYING dataType.Name += " " + p.currentToken.Literal.String() @@ -687,7 +693,7 @@ func (p *Parser) parseDataType() (*DataType, error) { p.nextToken() // current = ( idents, err := p.parseIdents() if err != nil { - return nil, errorz.Errorf("parseIdents: %w", err) + return nil, apperr.Errorf("parseIdents: %w", err) } dataType.Expr = dataType.Expr.Append(idents...) } @@ -722,7 +728,7 @@ LabelIdents: p.nextToken() break LabelIdents default: - return nil, errorz.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + return nil, apperr.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) } p.nextToken() } @@ -743,7 +749,7 @@ LabelIdents: case TOKEN_CLOSE_PAREN: break LabelIdents case TOKEN_EOF, TOKEN_ILLEGAL: - return nil, errorz.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + return nil, apperr.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) default: idents = append(idents, NewRawIdent(p.currentToken.Literal.Str)) } @@ -825,7 +831,7 @@ func (p *Parser) checkCurrentToken(expectedTypes ...TokenType) error { return nil } } - return errorz.Errorf("currentToken: expected=%s, but got=%#v: %w", stringz.JoinStringers(",", expectedTypes...), p.currentToken, ddl.ErrUnexpectedToken) + return apperr.Errorf("currentToken: expected=%s, but got=%#v: %w", stringz.JoinStringers(",", expectedTypes...), p.currentToken, ddl.ErrUnexpectedToken) } func (p *Parser) isPeekToken(expectedTypes ...TokenType) bool { @@ -843,5 +849,5 @@ func (p *Parser) checkPeekToken(expectedTypes ...TokenType) error { return nil } } - return errorz.Errorf("peekToken: expected=%s, but got=%#v: %w", stringz.JoinStringers(",", expectedTypes...), p.peekToken, ddl.ErrUnexpectedToken) + return apperr.Errorf("peekToken: expected=%s, but got=%#v: %w", stringz.JoinStringers(",", expectedTypes...), p.peekToken, ddl.ErrUnexpectedToken) } diff --git a/pkg/ddl/postgres/ddl.go b/pkg/ddl/postgres/ddl.go index 93e5ccd..9b92e59 100644 --- a/pkg/ddl/postgres/ddl.go +++ b/pkg/ddl/postgres/ddl.go @@ -7,8 +7,8 @@ import ( ) const ( - Dialect = "postgres" - DriverName = "postgres" + Dialect = "postgres" //diff:ignore-line-postgres-cockroach + DriverName = "postgres" //diff:ignore-line-postgres-cockroach Indent = " " CommentPrefix = "-- " ) diff --git a/pkg/ddl/postgres/ddl_table.go b/pkg/ddl/postgres/ddl_table.go index 809cac7..5794981 100644 --- a/pkg/ddl/postgres/ddl_table.go +++ b/pkg/ddl/postgres/ddl_table.go @@ -73,6 +73,7 @@ type ForeignKeyConstraint struct { Columns []*ColumnIdent Ref *Ident RefColumns []*ColumnIdent + OnAction string } var _ Constraint = (*ForeignKeyConstraint)(nil) @@ -89,6 +90,9 @@ func (c *ForeignKeyConstraint) String() string { str += " (" + stringz.JoinStringers(", ", c.Columns...) + ")" str += " REFERENCES " + c.Ref.String() str += " (" + stringz.JoinStringers(", ", c.RefColumns...) + ")" + if c.OnAction != "" { + str += " " + c.OnAction + } return str } @@ -115,6 +119,9 @@ func (c *ForeignKeyConstraint) StringForDiff() string { str += v.StringForDiff() } str += ")" + if c.OnAction != "" { + str += " " + c.OnAction + } return str } diff --git a/pkg/ddl/postgres/ddl_table_test.go b/pkg/ddl/postgres/ddl_table_test.go index 67d3e2b..1938225 100644 --- a/pkg/ddl/postgres/ddl_table_test.go +++ b/pkg/ddl/postgres/ddl_table_test.go @@ -50,12 +50,17 @@ func TestForeignKeyConstraint(t *testing.T) { Columns: []*ColumnIdent{{Ident: &Ident{Name: "group_id", QuotationMark: `"`, Raw: `"group_id"`}}}, Ref: &Ident{Name: "groups", QuotationMark: `"`, Raw: `"groups"`}, RefColumns: []*ColumnIdent{{Ident: &Ident{Name: "id", QuotationMark: `"`, Raw: `"id"`}}}, + OnAction: "ON DELETE NO ACTION", } - expected := `CONSTRAINT "fk_users_groups" FOREIGN KEY ("group_id") REFERENCES "groups" ("id")` + expected := `CONSTRAINT "fk_users_groups" FOREIGN KEY ("group_id") REFERENCES "groups" ("id") ON DELETE NO ACTION` actual := foreignKeyConstraint.String() require.Equal(t, expected, actual) + expectedForDiff := `CONSTRAINT fk_users_groups FOREIGN KEY (group_id) REFERENCES groups (id) ON DELETE NO ACTION` + actualForDiff := foreignKeyConstraint.StringForDiff() + require.Equal(t, expectedForDiff, actualForDiff) + t.Logf("✅: %s: foreignKeyConstraint: %#v", t.Name(), foreignKeyConstraint) }) } diff --git a/pkg/ddl/postgres/diff.go b/pkg/ddl/postgres/diff.go index d1d22e8..8db9540 100644 --- a/pkg/ddl/postgres/diff.go +++ b/pkg/ddl/postgres/diff.go @@ -6,6 +6,8 @@ import ( errorz "github.com/kunitsucom/util.go/errors" "github.com/kunitsucom/util.go/exp/diff/simplediff" + apperr "github.com/kunitsucom/ddlctl/pkg/apperr" + "github.com/kunitsucom/ddlctl/pkg/ddl" ) @@ -29,7 +31,7 @@ func Diff(before, after *DDL) (*DDL, error) { Name: s.Name, }) default: - return nil, errorz.Errorf("%s: %T: %w", s.GetNameForDiff(), s, ddl.ErrNotSupported) + return nil, apperr.Errorf("%s: %T: %w", s.GetNameForDiff(), s, ddl.ErrNotSupported) } } return result, nil @@ -49,7 +51,7 @@ func Diff(before, after *DDL) (*DDL, error) { Name: beforeStmt.Name, }) default: - return nil, errorz.Errorf("%s: %T: %w", beforeStmt.GetNameForDiff(), beforeStmt, ddl.ErrNotSupported) + return nil, apperr.Errorf("%s: %T: %w", beforeStmt.GetNameForDiff(), beforeStmt, ddl.ErrNotSupported) } } @@ -61,7 +63,7 @@ func Diff(before, after *DDL) (*DDL, error) { case *CreateIndexStmt: result.Stmts = append(result.Stmts, afterStmt) default: - return nil, errorz.Errorf("%s: %T: %w", afterStmt.GetNameForDiff(), afterStmt, ddl.ErrNotSupported) + return nil, apperr.Errorf("%s: %T: %w", afterStmt.GetNameForDiff(), afterStmt, ddl.ErrNotSupported) } } diff --git a/pkg/ddl/postgres/diff_create_table.go b/pkg/ddl/postgres/diff_create_table.go index 5b1568a..7394425 100644 --- a/pkg/ddl/postgres/diff_create_table.go +++ b/pkg/ddl/postgres/diff_create_table.go @@ -3,9 +3,10 @@ package postgres import ( "reflect" - errorz "github.com/kunitsucom/util.go/errors" "github.com/kunitsucom/util.go/exp/diff/simplediff" + apperr "github.com/kunitsucom/ddlctl/pkg/apperr" + "github.com/kunitsucom/ddlctl/pkg/ddl" ) @@ -53,7 +54,7 @@ func DiffCreateTable(before, after *CreateTableStmt, opts ...DiffCreateTableOpti }) return result, nil case (before == nil && after == nil) || reflect.DeepEqual(before, after) || before.String() == after.String(): - return nil, errorz.Errorf("before: %s, after: %s: %w", before.GetNameForDiff(), after.GetNameForDiff(), ddl.ErrNoDifference) + return nil, apperr.Errorf("before: %s, after: %s: %w", before.GetNameForDiff(), after.GetNameForDiff(), ddl.ErrNoDifference) } if before.Name.StringForDiff() != after.Name.StringForDiff() { @@ -130,7 +131,7 @@ func DiffCreateTable(before, after *CreateTableStmt, opts ...DiffCreateTableOpti } if len(result.Stmts) == 0 { - return nil, errorz.Errorf("before: %s, after: %s: %w", before.GetNameForDiff(), after.GetNameForDiff(), ddl.ErrNoDifference) + return nil, apperr.Errorf("before: %s, after: %s: %w", before.GetNameForDiff(), after.GetNameForDiff(), ddl.ErrNoDifference) } return result, nil diff --git a/pkg/ddl/postgres/lexar.go b/pkg/ddl/postgres/lexar.go index 1c59c8a..e020a1d 100644 --- a/pkg/ddl/postgres/lexar.go +++ b/pkg/ddl/postgres/lexar.go @@ -6,11 +6,6 @@ import ( // MEMO: https://www.postgresql.jp/docs/11/datatype.html -const ( - QuotationChar = '"' - QuotationStr = string(QuotationChar) -) - // Token はSQL文のトークンを表す型です。 type Token struct { Type TokenType @@ -62,6 +57,8 @@ const ( TOKEN_DROP TokenType = "DROP" TOKEN_RENAME TokenType = "RENAME" TOKEN_TRUNCATE TokenType = "TRUNCATE" + TOKEN_DELETE TokenType = "DELETE" + TOKEN_UPDATE TokenType = "UPDATE" // OBJECT. TOKEN_TABLE TokenType = "TABLE" @@ -107,6 +104,9 @@ const ( TOKEN_NOT TokenType = "NOT" TOKEN_ASC TokenType = "ASC" TOKEN_DESC TokenType = "DESC" + TOKEN_CASCADE TokenType = "CASCADE" + TOKEN_NO TokenType = "NO" + TOKEN_ACTION TokenType = "ACTION" // CONSTRAINT. TOKEN_CONSTRAINT TokenType = "CONSTRAINT" @@ -154,6 +154,10 @@ func lookupIdent(ident string) TokenType { return TOKEN_RENAME case "TRUNCATE": return TOKEN_TRUNCATE + case "DELETE": + return TOKEN_DELETE + case "UPDATE": + return TOKEN_UPDATE case "TABLE": return TOKEN_TABLE case "INDEX": @@ -222,6 +226,12 @@ func lookupIdent(ident string) TokenType { return TOKEN_ASC case "DESC": return TOKEN_DESC + case "CASCADE": + return TOKEN_CASCADE + case "NO": + return TOKEN_NO + case "ACTION": + return TOKEN_ACTION case "CONSTRAINT": return TOKEN_CONSTRAINT case "PRIMARY": diff --git a/pkg/ddl/postgres/lexar_test.go b/pkg/ddl/postgres/lexar_test.go index 0a3f7a9..1acea8c 100644 --- a/pkg/ddl/postgres/lexar_test.go +++ b/pkg/ddl/postgres/lexar_test.go @@ -26,6 +26,8 @@ func Test_lookupIdent(t *testing.T) { {name: "success,DROP", input: "DROP", want: TOKEN_DROP}, {name: "success,RENAME", input: "RENAME", want: TOKEN_RENAME}, {name: "success,TRUNCATE", input: "TRUNCATE", want: TOKEN_TRUNCATE}, + {name: "success,DELETE", input: "DELETE", want: TOKEN_DELETE}, + {name: "success,UPDATE", input: "UPDATE", want: TOKEN_UPDATE}, {name: "success,TABLE", input: "TABLE", want: TOKEN_TABLE}, {name: "success,INDEX", input: "INDEX", want: TOKEN_INDEX}, {name: "success,VIEW", input: "VIEW", want: TOKEN_VIEW}, @@ -62,6 +64,9 @@ func Test_lookupIdent(t *testing.T) { {name: "success,NULL", input: "NULL", want: TOKEN_NULL}, {name: "success,ASC", input: "ASC", want: TOKEN_ASC}, {name: "success,DESC", input: "DESC", want: TOKEN_DESC}, + {name: "success,CASCADE", input: "CASCADE", want: TOKEN_CASCADE}, + {name: "success,NO", input: "NO", want: TOKEN_NO}, + {name: "success,ACTION", input: "ACTION", want: TOKEN_ACTION}, {name: "success,CONSTRAINT", input: "CONSTRAINT", want: TOKEN_CONSTRAINT}, {name: "success,PRIMARY", input: "PRIMARY", want: TOKEN_PRIMARY}, {name: "success,KEY", input: "KEY", want: TOKEN_KEY}, diff --git a/pkg/ddl/postgres/parser.go b/pkg/ddl/postgres/parser.go index 040d18b..dacd3bc 100644 --- a/pkg/ddl/postgres/parser.go +++ b/pkg/ddl/postgres/parser.go @@ -8,10 +8,11 @@ import ( "runtime" "strings" - errorz "github.com/kunitsucom/util.go/errors" filepathz "github.com/kunitsucom/util.go/path/filepath" stringz "github.com/kunitsucom/util.go/strings" + apperr "github.com/kunitsucom/ddlctl/pkg/apperr" + "github.com/kunitsucom/ddlctl/pkg/ddl" "github.com/kunitsucom/ddlctl/pkg/ddl/logs" ) @@ -83,7 +84,7 @@ LabelDDL: case TOKEN_CREATE: stmt, err := p.parseCreateStatement() if err != nil { - return nil, errorz.Errorf("parseCreateStatement: %w", err) + return nil, apperr.Errorf("parseCreateStatement: %w", err) } d.Stmts = append(d.Stmts, stmt) case TOKEN_CLOSE_PAREN: @@ -93,7 +94,7 @@ LabelDDL: case TOKEN_EOF: break LabelDDL default: - return nil, errorz.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + return nil, apperr.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) } p.nextToken() @@ -110,7 +111,7 @@ func (p *Parser) parseCreateStatement() (Stmt, error) { //nolint:ireturn case TOKEN_INDEX, TOKEN_UNIQUE: return p.parseCreateIndexStmt() default: - return nil, errorz.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + return nil, apperr.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) } } @@ -123,11 +124,11 @@ func (p *Parser) parseCreateTableStmt() (*CreateTableStmt, error) { if p.isPeekToken(TOKEN_IF) { p.nextToken() // current = IF if err := p.checkPeekToken(TOKEN_NOT); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = NOT if err := p.checkPeekToken(TOKEN_EXISTS); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = EXISTS createTableStmt.IfNotExists = true @@ -135,7 +136,7 @@ func (p *Parser) parseCreateTableStmt() (*CreateTableStmt, error) { p.nextToken() // current = table_name if err := p.checkCurrentToken(TOKEN_IDENT); err != nil { - return nil, errorz.Errorf("checkCurrentToken: %w", err) + return nil, apperr.Errorf("checkCurrentToken: %w", err) } createTableStmt.Name = NewObjectName(p.currentToken.Literal.Str) @@ -144,7 +145,7 @@ func (p *Parser) parseCreateTableStmt() (*CreateTableStmt, error) { p.nextToken() // current = ( if err := p.checkCurrentToken(TOKEN_OPEN_PAREN); err != nil { - return nil, errorz.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) + return nil, apperr.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) } p.nextToken() // current = column_name @@ -155,7 +156,7 @@ LabelColumns: case p.isCurrentToken(TOKEN_IDENT): column, constraints, err := p.parseColumn(createTableStmt.Name.Name) if err != nil { - return nil, errorz.Errorf(errFmtPrefix+"parseColumn: %w", err) + return nil, apperr.Errorf(errFmtPrefix+"parseColumn: %w", err) } createTableStmt.Columns = append(createTableStmt.Columns, column) if len(constraints) > 0 { @@ -166,7 +167,7 @@ LabelColumns: case isConstraint(p.currentToken.Type): constraint, err := p.parseTableConstraint(createTableStmt.Name.Name) if err != nil { - return nil, errorz.Errorf(errFmtPrefix+"parseConstraint: %w", err) + return nil, apperr.Errorf(errFmtPrefix+"parseConstraint: %w", err) } createTableStmt.Constraints = createTableStmt.Constraints.Append(constraint) case p.isCurrentToken(TOKEN_COMMA): @@ -177,10 +178,10 @@ LabelColumns: case TOKEN_SEMICOLON, TOKEN_EOF: break LabelColumns default: - return nil, errorz.Errorf(errFmtPrefix+"peekToken=%#v: %w", p.peekToken, ddl.ErrUnexpectedToken) + return nil, apperr.Errorf(errFmtPrefix+"peekToken=%#v: %w", p.peekToken, ddl.ErrUnexpectedToken) } default: - return nil, errorz.Errorf(errFmtPrefix+"currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + return nil, apperr.Errorf(errFmtPrefix+"currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) } } @@ -199,11 +200,11 @@ func (p *Parser) parseCreateIndexStmt() (*CreateIndexStmt, error) { if p.isPeekToken(TOKEN_IF) { p.nextToken() // current = IF if err := p.checkPeekToken(TOKEN_NOT); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = NOT if err := p.checkPeekToken(TOKEN_EXISTS); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = EXISTS createIndexStmt.IfNotExists = true @@ -211,7 +212,7 @@ func (p *Parser) parseCreateIndexStmt() (*CreateIndexStmt, error) { p.nextToken() // current = index_name if err := p.checkCurrentToken(TOKEN_IDENT); err != nil { - return nil, errorz.Errorf("checkCurrentToken: %w", err) + return nil, apperr.Errorf("checkCurrentToken: %w", err) } createIndexStmt.Name = NewObjectName(p.currentToken.Literal.Str) @@ -220,13 +221,13 @@ func (p *Parser) parseCreateIndexStmt() (*CreateIndexStmt, error) { p.nextToken() // current = ON if err := p.checkCurrentToken(TOKEN_ON); err != nil { - return nil, errorz.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) + return nil, apperr.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) } p.nextToken() // current = table_name if err := p.checkCurrentToken(TOKEN_IDENT); err != nil { - return nil, errorz.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) + return nil, apperr.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) } createIndexStmt.TableName = NewObjectName(p.currentToken.Literal.Str) @@ -240,12 +241,12 @@ func (p *Parser) parseCreateIndexStmt() (*CreateIndexStmt, error) { } if err := p.checkCurrentToken(TOKEN_OPEN_PAREN); err != nil { - return nil, errorz.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) + return nil, apperr.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) } idents, err := p.parseColumnIdents() if err != nil { - return nil, errorz.Errorf(errFmtPrefix+"parseColumnIdents: %w", err) + return nil, apperr.Errorf(errFmtPrefix+"parseColumnIdents: %w", err) } createIndexStmt.Columns = idents @@ -259,7 +260,7 @@ func (p *Parser) parseColumn(tableName *Ident) (*Column, []Constraint, error) { constraints := make(Constraints, 0) if err := p.checkCurrentToken(TOKEN_IDENT); err != nil { - return nil, nil, errorz.Errorf("checkCurrentToken: %w", err) + return nil, nil, apperr.Errorf("checkCurrentToken: %w", err) } column.Name = NewRawIdent(p.currentToken.Literal.Str) @@ -271,7 +272,7 @@ func (p *Parser) parseColumn(tableName *Ident) (*Column, []Constraint, error) { case isDataType(p.currentToken.Type): dataType, err := p.parseDataType() if err != nil { - return nil, nil, errorz.Errorf(errFmtPrefix+"parseDataType: %w", err) + return nil, nil, apperr.Errorf(errFmtPrefix+"parseDataType: %w", err) } column.DataType = dataType @@ -281,7 +282,7 @@ func (p *Parser) parseColumn(tableName *Ident) (*Column, []Constraint, error) { switch p.currentToken.Type { //nolint:exhaustive case TOKEN_NOT: if err := p.checkPeekToken(TOKEN_NULL); err != nil { - return nil, nil, errorz.Errorf(errFmtPrefix+"checkPeekToken: %w", err) + return nil, nil, apperr.Errorf(errFmtPrefix+"checkPeekToken: %w", err) } p.nextToken() // current = NULL column.NotNull = true @@ -291,7 +292,7 @@ func (p *Parser) parseColumn(tableName *Ident) (*Column, []Constraint, error) { p.nextToken() // current = DEFAULT def, err := p.parseColumnDefault() if err != nil { - return nil, nil, errorz.Errorf(errFmtPrefix+"parseColumnDefault: %w", err) + return nil, nil, apperr.Errorf(errFmtPrefix+"parseColumnDefault: %w", err) } column.Default = def continue @@ -304,7 +305,7 @@ func (p *Parser) parseColumn(tableName *Ident) (*Column, []Constraint, error) { cs, err := p.parseColumnConstraints(tableName, column) if err != nil { - return nil, nil, errorz.Errorf(errFmtPrefix+"parseColumnConstraints: %w", err) + return nil, nil, apperr.Errorf(errFmtPrefix+"parseColumnConstraints: %w", err) } if len(cs) > 0 { for _, c := range cs { @@ -312,7 +313,7 @@ func (p *Parser) parseColumn(tableName *Ident) (*Column, []Constraint, error) { } } default: - return nil, nil, errorz.Errorf(errFmtPrefix+"currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + return nil, nil, apperr.Errorf(errFmtPrefix+"currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) } return column, constraints, nil @@ -330,7 +331,7 @@ LabelDefault: case TOKEN_OPEN_PAREN: ids, err := p.parseExpr() if err != nil { - return nil, errorz.Errorf("parseExpr: %w", err) + return nil, apperr.Errorf("parseExpr: %w", err) } def.Value = def.Value.Append(ids...) continue @@ -355,7 +356,7 @@ LabelDefault: if isConstraint(p.currentToken.Type) { break LabelDefault } - return nil, errorz.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + return nil, apperr.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) } p.nextToken() @@ -364,11 +365,12 @@ LabelDefault: return def, nil } +//nolint:cyclop func (p *Parser) parseExpr() ([]*Ident, error) { idents := make([]*Ident, 0) if err := p.checkCurrentToken(TOKEN_OPEN_PAREN); err != nil { - return nil, errorz.Errorf("checkCurrentToken: %w", err) + return nil, apperr.Errorf("checkCurrentToken: %w", err) } idents = append(idents, NewRawIdent(p.currentToken.Literal.Str)) p.nextToken() // current = IDENT @@ -379,7 +381,7 @@ LabelExpr: case TOKEN_OPEN_PAREN: ids, err := p.parseExpr() if err != nil { - return nil, errorz.Errorf("parseExpr: %w", err) + return nil, apperr.Errorf("parseExpr: %w", err) } idents = append(idents, ids...) continue @@ -396,9 +398,13 @@ LabelExpr: } idents = append(idents, NewRawIdent(value)) case TOKEN_EOF: - return nil, errorz.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + return nil, apperr.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) default: - idents = append(idents, NewRawIdent(p.currentToken.Literal.Str)) + if isReservedValue(p.currentToken.Type) { + idents = append(idents, NewRawIdent(p.currentToken.Type.String())) + } else { + idents = append(idents, NewRawIdent(p.currentToken.Literal.Str)) + } } p.nextToken() @@ -416,7 +422,7 @@ LabelConstraints: switch p.currentToken.Type { //nolint:exhaustive case TOKEN_PRIMARY: if err := p.checkPeekToken(TOKEN_KEY); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = KEY constraints = constraints.Append(&PrimaryKeyConstraint{ @@ -425,7 +431,7 @@ LabelConstraints: }) case TOKEN_REFERENCES: if err := p.checkPeekToken(TOKEN_IDENT); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = table_name constraint := &ForeignKeyConstraint{ @@ -436,7 +442,30 @@ LabelConstraints: p.nextToken() // current = ( idents, err := p.parseColumnIdents() if err != nil { - return nil, errorz.Errorf("parseColumnIdents: %w", err) + return nil, apperr.Errorf("parseColumnIdents: %w", err) + } + // TODO: support ON DELETE, ON UPDATE + //nolint:nestif + if p.isCurrentToken(TOKEN_ON) { + onAction := p.currentToken.Literal.String() // current = ON + p.nextToken() // current = DELETE or UPDATE + if err := p.checkCurrentToken(TOKEN_DELETE, TOKEN_UPDATE); err != nil { + return nil, apperr.Errorf("checkCurrentToken: %w", err) + } + onAction += " " + p.currentToken.Literal.String() + if err := p.checkPeekToken(TOKEN_CASCADE, TOKEN_NO); err != nil { + return nil, apperr.Errorf("checkPeekToken: %w", err) + } + p.nextToken() // current = CASCADE or NO + onAction += " " + p.currentToken.Literal.String() // current = CASCADE or NO + if p.isCurrentToken(TOKEN_NO) { + if err := p.checkPeekToken(TOKEN_ACTION); err != nil { + return nil, apperr.Errorf("checkPeekToken: %w", err) + } + p.nextToken() // current = ACTION + onAction += " " + p.currentToken.Literal.String() // current = ACTION + } + constraint.OnAction = onAction } constraint.RefColumns = idents constraints = constraints.Append(constraint) @@ -447,7 +476,7 @@ LabelConstraints: }) case TOKEN_CHECK: if err := p.checkPeekToken(TOKEN_OPEN_PAREN); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = ( constraint := &CheckConstraint{ @@ -455,14 +484,14 @@ LabelConstraints: } idents, err := p.parseExpr() if err != nil { - return nil, errorz.Errorf("parseExpr: %w", err) + return nil, apperr.Errorf("parseExpr: %w", err) } constraint.Expr = constraint.Expr.Append(idents...) constraints = constraints.Append(constraint) case TOKEN_IDENT, TOKEN_COMMA, TOKEN_CLOSE_PAREN: break LabelConstraints default: - return nil, errorz.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + return nil, apperr.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) } p.nextToken() @@ -477,7 +506,7 @@ func (p *Parser) parseTableConstraint(tableName *Ident) (Constraint, error) { // if p.isCurrentToken(TOKEN_CONSTRAINT) { p.nextToken() // current = constraint_name if p.currentToken.Type != TOKEN_IDENT { - return nil, errorz.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + return nil, apperr.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) } constraintName = NewRawIdent(p.currentToken.Literal.Str) p.nextToken() // current = PRIMARY or CHECK or UNIQUE //diff:ignore-line-postgres-cockroach @@ -486,16 +515,16 @@ func (p *Parser) parseTableConstraint(tableName *Ident) (Constraint, error) { // switch p.currentToken.Type { //nolint:exhaustive case TOKEN_PRIMARY: if err := p.checkPeekToken(TOKEN_KEY); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = KEY if err := p.checkPeekToken(TOKEN_OPEN_PAREN); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = ( idents, err := p.parseColumnIdents() if err != nil { - return nil, errorz.Errorf("parseColumnIdents: %w", err) + return nil, apperr.Errorf("parseColumnIdents: %w", err) } if constraintName == nil { constraintName = NewRawIdent(fmt.Sprintf("%s_pkey", tableName.StringForDiff())) @@ -506,30 +535,49 @@ func (p *Parser) parseTableConstraint(tableName *Ident) (Constraint, error) { // }, nil case TOKEN_FOREIGN: if err := p.checkPeekToken(TOKEN_KEY); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = KEY if err := p.checkPeekToken(TOKEN_OPEN_PAREN); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = ( idents, err := p.parseColumnIdents() if err != nil { - return nil, errorz.Errorf("parseColumnIdents: %w", err) + return nil, apperr.Errorf("parseColumnIdents: %w", err) } if err := p.checkCurrentToken(TOKEN_REFERENCES); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = ref_table_name if err := p.checkCurrentToken(TOKEN_IDENT); err != nil { - return nil, errorz.Errorf("checkCurrentToken: %w", err) + return nil, apperr.Errorf("checkCurrentToken: %w", err) } refName := NewRawIdent(p.currentToken.Literal.Str) p.nextToken() // current = ( identsRef, err := p.parseColumnIdents() if err != nil { - return nil, errorz.Errorf("parseColumnIdents: %w", err) + return nil, apperr.Errorf("parseColumnIdents: %w", err) + } + // TODO: support ON DELETE, ON UPDATE + var onAction string + if p.isCurrentToken(TOKEN_ON) { + onAction = p.currentToken.Literal.String() // current = ON + p.nextToken() // current = DELETE or UPDATE + if err := p.checkCurrentToken(TOKEN_DELETE, TOKEN_UPDATE); err != nil { + return nil, apperr.Errorf("checkCurrentToken: %w", err) + } + onAction += " " + p.currentToken.Literal.String() + if err := p.checkPeekToken(TOKEN_CASCADE, TOKEN_NO); err != nil { + return nil, apperr.Errorf("checkPeekToken: %w", err) + } + p.nextToken() // current = CASCADE or NO + onAction += " " + p.currentToken.Literal.String() // current = CASCADE or NO + if p.isCurrentToken(TOKEN_NO) && p.isPeekToken(TOKEN_ACTION) { + p.nextToken() // current = ACTION + onAction += " " + p.currentToken.Literal.String() // current = ACTION + } } if constraintName == nil { name := tableName.StringForDiff() @@ -544,17 +592,18 @@ func (p *Parser) parseTableConstraint(tableName *Ident) (Constraint, error) { // Columns: idents, Ref: refName, RefColumns: identsRef, + OnAction: onAction, }, nil case TOKEN_UNIQUE: //diff:ignore-line-postgres-cockroach c := &UniqueConstraint{} //diff:ignore-line-postgres-cockroach if err := p.checkPeekToken(TOKEN_OPEN_PAREN); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = ( idents, err := p.parseColumnIdents() if err != nil { - return nil, errorz.Errorf("parseColumnIdents: %w", err) + return nil, apperr.Errorf("parseColumnIdents: %w", err) } if constraintName == nil { //diff:ignore-line-postgres-cockroach name := fmt.Sprintf("%s_unique", tableName.StringForDiff()) //diff:ignore-line-postgres-cockroach @@ -567,7 +616,7 @@ func (p *Parser) parseTableConstraint(tableName *Ident) (Constraint, error) { // c.Columns = idents return c, nil default: - return nil, errorz.Errorf("currentToken=%s: %w", p.currentToken.Type, ddl.ErrUnexpectedToken) + return nil, apperr.Errorf("currentToken=%s: %w", p.currentToken.Type, ddl.ErrUnexpectedToken) } } @@ -585,12 +634,12 @@ func (p *Parser) parseDataType() (*DataType, error) { p.nextToken() // current = WITH dataType.Name += " " + p.currentToken.Literal.String() if err := p.checkPeekToken(TOKEN_TIME); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = TIME dataType.Name += " " + p.currentToken.Literal.String() if err := p.checkPeekToken(TOKEN_ZONE); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = ZONE dataType.Name += " " + p.currentToken.Literal.String() @@ -601,7 +650,7 @@ func (p *Parser) parseDataType() (*DataType, error) { case TOKEN_DOUBLE: dataType.Name = p.currentToken.Literal.String() if err := p.checkPeekToken(TOKEN_PRECISION); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = PRECISION dataType.Name += " " + p.currentToken.Literal.String() @@ -609,7 +658,7 @@ func (p *Parser) parseDataType() (*DataType, error) { case TOKEN_CHARACTER: dataType.Name = p.currentToken.Literal.String() if err := p.checkPeekToken(TOKEN_VARYING); err != nil { - return nil, errorz.Errorf("checkPeekToken: %w", err) + return nil, apperr.Errorf("checkPeekToken: %w", err) } p.nextToken() // current = VARYING dataType.Name += " " + p.currentToken.Literal.String() @@ -623,7 +672,7 @@ func (p *Parser) parseDataType() (*DataType, error) { p.nextToken() // current = ( idents, err := p.parseIdents() if err != nil { - return nil, errorz.Errorf("parseIdents: %w", err) + return nil, apperr.Errorf("parseIdents: %w", err) } dataType.Expr = dataType.Expr.Append(idents...) } @@ -648,7 +697,7 @@ LabelIdents: p.nextToken() break LabelIdents default: - return nil, errorz.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + return nil, apperr.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) } p.nextToken() } @@ -669,12 +718,13 @@ LabelIdents: case TOKEN_CLOSE_PAREN: break LabelIdents case TOKEN_EOF, TOKEN_ILLEGAL: - return nil, errorz.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + return nil, apperr.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) default: idents = append(idents, NewRawIdent(p.currentToken.Literal.Str)) } p.nextToken() } + return idents, nil } @@ -743,7 +793,7 @@ func (p *Parser) checkCurrentToken(expectedTypes ...TokenType) error { return nil } } - return errorz.Errorf("currentToken: expected=%s, but got=%#v: %w", stringz.JoinStringers(",", expectedTypes...), p.currentToken, ddl.ErrUnexpectedToken) + return apperr.Errorf("currentToken: expected=%s, but got=%#v: %w", stringz.JoinStringers(",", expectedTypes...), p.currentToken, ddl.ErrUnexpectedToken) } func (p *Parser) isPeekToken(expectedTypes ...TokenType) bool { @@ -761,5 +811,5 @@ func (p *Parser) checkPeekToken(expectedTypes ...TokenType) error { return nil } } - return errorz.Errorf("peekToken: expected=%s, but got=%#v: %w", stringz.JoinStringers(",", expectedTypes...), p.peekToken, ddl.ErrUnexpectedToken) + return apperr.Errorf("peekToken: expected=%s, but got=%#v: %w", stringz.JoinStringers(",", expectedTypes...), p.peekToken, ddl.ErrUnexpectedToken) } diff --git a/pkg/ddl/postgres/parser_test.go b/pkg/ddl/postgres/parser_test.go index f7985b5..dfe2208 100644 --- a/pkg/ddl/postgres/parser_test.go +++ b/pkg/ddl/postgres/parser_test.go @@ -224,6 +224,26 @@ CREATE TABLE IF NOT EXISTS complex_defaults ( input: `CREATE TABLE "users" ("id" UUID, PRIMARY KEY (NOT`, wantErr: ddl.ErrUnexpectedToken, }, + { + name: "failure,CREATE_TABLE_table_name_column_name_REFERENCES_INVALID", + input: `CREATE TABLE "users" ("id" UUID REFERENCES foo (foo_id) ON`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_REFERENCES_INVALID", + input: `CREATE TABLE "users" ("id" UUID REFERENCES foo (foo_id) ON DELETE`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_REFERENCES_INVALID", + input: `CREATE TABLE "users" ("id" UUID REFERENCES foo (foo_id) ON DELETE NO`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_REFERENCES_INVALID", + input: `CREATE TABLE "users" ("id" UUID REFERENCES foo (foo_id) ON DELETE NO ACTION`, + wantErr: ddl.ErrUnexpectedToken, + }, { name: "failure,CREATE_TABLE_table_name_column_name_CONSTRAINT_INVALID_FOREIGN", input: `CREATE TABLE "users" ("id" UUID, FOREIGN NOT`, @@ -259,6 +279,26 @@ CREATE TABLE IF NOT EXISTS complex_defaults ( input: `CREATE TABLE "users" ("id" UUID, FOREIGN KEY ("group_id") REFERENCES "groups" ("id")`, wantErr: ddl.ErrUnexpectedToken, }, + { + name: "failure,CREATE_TABLE_table_name_column_name_CONSTRAINT_FOREIGN_KEY_IDENTS_REFERENCES_ON_INVALID", + input: `CREATE TABLE "users" ("id" UUID, FOREIGN KEY ("group_id") REFERENCES "groups" ("id") ON`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_CONSTRAINT_FOREIGN_KEY_IDENTS_REFERENCES_ON_DELETE_INVALID", + input: `CREATE TABLE "users" ("id" UUID, FOREIGN KEY ("group_id") REFERENCES "groups" ("id") ON DELETE`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_CONSTRAINT_FOREIGN_KEY_IDENTS_REFERENCES_ON_DELETE_NO_INVALID", + input: `CREATE TABLE "users" ("id" UUID, FOREIGN KEY ("group_id") REFERENCES "groups" ("id") ON DELETE NO`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_CONSTRAINT_FOREIGN_KEY_IDENTS_REFERENCES_ON_DELETE_NO_INVALID", + input: `CREATE TABLE "users" ("id" UUID, FOREIGN KEY ("group_id") REFERENCES "groups" ("id") ON DELETE NO ACTION`, + wantErr: ddl.ErrUnexpectedToken, + }, { name: "failure,CREATE_TABLE_table_name_column_name_CONSTRAINT_UNIQUE_INVALID", input: `CREATE TABLE "users" ("id" UUID, UNIQUE NOT`, @@ -384,6 +424,16 @@ func TestParser_parseColumn(t *testing.T) { func TestParser_parseExpr(t *testing.T) { t.Parallel() + t.Run("success,isReservedValue", func(t *testing.T) { + t.Parallel() + + p := NewParser(NewLexer(`(null)`)) + p.nextToken() + p.nextToken() + _, err := p.parseExpr() + require.NoError(t, err) + }) + t.Run("failure,invalid", func(t *testing.T) { t.Parallel() diff --git a/pkg/ddl/spanner/ddl.go b/pkg/ddl/spanner/ddl.go new file mode 100644 index 0000000..cda8dbb --- /dev/null +++ b/pkg/ddl/spanner/ddl.go @@ -0,0 +1,150 @@ +package spanner + +import ( + stringz "github.com/kunitsucom/util.go/strings" + + "github.com/kunitsucom/ddlctl/pkg/ddl/internal" +) + +const ( + Dialect = "spanner" + DriverName = "spanner" + Indent = " " + CommentPrefix = "-- " +) + +type Verb string + +const ( + VerbCreate Verb = "CREATE" + VerbAlter Verb = "ALTER" + VerbDrop Verb = "DROP" + VerbRename Verb = "RENAME" + VerbTruncate Verb = "TRUNCATE" +) + +type Object string + +const ( + ObjectTable Object = "TABLE" + ObjectIndex Object = "INDEX" + ObjectView Object = "VIEW" +) + +type Action string + +const ( + ActionAdd Action = "ADD" + ActionDrop Action = "DROP" + ActionAlter Action = "ALTER" + ActionRename Action = "RENAME" +) + +type Stmt interface { + isStmt() + GetNameForDiff() string + String() string +} + +type DDL struct { + Stmts []Stmt +} + +func (d *DDL) String() string { + if d == nil { + return "" + } + return stringz.JoinStringers("", d.Stmts...) +} + +type Ident struct { + Name string + QuotationMark string + Raw string +} + +func (i *Ident) GoString() string { return internal.GoString(*i) } + +func (i *Ident) String() string { + if i == nil { + return "" + } + return i.Raw +} + +func (i *Ident) StringForDiff() string { + if i == nil { + return "" + } + return i.Name +} + +type ColumnIdent struct { + Ident *Ident + Order *Order //diff:ignore-line-postgres-cockroach +} + +type Order struct{ Desc bool } //diff:ignore-line-postgres-cockroach + +func (i *ColumnIdent) GoString() string { return internal.GoString(*i) } + +func (i *ColumnIdent) String() string { + str := i.Ident.String() + if i.Order != nil { //diff:ignore-line-postgres-cockroach + if i.Order.Desc { //diff:ignore-line-postgres-cockroach + str += " DESC" //diff:ignore-line-postgres-cockroach + } else { //diff:ignore-line-postgres-cockroach + str += " ASC" //diff:ignore-line-postgres-cockroach + } //diff:ignore-line-postgres-cockroach + } //diff:ignore-line-postgres-cockroach + return str +} + +func (i *ColumnIdent) StringForDiff() string { + str := i.Ident.StringForDiff() + if i.Order != nil && i.Order.Desc { //diff:ignore-line-postgres-cockroach + str += " DESC" //diff:ignore-line-postgres-cockroach + } else { //diff:ignore-line-postgres-cockroach + str += " ASC" //diff:ignore-line-postgres-cockroach + } //diff:ignore-line-postgres-cockroach + return str +} + +type DataType struct { + Name string + Type TokenType + Expr *Expr +} + +func (s *DataType) String() string { + if s == nil { + return "" + } + str := s.Name + if s.Expr != nil && len(s.Expr.Idents) > 0 { + str += "(" + s.Expr.String() + ")" + } + return str +} + +func (s *DataType) StringForDiff() string { + if s == nil { + return "" + } + var str string + if s.Type != "" { + str += string(s.Type) + } else { + str += string(TOKEN_ILLEGAL) + } + + if s.Expr != nil && len(s.Expr.Idents) > 0 { + str += "(" + for _, ident := range s.Expr.Idents { + str += ident.StringForDiff() + } + str += ")" + } + + return str +} diff --git a/pkg/ddl/spanner/ddl_index_create.go b/pkg/ddl/spanner/ddl_index_create.go new file mode 100644 index 0000000..22ec886 --- /dev/null +++ b/pkg/ddl/spanner/ddl_index_create.go @@ -0,0 +1,76 @@ +package spanner + +import ( + "strings" + + stringz "github.com/kunitsucom/util.go/strings" + + "github.com/kunitsucom/ddlctl/pkg/ddl/internal" +) + +// MEMO: https://cloud.google.com/spanner/docs/reference/standard-sql/data-definition-language#create-index + +var _ Stmt = (*CreateIndexStmt)(nil) + +type CreateIndexStmt struct { + Comment string + Unique bool + IfNotExists bool + Name *ObjectName + TableName *ObjectName + Using []*Ident + Columns []*ColumnIdent +} + +func (s *CreateIndexStmt) GetNameForDiff() string { + return s.Name.StringForDiff() +} + +func (s *CreateIndexStmt) String() string { + var str string + if s.Comment != "" { + comments := strings.Split(s.Comment, "\n") + for i := range comments { + if comments[i] != "" { + str += CommentPrefix + comments[i] + "\n" + } + } + } + str += "CREATE " + if s.Unique { + str += "UNIQUE " + } + str += "INDEX " + if s.IfNotExists { + str += "IF NOT EXISTS " + } + str += s.Name.String() + " ON " + s.TableName.String() + if len(s.Using) > 0 { + str += " USING " + str += stringz.JoinStringers(" ", s.Using...) + } + str += " (" + stringz.JoinStringers(", ", s.Columns...) + ");\n" + return str +} + +func (s *CreateIndexStmt) StringForDiff() string { + str := "CREATE " + if s.Unique { + str += "UNIQUE " + } + str += "INDEX " + str += s.Name.StringForDiff() + " ON " + s.TableName.StringForDiff() + // TODO: add USING + str += " (" + for i, c := range s.Columns { + if i > 0 { + str += ", " + } + str += c.StringForDiff() + } + str += ");\n" + return str +} + +func (*CreateIndexStmt) isStmt() {} +func (s *CreateIndexStmt) GoString() string { return internal.GoString(*s) } diff --git a/pkg/ddl/spanner/ddl_index_create_test.go b/pkg/ddl/spanner/ddl_index_create_test.go new file mode 100644 index 0000000..624635c --- /dev/null +++ b/pkg/ddl/spanner/ddl_index_create_test.go @@ -0,0 +1,51 @@ +package spanner + +import ( + "testing" + + "github.com/kunitsucom/util.go/testing/require" +) + +func TestCreateIndexStmt_GetNameForDiff(t *testing.T) { + t.Parallel() + + t.Run("success,", func(t *testing.T) { + t.Parallel() + + stmt := &CreateIndexStmt{Name: &ObjectName{Name: &Ident{Name: "test", QuotationMark: `"`, Raw: `"test"`}}} + expected := "test" + actual := stmt.GetNameForDiff() + + require.Equal(t, expected, actual) + }) +} + +func TestCreateIndexStmt_String(t *testing.T) { + t.Parallel() + + t.Run("success,", func(t *testing.T) { + t.Parallel() + + stmt := &CreateIndexStmt{ + Comment: "test comment content", + IfNotExists: true, + Name: &ObjectName{Name: &Ident{Name: "test", QuotationMark: `"`, Raw: `"test"`}}, + TableName: &ObjectName{Name: &Ident{Name: "users", QuotationMark: `"`, Raw: `"users"`}}, + Using: []*Ident{{Name: "btree", QuotationMark: ``, Raw: `btree`}}, + Columns: []*ColumnIdent{ + { + Ident: &Ident{Name: "id", QuotationMark: `"`, Raw: `"id"`}, + Order: &Order{Desc: false}, + }, + }, + } + expected := `-- test comment content +CREATE INDEX IF NOT EXISTS "test" ON "users" USING btree ("id" ASC); +` + actual := stmt.String() + + require.Equal(t, expected, actual) + + t.Logf("✅: %s: stmt: %#v", t.Name(), stmt) + }) +} diff --git a/pkg/ddl/spanner/ddl_index_drop.go b/pkg/ddl/spanner/ddl_index_drop.go new file mode 100644 index 0000000..8289901 --- /dev/null +++ b/pkg/ddl/spanner/ddl_index_drop.go @@ -0,0 +1,42 @@ +package spanner + +import ( + "strings" + + "github.com/kunitsucom/ddlctl/pkg/ddl/internal" +) + +// MEMO: https://cloud.google.com/spanner/docs/reference/standard-sql/data-definition-language#drop-index + +var _ Stmt = (*DropIndexStmt)(nil) + +type DropIndexStmt struct { + Comment string + IfExists bool + Name *ObjectName +} + +func (s *DropIndexStmt) GetNameForDiff() string { + return s.Name.StringForDiff() +} + +func (s *DropIndexStmt) String() string { + var str string + if s.Comment != "" { + comments := strings.Split(s.Comment, "\n") + for i := range comments { + if comments[i] != "" { + str += CommentPrefix + comments[i] + "\n" + } + } + } + str += "DROP INDEX " + if s.IfExists { + str += "IF EXISTS " + } + str += s.Name.String() + ";\n" + return str +} + +func (*DropIndexStmt) isStmt() {} +func (s *DropIndexStmt) GoString() string { return internal.GoString(*s) } diff --git a/pkg/ddl/spanner/ddl_index_drop_test.go b/pkg/ddl/spanner/ddl_index_drop_test.go new file mode 100644 index 0000000..b708020 --- /dev/null +++ b/pkg/ddl/spanner/ddl_index_drop_test.go @@ -0,0 +1,40 @@ +package spanner + +import ( + "testing" + + "github.com/kunitsucom/util.go/testing/require" +) + +func TestDropIndexStmt_GetNameForDiff(t *testing.T) { + t.Parallel() + + t.Run("success,", func(t *testing.T) { + t.Parallel() + + stmt := &DropIndexStmt{Name: &ObjectName{Name: &Ident{Name: "test", QuotationMark: `"`, Raw: `"test"`}}} + expected := "test" + actual := stmt.GetNameForDiff() + + require.Equal(t, expected, actual) + }) +} + +func TestDropIndexStmt_String(t *testing.T) { + t.Parallel() + + t.Run("success,", func(t *testing.T) { + t.Parallel() + + stmt := &DropIndexStmt{ + IfExists: true, + Name: &ObjectName{Name: &Ident{Name: "test", QuotationMark: `"`, Raw: `"test"`}}, + } + expected := `DROP INDEX IF EXISTS "test";` + "\n" + actual := stmt.String() + + require.Equal(t, expected, actual) + + t.Logf("✅: %s: stmt: %#v", t.Name(), stmt) + }) +} diff --git a/pkg/ddl/spanner/ddl_table.go b/pkg/ddl/spanner/ddl_table.go new file mode 100644 index 0000000..68279f7 --- /dev/null +++ b/pkg/ddl/spanner/ddl_table.go @@ -0,0 +1,331 @@ +package spanner + +import ( + //diff:ignore-line-postgres-cockroach + "strings" + + stringz "github.com/kunitsucom/util.go/strings" + + "github.com/kunitsucom/ddlctl/pkg/ddl/internal" +) + +type Constraint interface { + isConstraint() + GetName() *Ident + GoString() string + String() string + StringForDiff() string +} + +type Constraints []Constraint + +func (constraints Constraints) Append(constraint Constraint) Constraints { + for i := range constraints { + if constraints[i].GetName().Name == constraint.GetName().Name { + constraints[i] = constraint + return constraints + } + } + constraints = append(constraints, constraint) + + return constraints +} + +// ForeignKeyConstraint represents a FOREIGN KEY constraint. +type ForeignKeyConstraint struct { + Name *Ident + Columns []*ColumnIdent + Ref *Ident + RefColumns []*ColumnIdent +} + +var _ Constraint = (*ForeignKeyConstraint)(nil) + +func (*ForeignKeyConstraint) isConstraint() {} +func (c *ForeignKeyConstraint) GetName() *Ident { return c.Name } +func (c *ForeignKeyConstraint) GoString() string { return internal.GoString(*c) } +func (c *ForeignKeyConstraint) String() string { + var str string + if c.Name != nil { + str += "CONSTRAINT " + c.Name.String() + " " + } + str += "FOREIGN KEY" + str += " (" + stringz.JoinStringers(", ", c.Columns...) + ")" + str += " REFERENCES " + c.Ref.String() + str += " (" + stringz.JoinStringers(", ", c.RefColumns...) + ")" + return str +} + +func (c *ForeignKeyConstraint) StringForDiff() string { + var str string + if c.Name != nil { + str += "CONSTRAINT " + c.Name.StringForDiff() + " " + } + str += "FOREIGN KEY" + str += " (" + for i, v := range c.Columns { + if i != 0 { + str += ", " + } + str += v.StringForDiff() + } + str += ")" + str += " REFERENCES " + c.Ref.Name + str += " (" + for i, v := range c.RefColumns { + if i != 0 { + str += ", " + } + str += v.StringForDiff() + } + str += ")" + return str +} + +// IndexConstraint represents a UNIQUE constraint. //diff:ignore-line-postgres-cockroach. +type IndexConstraint struct { //diff:ignore-line-postgres-cockroach + Name *Ident + Unique bool //diff:ignore-line-postgres-cockroach + Columns []*ColumnIdent +} + +// CheckConstraint represents a CHECK constraint. +type CheckConstraint struct { + Name *Ident + Expr *Expr +} + +var _ Constraint = (*CheckConstraint)(nil) + +func (*CheckConstraint) isConstraint() {} +func (c *CheckConstraint) GetName() *Ident { return c.Name } +func (c *CheckConstraint) GoString() string { return internal.GoString(*c) } +func (c *CheckConstraint) String() string { + var str string + if c.Name != nil { + str += "CONSTRAINT " + c.Name.String() + " " + } + str += "CHECK " + str += c.Expr.String() + return str +} + +func (c *CheckConstraint) StringForDiff() string { + var str string + if c.Name != nil { + str += "CONSTRAINT " + c.Name.StringForDiff() + " " + } + str += "CHECK " + for i, v := range c.Expr.Idents { + if i != 0 { + str += " " + } + str += v.StringForDiff() + } + return str +} + +func NewObjectName(name string) *ObjectName { + objName := &ObjectName{} + + tableName := NewRawIdent(name) + switch name := strings.Split(tableName.Name, "."); len(name) { //nolint:exhaustive + case 2: + // CREATE TABLE "schema.table" + objName.Schema = NewRawIdent(tableName.QuotationMark + name[0] + tableName.QuotationMark) + objName.Name = NewRawIdent(tableName.QuotationMark + name[1] + tableName.QuotationMark) + default: + // CREATE TABLE "table" + objName.Name = tableName + } + + return objName +} + +type ObjectName struct { + Schema *Ident + Name *Ident +} + +func (t *ObjectName) String() string { + if t == nil { + return "" + } + if t.Schema != nil { + return t.Name.QuotationMark + t.Schema.StringForDiff() + "." + t.Name.StringForDiff() + t.Name.QuotationMark + } + return t.Name.String() +} + +func (t *ObjectName) StringForDiff() string { + if t == nil { + return "" + } + if t.Schema != nil { + return t.Schema.StringForDiff() + "." + t.Name.StringForDiff() + } + return t.Name.StringForDiff() +} + +type Column struct { + Name *Ident + DataType *DataType + Default *Default + NotNull bool + Options *Expr +} + +type Default struct { + Value *Expr +} + +func (d *Expr) Append(idents ...*Ident) *Expr { + if d == nil { + d = &Expr{Idents: idents} + return d + } + d.Idents = append(d.Idents, idents...) + return d +} + +type Expr struct { + Idents []*Ident +} + +func (d *Expr) GoString() string { return internal.GoString(*d) } + +//nolint:cyclop +func (d *Expr) String() string { + if d == nil || len(d.Idents) == 0 { + return "" + } + + var str string + for i := range d.Idents { + switch { + // MEMO: backup + // case i != 0 && (d.Idents[i-1].String() == "||" || d.Idents[i].String() == "||"): + // str += " " + case i == 0 || + d.Idents[i-1].String() == "(" || d.Idents[i].String() == "(" || + d.Idents[i].String() == ")" || + d.Idents[i-1].String() == "::" || d.Idents[i].String() == "::" || + d.Idents[i-1].String() == ":::" || d.Idents[i].String() == ":::" || //diff:ignore-line-postgres-cockroach + d.Idents[i].String() == ",": + // noop + default: + str += " " + } + str += d.Idents[i].String() + } + + return str +} + +func (d *Expr) StringForDiff() string { + if d == nil || len(d.Idents) == 0 { + return "" + } + + var str string + for i, v := range d.Idents { + if i != 0 { + str += " " + } + str += v.StringForDiff() + } + + return str +} + +func (d *Default) GoString() string { return internal.GoString(*d) } + +func (d *Default) String() string { + if d == nil { + return "" + } + if d.Value != nil { + return "DEFAULT " + d.Value.String() + } + return "" +} + +func (d *Default) StringForDiff() string { + if d == nil { + return "" + } + if e := d.Value; e != nil { + str := "DEFAULT (" + for i, v := range d.Value.Idents { + if i != 0 { + str += " " + } + str += v.StringForDiff() + } + str += ")" + return str + } + return "" +} + +func (c *Column) String() string { + str := c.Name.String() + " " + + c.DataType.String() + if c.NotNull { //diff:ignore-line-postgres-cockroach + str += " NOT NULL" //diff:ignore-line-postgres-cockroach + } //diff:ignore-line-postgres-cockroach + if c.Default != nil { //diff:ignore-line-postgres-cockroach + str += " " + c.Default.String() //diff:ignore-line-postgres-cockroach + } + if c.Options != nil && len(c.Options.Idents) > 0 { //diff:ignore-line-postgres-cockroach + str += " OPTIONS " + c.Options.String() //diff:ignore-line-postgres-cockroach + } + return str +} + +func (c *Column) GoString() string { return internal.GoString(*c) } + +type Option struct { + Name string + Value *Expr +} + +func (o *Option) String() string { + if o.Value == nil { + return "" + } + return o.Name + " " + o.Value.String() +} + +func (o *Option) StringForDiff() string { + if o.Value == nil { + return "" + } + return o.Name + " " + o.Value.StringForDiff() +} + +func (o *Option) GoString() string { return internal.GoString(*o) } + +type Options []*Option + +func (o Options) String() string { + var str string + for i, v := range o { + if i != 0 { + str += ",\n" + } + str += v.String() + } + return str +} + +func (o Options) StringForDiff() string { + var str string + for i, v := range o { + if i != 0 { + str += ", " + } + str += v.StringForDiff() + } + return str +} diff --git a/pkg/ddl/spanner/ddl_table_alter.go b/pkg/ddl/spanner/ddl_table_alter.go new file mode 100644 index 0000000..5fa644e --- /dev/null +++ b/pkg/ddl/spanner/ddl_table_alter.go @@ -0,0 +1,231 @@ +package spanner + +import ( + "strings" + + "github.com/kunitsucom/ddlctl/pkg/ddl/internal" +) + +// MEMO: https://cloud.google.com/spanner/docs/reference/standard-sql/data-definition-language#alter_table + +var _ Stmt = (*AlterTableStmt)(nil) + +type AlterTableStmt struct { + Comment string + Indent string + Name *ObjectName + Action AlterTableAction +} + +func (*AlterTableStmt) isStmt() {} + +func (s *AlterTableStmt) GetNameForDiff() string { + return s.Name.StringForDiff() +} + +//nolint:cyclop,funlen +func (s *AlterTableStmt) String() string { + var str string + if s.Comment != "" { + comments := strings.Split(s.Comment, "\n") + for i := range comments { + if comments[i] != "" { + str += CommentPrefix + comments[i] + "\n" + } + } + } + str += "ALTER TABLE " + str += s.Name.String() + " " + switch a := s.Action.(type) { + case *RenameTable: + str += "RENAME TO " + str += a.NewName.String() + case *RenameColumn: + str += "RENAME COLUMN " + a.Name.String() + " TO " + a.NewName.String() + case *RenameConstraint: + str += "RENAME CONSTRAINT " + a.Name.String() + " TO " + a.NewName.String() + case *AddColumn: + str += "ADD COLUMN " + a.Column.String() + case *DropColumn: + str += "DROP COLUMN " + a.Name.String() + case *AlterColumn: + str += "ALTER COLUMN " + a.Name.String() + " " + switch ca := a.Action.(type) { + case *AlterColumnDataType: + str += ca.DataType.String() + if ca.NotNull { + str += " NOT NULL" + } + case *AlterColumnSetDefault: + str += "SET " + ca.Default.String() + case *AlterColumnDropDefault: + str += "DROP DEFAULT" + case *AlterColumnSetOptions: + str += "SET OPTIONS " + ca.Options.String() + case *AlterColumnDropOptions: + str += "DROP OPTIONS" + } + case *AddConstraint: + str += "ADD " + a.Constraint.String() + if a.NotValid { + str += " NOT VALID" + } + case *DropConstraint: + str += "DROP CONSTRAINT " + a.Name.String() + case *AlterConstraint: + str += "ALTER CONSTRAINT " + a.Name.String() + " " + if a.Deferrable { + str += "DEFERRABLE" + } else { + str += "NOT DEFERRABLE" + } + if a.InitiallyDeferred { + str += " INITIALLY DEFERRED" + } else { + str += " INITIALLY IMMEDIATE" + } + } + + return str + ";\n" +} + +func (s *AlterTableStmt) GoString() string { return internal.GoString(*s) } + +type AlterTableAction interface { + isAlterTableAction() + GoString() string +} + +// RenameTable represents ALTER TABLE table_name RENAME TO new_table_name. +type RenameTable struct { + NewName *ObjectName +} + +func (*RenameTable) isAlterTableAction() {} + +func (s *RenameTable) GoString() string { return internal.GoString(*s) } + +// RenameConstraint represents ALTER TABLE table_name RENAME COLUMN. +type RenameConstraint struct { + Name *Ident + NewName *Ident +} + +func (*RenameConstraint) isAlterTableAction() {} + +func (s *RenameConstraint) GoString() string { return internal.GoString(*s) } + +// RenameColumn represents ALTER TABLE table_name RENAME COLUMN. +type RenameColumn struct { + Name *Ident + NewName *Ident +} + +func (*RenameColumn) isAlterTableAction() {} + +func (s *RenameColumn) GoString() string { return internal.GoString(*s) } + +// AddColumn represents ALTER TABLE table_name ADD COLUMN. +type AddColumn struct { + Column *Column +} + +func (*AddColumn) isAlterTableAction() {} + +func (s *AddColumn) GoString() string { return internal.GoString(*s) } + +// DropColumn represents ALTER TABLE table_name DROP COLUMN. +type DropColumn struct { + Name *Ident +} + +func (*DropColumn) isAlterTableAction() {} + +func (s *DropColumn) GoString() string { return internal.GoString(*s) } + +// AlterColumn represents ALTER TABLE table_name ALTER COLUMN. +type AlterColumn struct { + Name *Ident + Action AlterColumnAction +} + +func (*AlterColumn) isAlterTableAction() {} + +func (s *AlterColumn) GoString() string { return internal.GoString(*s) } + +type AlterColumnAction interface { + isAlterColumnAction() + GoString() string +} + +// AlterColumnDataType represents ALTER TABLE table_name ALTER COLUMN column_name data_type NOT NULL. +type AlterColumnDataType struct { + DataType *DataType + NotNull bool +} + +func (*AlterColumnDataType) isAlterColumnAction() {} + +func (s *AlterColumnDataType) GoString() string { return internal.GoString(*s) } + +// AlterColumnSetDefault represents ALTER TABLE table_name ALTER COLUMN column_name SET DEFAULT. +type AlterColumnSetDefault struct { + Default *Default +} + +func (*AlterColumnSetDefault) isAlterColumnAction() {} + +func (s *AlterColumnSetDefault) GoString() string { return internal.GoString(*s) } + +// AlterColumnDropDefault represents ALTER TABLE table_name ALTER COLUMN column_name DROP DEFAULT. +type AlterColumnDropDefault struct{} + +func (*AlterColumnDropDefault) isAlterColumnAction() {} + +func (s *AlterColumnDropDefault) GoString() string { return internal.GoString(*s) } + +// AlterColumnSetOptions represents ALTER TABLE table_name ALTER COLUMN column_name SET OPTIONS. +type AlterColumnSetOptions struct { + Options *Expr +} + +func (*AlterColumnSetOptions) isAlterColumnAction() {} + +func (s *AlterColumnSetOptions) GoString() string { return internal.GoString(*s) } + +// AlterColumnDropOptions represents ALTER TABLE table_name ALTER COLUMN column_name DROP OPTIONS. +type AlterColumnDropOptions struct{} + +func (*AlterColumnDropOptions) isAlterColumnAction() {} + +func (s *AlterColumnDropOptions) GoString() string { return internal.GoString(*s) } + +// AddConstraint represents ALTER TABLE table_name ADD CONSTRAINT. +type AddConstraint struct { + Constraint Constraint + NotValid bool +} + +func (*AddConstraint) isAlterTableAction() {} + +func (s *AddConstraint) GoString() string { return internal.GoString(*s) } + +// DropConstraint represents ALTER TABLE table_name DROP CONSTRAINT. +type DropConstraint struct { + Name *Ident +} + +func (*DropConstraint) isAlterTableAction() {} + +func (s *DropConstraint) GoString() string { return internal.GoString(*s) } + +// AlterConstraint represents ALTER TABLE table_name ALTER CONSTRAINT. +type AlterConstraint struct { + Name *Ident + Deferrable bool + InitiallyDeferred bool +} + +func (*AlterConstraint) isAlterTableAction() {} + +func (s *AlterConstraint) GoString() string { return internal.GoString(*s) } diff --git a/pkg/ddl/spanner/ddl_table_alter_test.go b/pkg/ddl/spanner/ddl_table_alter_test.go new file mode 100644 index 0000000..c16a919 --- /dev/null +++ b/pkg/ddl/spanner/ddl_table_alter_test.go @@ -0,0 +1,288 @@ +package spanner + +import ( + "fmt" + "testing" + + "github.com/kunitsucom/util.go/testing/assert" + "github.com/kunitsucom/util.go/testing/require" +) + +func Test_isAlterTableAction(t *testing.T) { + t.Parallel() + + (&RenameTable{}).isAlterTableAction() + (&RenameConstraint{}).isAlterTableAction() + (&RenameColumn{}).isAlterTableAction() + (&AddColumn{}).isAlterTableAction() + (&DropColumn{}).isAlterTableAction() + (&AlterColumn{}).isAlterTableAction() + (&AddConstraint{}).isAlterTableAction() + (&DropConstraint{}).isAlterTableAction() + (&AlterConstraint{}).isAlterTableAction() + (&AlterColumnDataType{}).isAlterColumnAction() + (&AlterColumnSetOptions{}).isAlterColumnAction() + (&AlterColumnDropOptions{}).isAlterColumnAction() +} + +func Test_isAlterColumnAction(t *testing.T) { + t.Parallel() + + (&AlterColumnDataType{}).isAlterColumnAction() + (&AlterColumnSetDefault{}).isAlterColumnAction() + (&AlterColumnDropDefault{}).isAlterColumnAction() +} + +func TestAlterTableStmt_String(t *testing.T) { + t.Parallel() + + t.Run("success,RenameTable", func(t *testing.T) { + t.Parallel() + + stmt := &AlterTableStmt{ + Name: &ObjectName{Name: &Ident{Name: "users", QuotationMark: `"`, Raw: `"users"`}}, + Action: &RenameTable{ + NewName: &ObjectName{Name: &Ident{Name: "accounts", QuotationMark: `"`, Raw: `"accounts"`}}, + }, + } + + expected := `ALTER TABLE "users" RENAME TO "accounts";` + "\n" + actual := stmt.String() + + if !assert.Equal(t, expected, actual) { + assert.Equal(t, fmt.Sprintf("%#v", expected), fmt.Sprintf("%#v", actual)) + } + t.Logf("✅: %s: stmt: %#v", t.Name(), stmt) + }) + + t.Run("success,RenameColumn", func(t *testing.T) { + t.Parallel() + + stmt := &AlterTableStmt{ + Name: &ObjectName{Name: &Ident{Name: "users", QuotationMark: `"`, Raw: `"users"`}}, + Action: &RenameColumn{Name: &Ident{Name: "name", QuotationMark: `"`, Raw: `"name"`}, NewName: &Ident{Name: "username", QuotationMark: `"`, Raw: `"username"`}}, + } + + expected := `ALTER TABLE "users" RENAME COLUMN "name" TO "username";` + "\n" + actual := stmt.String() + + if !assert.Equal(t, expected, actual) { + assert.Equal(t, fmt.Sprintf("%#v", expected), fmt.Sprintf("%#v", actual)) + } + t.Logf("✅: %s: stmt: %#v", t.Name(), stmt) + }) + + t.Run("success,RenameConstraint", func(t *testing.T) { + t.Parallel() + + stmt := &AlterTableStmt{ + Name: &ObjectName{Name: &Ident{Name: "users", QuotationMark: `"`, Raw: `"users"`}}, + Action: &RenameConstraint{Name: &Ident{Name: "users_pkey", QuotationMark: `"`, Raw: `"users_pkey"`}, NewName: &Ident{Name: "users_id_pkey", QuotationMark: `"`, Raw: `"users_id_pkey"`}}, + } + + expected := `ALTER TABLE "users" RENAME CONSTRAINT "users_pkey" TO "users_id_pkey";` + "\n" + actual := stmt.String() + + if !assert.Equal(t, expected, actual) { + assert.Equal(t, fmt.Sprintf("%#v", expected), fmt.Sprintf("%#v", actual)) + } + t.Logf("✅: %s: stmt: %#v", t.Name(), stmt) + }) + + t.Run("success,AddColumn", func(t *testing.T) { + t.Parallel() + + stmt := &AlterTableStmt{ + Name: &ObjectName{Name: &Ident{Name: "users", QuotationMark: `"`, Raw: `"users"`}}, + Action: &AddColumn{ + Column: &Column{ + Name: &Ident{Name: "age", QuotationMark: `"`, Raw: `"age"`}, + DataType: &DataType{Name: "INTEGER"}, + }, + }, + } + + expected := `ALTER TABLE "users" ADD COLUMN "age" INTEGER;` + "\n" + actual := stmt.String() + + if !assert.Equal(t, expected, actual) { + assert.Equal(t, fmt.Sprintf("%#v", expected), fmt.Sprintf("%#v", actual)) + } + t.Logf("✅: %s: stmt: %#v", t.Name(), stmt) + }) + + t.Run("success,DropColumn", func(t *testing.T) { + t.Parallel() + + stmt := &AlterTableStmt{ + Name: &ObjectName{Name: &Ident{Name: "users", QuotationMark: `"`, Raw: `"users"`}}, + Action: &DropColumn{Name: &Ident{Name: "age", QuotationMark: `"`, Raw: `"age"`}}, + } + + expected := `ALTER TABLE "users" DROP COLUMN "age";` + "\n" + actual := stmt.String() + + if !assert.Equal(t, expected, actual) { + assert.Equal(t, fmt.Sprintf("%#v", expected), fmt.Sprintf("%#v", actual)) + } + t.Logf("✅: %s: stmt: %#v", t.Name(), stmt) + }) + + t.Run("success,AlterColumnDataType", func(t *testing.T) { + t.Parallel() + + stmt := &AlterTableStmt{ + Name: &ObjectName{Name: &Ident{Name: "users", QuotationMark: `"`, Raw: `"users"`}}, + Action: &AlterColumn{ + Name: &Ident{Name: "age", QuotationMark: `"`, Raw: `"age"`}, + Action: &AlterColumnDataType{DataType: &DataType{Name: "INT64"}}, + }, + } + + expected := `ALTER TABLE "users" ALTER COLUMN "age" INT64;` + "\n" + actual := stmt.String() + + if !assert.Equal(t, expected, actual) { + assert.Equal(t, fmt.Sprintf("%#v", expected), fmt.Sprintf("%#v", actual)) + } + t.Logf("✅: %s: stmt: %#v", t.Name(), stmt) + }) + + t.Run("success,AlterColumnSetDefault", func(t *testing.T) { + t.Parallel() + + stmt := &AlterTableStmt{ + Name: &ObjectName{Name: &Ident{Name: "users", QuotationMark: `"`, Raw: `"users"`}}, + Action: &AlterColumn{ + Name: &Ident{Name: "age", QuotationMark: `"`, Raw: `"age"`}, + Action: &AlterColumnSetDefault{Default: &Default{Value: &Expr{[]*Ident{{Name: "0", Raw: "0"}}}}}, + }, + } + + expected := `ALTER TABLE "users" ALTER COLUMN "age" SET DEFAULT 0;` + "\n" + actual := stmt.String() + + if !assert.Equal(t, expected, actual) { + assert.Equal(t, fmt.Sprintf("%#v", expected), fmt.Sprintf("%#v", actual)) + } + t.Logf("✅: %s: stmt: %#v", t.Name(), stmt) + }) + + t.Run("success,AlterColumnDropDefault", func(t *testing.T) { + t.Parallel() + + stmt := &AlterTableStmt{ + Name: &ObjectName{Name: &Ident{Name: "users", QuotationMark: `"`, Raw: `"users"`}}, + Action: &AlterColumn{ + Name: &Ident{Name: "age", QuotationMark: `"`, Raw: `"age"`}, + Action: &AlterColumnDropDefault{}, + }, + } + + expected := `ALTER TABLE "users" ALTER COLUMN "age" DROP DEFAULT;` + "\n" + actual := stmt.String() + + if !assert.Equal(t, expected, actual) { + assert.Equal(t, fmt.Sprintf("%#v", expected), fmt.Sprintf("%#v", actual)) + } + t.Logf("✅: %s: stmt: %#v", t.Name(), stmt) + }) + + t.Run("success,AddConstraint", func(t *testing.T) { + t.Parallel() + + stmt := &AlterTableStmt{ + Name: &ObjectName{Name: &Ident{Name: "groups", QuotationMark: `"`, Raw: `"groups"`}}, + Action: &AddConstraint{ + Constraint: &CheckConstraint{ + Name: &Ident{Name: "groups_yyyymmdd_chk", QuotationMark: `"`, Raw: `"groups_yyyymmdd_chk"`}, + Expr: &Expr{Idents: []*Ident{ + NewRawIdent("("), + NewRawIdent(`"yyyymmdd"`), + NewRawIdent(">"), + NewRawIdent("0"), + NewRawIdent(")"), + }}, + }, + }, + } + + expected := `ALTER TABLE "groups" ADD CONSTRAINT "groups_yyyymmdd_chk" CHECK ("yyyymmdd" > 0);` + "\n" + actual := stmt.String() + + if !assert.Equal(t, expected, actual) { + assert.Equal(t, fmt.Sprintf("%#v", expected), fmt.Sprintf("%#v", actual)) + } + t.Logf("✅: %s: stmt: %#v", t.Name(), stmt) + }) + + t.Run("success,DropConstraint", func(t *testing.T) { + t.Parallel() + + stmt := &AlterTableStmt{ + Name: &ObjectName{Name: &Ident{Name: "groups", QuotationMark: `"`, Raw: `"groups"`}}, + Action: &DropConstraint{Name: &Ident{Name: "groups_pkey", QuotationMark: `"`, Raw: `"groups_pkey"`}}, + } + + expected := `ALTER TABLE "groups" DROP CONSTRAINT "groups_pkey";` + "\n" + actual := stmt.String() + + if !assert.Equal(t, expected, actual) { + assert.Equal(t, fmt.Sprintf("%#v", expected), fmt.Sprintf("%#v", actual)) + } + t.Logf("✅: %s: stmt: %#v", t.Name(), stmt) + }) + + t.Run("success,AlterConstraint,DEFERRABLE", func(t *testing.T) { + t.Parallel() + + stmt := &AlterTableStmt{ + Name: &ObjectName{Name: &Ident{Name: "groups", QuotationMark: `"`, Raw: `"groups"`}}, + Action: &AlterConstraint{ + Name: &Ident{Name: "groups_pkey", QuotationMark: `"`, Raw: `"groups_pkey"`}, + Deferrable: true, + InitiallyDeferred: true, + }, + } + + expected := `ALTER TABLE "groups" ALTER CONSTRAINT "groups_pkey" DEFERRABLE INITIALLY DEFERRED;` + "\n" + actual := stmt.String() + + if !assert.Equal(t, expected, actual) { + assert.Equal(t, fmt.Sprintf("%#v", expected), fmt.Sprintf("%#v", actual)) + } + t.Logf("✅: %s: stmt: %#v", t.Name(), stmt) + }) + + t.Run("success,AlterConstraint,NOT_DEFERRABLE", func(t *testing.T) { + t.Parallel() + + stmt := &AlterTableStmt{ + Name: &ObjectName{Name: &Ident{Name: "groups", QuotationMark: `"`, Raw: `"groups"`}}, + Action: &AlterConstraint{ + Name: &Ident{Name: "groups_pkey", QuotationMark: `"`, Raw: `"groups_pkey"`}, + Deferrable: false, + InitiallyDeferred: false, + }, + } + + expected := `ALTER TABLE "groups" ALTER CONSTRAINT "groups_pkey" NOT DEFERRABLE INITIALLY IMMEDIATE;` + "\n" + actual := stmt.String() + + if !assert.Equal(t, expected, actual) { + assert.Equal(t, fmt.Sprintf("%#v", expected), fmt.Sprintf("%#v", actual)) + } + t.Logf("✅: %s: stmt: %#v", t.Name(), stmt) + }) +} + +func TestAlterTableStmt_GetNameForDiff(t *testing.T) { + t.Parallel() + + stmt := &AlterTableStmt{Name: &ObjectName{Name: &Ident{Name: "users", QuotationMark: `"`, Raw: `"users"`}}} + + expected := `users` + actual := stmt.GetNameForDiff() + + require.Equal(t, expected, actual) +} diff --git a/pkg/ddl/spanner/ddl_table_create.go b/pkg/ddl/spanner/ddl_table_create.go new file mode 100644 index 0000000..5d399ff --- /dev/null +++ b/pkg/ddl/spanner/ddl_table_create.go @@ -0,0 +1,83 @@ +package spanner + +import ( + "strings" + + "github.com/kunitsucom/ddlctl/pkg/ddl/internal" +) + +// MEMO: https://cloud.google.com/spanner/docs/reference/standard-sql/data-definition-language#create_table + +var _ Stmt = (*CreateTableStmt)(nil) + +type CreateTableStmt struct { + Comment string + Indent string + IfNotExists bool + Name *ObjectName + Columns []*Column + Constraints Constraints + Options Options +} + +func (s *CreateTableStmt) GetNameForDiff() string { + return s.Name.StringForDiff() +} + +//nolint:cyclop +func (s *CreateTableStmt) String() string { + var str string + if s.Comment != "" { + comments := strings.Split(s.Comment, "\n") + for i := range comments { + if comments[i] != "" { + str += CommentPrefix + comments[i] + "\n" + } + } + } + str += "CREATE TABLE " + if s.IfNotExists { + str += "IF NOT EXISTS " + } + str += s.Name.String() + " (\n" + lastIndex := len(s.Columns) - 1 + hasConstraint := len(s.Constraints) > 0 + for i, v := range s.Columns { + str += Indent + str += v.String() + if i != lastIndex || hasConstraint { + str += ",\n" + } else { + str += "\n" + } + } + if len(s.Constraints) > 0 { + lastConstraint := len(s.Constraints) - 1 + for i, v := range s.Constraints { + str += Indent + str += v.String() + if i != lastConstraint { + str += ",\n" + } else { + str += "\n" + } + } + } + str += ")" + if len(s.Options) > 0 { + str += " " + lastIndex := len(s.Options) - 1 + for i, v := range s.Options { + str += v.String() + if i != lastIndex { + str += ",\n" + } + } + } + + str += ";\n" + return str +} + +func (*CreateTableStmt) isStmt() {} +func (s *CreateTableStmt) GoString() string { return internal.GoString(*s) } diff --git a/pkg/ddl/spanner/ddl_table_create_test.go b/pkg/ddl/spanner/ddl_table_create_test.go new file mode 100644 index 0000000..b76b831 --- /dev/null +++ b/pkg/ddl/spanner/ddl_table_create_test.go @@ -0,0 +1,67 @@ +package spanner + +import ( + "testing" + + "github.com/kunitsucom/util.go/testing/assert" +) + +func TestCreateTableStmt_String(t *testing.T) { + t.Parallel() + + t.Run("success,", func(t *testing.T) { + t.Parallel() + + stmt := &CreateTableStmt{ + Comment: "test comment content", + Indent: " ", + Name: &ObjectName{Name: &Ident{Name: "test", Raw: "test"}}, + Columns: []*Column{ + {Name: &Ident{Name: "id", Raw: "id"}, DataType: &DataType{Name: "INT64"}}, + {Name: &Ident{Name: "name", Raw: "name"}, DataType: &DataType{Name: "STRING", Expr: &Expr{Idents: []*Ident{NewRawIdent("255")}}}}, + {Name: &Ident{Name: "created_at", Raw: "created_at"}, DataType: &DataType{Name: "TIMESTAMP"}, NotNull: true, Options: &Expr{Idents: []*Ident{ + NewRawIdent("("), + NewRawIdent("allow_commit_timestamp"), + NewRawIdent("="), + NewRawIdent("true"), + NewRawIdent(","), + NewRawIdent("option_name"), + NewRawIdent("="), + NewRawIdent("null"), + NewRawIdent(")"), + }}}, + }, + Options: []*Option{ + {Name: "PRIMARY KEY", Value: &Expr{Idents: []*Ident{NewRawIdent("("), NewRawIdent("id"), NewRawIdent(")")}}}, + }, + } + expected := `-- test comment content +CREATE TABLE test ( + id INT64, + name STRING(255), + created_at TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true, option_name = null) +) PRIMARY KEY (id); +` + + actual := stmt.String() + assert.Equal(t, expected, actual) + + t.Logf("✅: %s: stmt: %#v", t.Name(), stmt) + }) +} + +func TestCreateTableStmt_GetNameForDiff(t *testing.T) { + t.Parallel() + + t.Run("success,", func(t *testing.T) { + t.Parallel() + + stmt := &CreateTableStmt{Name: &ObjectName{Name: &Ident{Name: "test", QuotationMark: `"`, Raw: `"test"`}}} + expected := "test" + actual := stmt.GetNameForDiff() + + assert.Equal(t, expected, actual) + + t.Logf("✅: %s: stmt: %#v", t.Name(), stmt) + }) +} diff --git a/pkg/ddl/spanner/ddl_table_drop.go b/pkg/ddl/spanner/ddl_table_drop.go new file mode 100644 index 0000000..1f979ae --- /dev/null +++ b/pkg/ddl/spanner/ddl_table_drop.go @@ -0,0 +1,42 @@ +package spanner + +import ( + "strings" + + "github.com/kunitsucom/ddlctl/pkg/ddl/internal" +) + +// MEMO: https://www.postgresql.jp/docs/11/sql-createtable.html + +var _ Stmt = (*DropTableStmt)(nil) + +type DropTableStmt struct { + Comment string + IfExists bool + Name *ObjectName +} + +func (s *DropTableStmt) GetNameForDiff() string { + return s.Name.StringForDiff() +} + +func (s *DropTableStmt) String() string { + var str string + if s.Comment != "" { + comments := strings.Split(s.Comment, "\n") + for i := range comments { + if comments[i] != "" { + str += CommentPrefix + comments[i] + "\n" + } + } + } + str += "DROP TABLE " + if s.IfExists { + str += "IF EXISTS " + } + str += s.Name.String() + ";\n" + return str +} + +func (*DropTableStmt) isStmt() {} +func (s *DropTableStmt) GoString() string { return internal.GoString(*s) } diff --git a/pkg/ddl/spanner/ddl_table_drop_test.go b/pkg/ddl/spanner/ddl_table_drop_test.go new file mode 100644 index 0000000..13d9bf8 --- /dev/null +++ b/pkg/ddl/spanner/ddl_table_drop_test.go @@ -0,0 +1,44 @@ +package spanner + +import ( + "testing" + + "github.com/kunitsucom/util.go/testing/require" +) + +func TestDropTableStmt_GetNameForDiff(t *testing.T) { + t.Parallel() + + t.Run("success,", func(t *testing.T) { + t.Parallel() + + stmt := &DropTableStmt{Name: &ObjectName{Name: &Ident{Name: "test", QuotationMark: `"`, Raw: `"test"`}}} + expected := "test" + actual := stmt.GetNameForDiff() + + require.Equal(t, expected, actual) + + t.Logf("✅: %s: stmt: %#v", t.Name(), stmt) + }) +} + +func TestDropTableStmt_String(t *testing.T) { + t.Parallel() + + t.Run("success,", func(t *testing.T) { + t.Parallel() + + stmt := &DropTableStmt{ + Comment: "test comment content", + IfExists: true, + Name: &ObjectName{Name: &Ident{Name: "test", Raw: "test"}}, + } + expected := `-- test comment content +DROP TABLE IF EXISTS test;` + "\n" + actual := stmt.String() + + require.Equal(t, expected, actual) + + t.Logf("✅: %s: stmt: %#v", t.Name(), stmt) + }) +} diff --git a/pkg/ddl/spanner/ddl_table_test.go b/pkg/ddl/spanner/ddl_table_test.go new file mode 100644 index 0000000..103e799 --- /dev/null +++ b/pkg/ddl/spanner/ddl_table_test.go @@ -0,0 +1,233 @@ +package spanner + +import ( + "testing" + + "github.com/kunitsucom/util.go/testing/require" +) + +func Test_isConstraint(t *testing.T) { + t.Parallel() + + (&ForeignKeyConstraint{}).isConstraint() + (&CheckConstraint{}).isConstraint() +} + +func TestConstraints_Append(t *testing.T) { + t.Parallel() + + t.Run("success,Constraints,Append", func(t *testing.T) { + t.Parallel() + + constraints := Constraints{} + constraint := &CheckConstraint{ + Name: NewRawIdent(`"users_age_check"`), + Expr: &Expr{Idents: []*Ident{ + {Name: "(", QuotationMark: ``, Raw: `(`}, + {Name: "age", QuotationMark: `"`, Raw: `"age"`}, + {Name: ">=", QuotationMark: ``, Raw: `>=`}, + {Name: "0", QuotationMark: ``, Raw: `0`}, + {Name: ")", QuotationMark: ``, Raw: `)`}, + }}, + } + constraints = constraints.Append(constraint) + constraints = constraints.Append(constraint) + expected := Constraints{constraint} + actual := constraints + require.Equal(t, expected, actual) + }) +} + +func TestForeignKeyConstraint(t *testing.T) { + t.Parallel() + t.Run("success,ForeignKeyConstraint", func(t *testing.T) { + t.Parallel() + + foreignKeyConstraint := &ForeignKeyConstraint{ + Name: &Ident{Name: "fk_users_groups", QuotationMark: `"`, Raw: `"fk_users_groups"`}, + Columns: []*ColumnIdent{{Ident: &Ident{Name: "group_id", QuotationMark: `"`, Raw: `"group_id"`}}}, + Ref: &Ident{Name: "groups", QuotationMark: `"`, Raw: `"groups"`}, + RefColumns: []*ColumnIdent{{Ident: &Ident{Name: "id", QuotationMark: `"`, Raw: `"id"`}}}, + } + + expected := `CONSTRAINT "fk_users_groups" FOREIGN KEY ("group_id") REFERENCES "groups" ("id")` + actual := foreignKeyConstraint.String() + require.Equal(t, expected, actual) + + t.Logf("✅: %s: foreignKeyConstraint: %#v", t.Name(), foreignKeyConstraint) + }) +} + +func TestCheckConstraint(t *testing.T) { + t.Parallel() + t.Run("success,CheckConstraint", func(t *testing.T) { + t.Parallel() + + checkConstraint := &CheckConstraint{ + Name: &Ident{Name: "users_check_age", QuotationMark: `"`, Raw: `"users_check_age"`}, + Expr: &Expr{Idents: []*Ident{{Name: "(", QuotationMark: ``, Raw: `(`}, {Name: "age", QuotationMark: `"`, Raw: `"age"`}, {Name: ">=", QuotationMark: ``, Raw: `>=`}, {Name: "0", QuotationMark: ``, Raw: `0`}, {Name: ")", QuotationMark: ``, Raw: `)`}}}, + } + + expected := `CONSTRAINT "users_check_age" CHECK ("age" >= 0)` + actual := checkConstraint.String() + require.Equal(t, expected, actual) + + t.Logf("✅: %s: checkConstraint: %#v", t.Name(), checkConstraint) + }) +} + +func TestObjectName_StringForDiff(t *testing.T) { + t.Parallel() + + t.Run("success,ObjectName", func(t *testing.T) { + t.Parallel() + + objectName := &ObjectName{Schema: &Ident{Name: "public", QuotationMark: `"`, Raw: `"public"`}, Name: &Ident{Name: "users", QuotationMark: `"`, Raw: `"users"`}} + expected := "public.users" + actual := objectName.StringForDiff() + require.Equal(t, expected, actual) + + t.Logf("✅: %s: objectName: %#v", t.Name(), objectName) + }) + t.Run("success,ObjectName,empty", func(t *testing.T) { + t.Parallel() + + objectName := (*ObjectName)(nil) + expected := "" + actual := objectName.StringForDiff() + require.Equal(t, expected, actual) + + t.Logf("✅: %s: objectName: %#v", t.Name(), objectName) + }) +} + +func TestExpr_String(t *testing.T) { + t.Parallel() + + t.Run("success,String,nil", func(t *testing.T) { + t.Parallel() + + d := (*Default)(nil) + expected := "" + actual := d.String() + require.Equal(t, expected, actual) + }) + t.Run("success,String,nilnil", func(t *testing.T) { + t.Parallel() + + d := &Default{} + expected := "" + actual := d.String() + require.Equal(t, expected, actual) + }) + t.Run("success,PlainString,nilnil", func(t *testing.T) { + t.Parallel() + + d := &Default{} + expected := "" + actual := d.StringForDiff() + require.Equal(t, expected, actual) + }) + t.Run("success,DEFAULT_VALUE", func(t *testing.T) { + t.Parallel() + + d := &Default{Value: &Expr{[]*Ident{{Name: "now()", Raw: "now()"}}}} + expected := "DEFAULT now()" + actual := d.String() + require.Equal(t, expected, actual) + + t.Logf("✅: %s: d: %#v", t.Name(), d) + }) + t.Run("success,DEFAULT_VALUE,empty", func(t *testing.T) { + t.Parallel() + + d := (*Expr)(nil) + expected := "" + actual := d.String() + require.Equal(t, expected, actual) + }) + t.Run("success,DEFAULT_EXPR", func(t *testing.T) { + t.Parallel() + + d := &Default{Value: &Expr{[]*Ident{{Name: "(", Raw: "("}, {Name: "age", Raw: "age"}, {Name: ">=", Raw: ">="}, {Name: "0", Raw: "0"}, {Name: ")", Raw: ")"}}}} + expected := "DEFAULT (age >= 0)" + actual := d.String() + require.Equal(t, expected, actual) + + t.Logf("✅: %s: d: %#v", t.Name(), d) + }) +} + +func TestColumn(t *testing.T) { + t.Parallel() + + t.Run("success,Column", func(t *testing.T) { + t.Parallel() + + column := &Column{ + Name: &Ident{Name: "id", QuotationMark: `"`, Raw: `"id"`}, + DataType: &DataType{Name: "INTEGER"}, + } + + expected := `"id" INTEGER` + actual := column.String() + require.Equal(t, expected, actual) + + t.Logf("✅: %s: column: %#v", t.Name(), column) + }) +} + +func TestOption(t *testing.T) { + t.Parallel() + + t.Run("success,Option", func(t *testing.T) { + t.Parallel() + + option := &Option{Name: "PRIMARY KEY", Value: &Expr{Idents: []*Ident{NewRawIdent("("), NewRawIdent(`"id1"`), NewRawIdent(`,`), NewRawIdent(`"id2"`), NewRawIdent(")")}}} + + expected := `PRIMARY KEY ("id1", "id2")` + actual := option.String() + require.Equal(t, expected, actual) + + expectedForDiff := `PRIMARY KEY ( id1 , id2 )` + actualForDiff := option.StringForDiff() + require.Equal(t, expectedForDiff, actualForDiff) + + t.Logf("✅: %s: option: %#v", t.Name(), option) + }) + + t.Run("success,Options", func(t *testing.T) { + t.Parallel() + + options := Options{ + &Option{Name: "PRIMARY KEY", Value: &Expr{Idents: []*Ident{NewRawIdent("("), NewRawIdent(`"id1"`), NewRawIdent(`,`), NewRawIdent(`"id2"`), NewRawIdent(")")}}}, + &Option{Name: "PRIMARY KEY", Value: &Expr{Idents: []*Ident{NewRawIdent("("), NewRawIdent(`"id1"`), NewRawIdent(`,`), NewRawIdent(`"id2"`), NewRawIdent(")")}}}, + } + + expected := `PRIMARY KEY ("id1", "id2"), +PRIMARY KEY ("id1", "id2")` + actual := options.String() + require.Equal(t, expected, actual) + + expectedForDiff := `PRIMARY KEY ( id1 , id2 ), PRIMARY KEY ( id1 , id2 )` + actualForDiff := options.StringForDiff() + require.Equal(t, expectedForDiff, actualForDiff) + + t.Logf("✅: %s: option: %#v", t.Name(), options) + }) + + t.Run("success,Option,empty", func(t *testing.T) { + t.Parallel() + + option := &Option{} + expectedString := "" + actualString := option.String() + require.Equal(t, expectedString, actualString) + + expectedStringForDiff := "" + actualStringForDiff := option.StringForDiff() + require.Equal(t, expectedStringForDiff, actualStringForDiff) + + t.Logf("✅: %s: option: %#v", t.Name(), option) + }) +} diff --git a/pkg/ddl/spanner/ddl_test.go b/pkg/ddl/spanner/ddl_test.go new file mode 100644 index 0000000..e3c3576 --- /dev/null +++ b/pkg/ddl/spanner/ddl_test.go @@ -0,0 +1,107 @@ +package spanner + +import ( + "testing" + + "github.com/kunitsucom/util.go/testing/require" +) + +func Test_isStmt(t *testing.T) { + t.Parallel() + + (&CreateTableStmt{}).isStmt() + (&DropTableStmt{}).isStmt() + (&AlterTableStmt{}).isStmt() + (&CreateIndexStmt{}).isStmt() + (&DropIndexStmt{}).isStmt() +} + +func TestIdent_String(t *testing.T) { + t.Parallel() + + t.Run("success", func(t *testing.T) { + t.Parallel() + + ident := &Ident{Name: "users", QuotationMark: `"`, Raw: `"users"`} + expected := ident.Raw + actual := ident.String() + + require.Equal(t, expected, actual) + + t.Logf("✅: %s: ident: %#v", t.Name(), ident) + }) + + t.Run("success,empty", func(t *testing.T) { + t.Parallel() + + ident := (*Ident)(nil) + expected := "" + actual := ident.String() + + require.Equal(t, expected, actual) + + t.Logf("✅: %s: ident: %#v", t.Name(), ident) + }) +} + +func TestIdent_StringForDiff(t *testing.T) { + t.Parallel() + + t.Run("success", func(t *testing.T) { + t.Parallel() + ident := &Ident{Name: "users", QuotationMark: `"`, Raw: `"users"`} + expected := ident.Name + actual := ident.StringForDiff() + + require.Equal(t, expected, actual) + }) + + t.Run("success,empty", func(t *testing.T) { + t.Parallel() + ident := (*Ident)(nil) + expected := "" + actual := ident.StringForDiff() + + require.Equal(t, expected, actual) + }) +} + +func TestDataType_StringForDiff(t *testing.T) { + t.Parallel() + + t.Run("success", func(t *testing.T) { + t.Parallel() + dataType := &DataType{Name: "INT64", Type: TOKEN_INT64, Expr: &Expr{Idents: []*Ident{}}} + expected := string(TOKEN_INT64) + actual := dataType.StringForDiff() + + require.Equal(t, expected, actual) + }) + + t.Run("success,nil", func(t *testing.T) { + t.Parallel() + dataType := (*DataType)(nil) + expected := "" + actual := dataType.StringForDiff() + + require.Equal(t, expected, actual) + }) + + t.Run("success,TOKEN_ILLEGAL", func(t *testing.T) { + t.Parallel() + dataType := &DataType{Name: "unknown", Type: TOKEN_ILLEGAL, Expr: &Expr{Idents: []*Ident{}}} + expected := string(TOKEN_ILLEGAL) + actual := dataType.StringForDiff() + + require.Equal(t, expected, actual) + }) + + t.Run("success,empty", func(t *testing.T) { + t.Parallel() + dataType := &DataType{Name: "unknown", Type: "", Expr: &Expr{Idents: []*Ident{}}} + expected := string(TOKEN_ILLEGAL) + actual := dataType.StringForDiff() + + require.Equal(t, expected, actual) + }) +} diff --git a/pkg/ddl/spanner/diff.go b/pkg/ddl/spanner/diff.go new file mode 100644 index 0000000..e960104 --- /dev/null +++ b/pkg/ddl/spanner/diff.go @@ -0,0 +1,126 @@ +package spanner + +import ( + "reflect" + + errorz "github.com/kunitsucom/util.go/errors" + "github.com/kunitsucom/util.go/exp/diff/simplediff" + + apperr "github.com/kunitsucom/ddlctl/pkg/apperr" + + "github.com/kunitsucom/ddlctl/pkg/ddl" +) + +//nolint:funlen,cyclop,gocognit +func Diff(before, after *DDL) (*DDL, error) { + result := &DDL{} + + switch { + case before == nil && after != nil: + result.Stmts = append(result.Stmts, after.Stmts...) + return result, nil + case before != nil && after == nil: + for _, stmt := range before.Stmts { + switch s := stmt.(type) { + case *CreateTableStmt: + result.Stmts = append(result.Stmts, &DropTableStmt{ + Name: s.Name, + }) + case *CreateIndexStmt: + result.Stmts = append(result.Stmts, &DropIndexStmt{ + Name: s.Name, + }) + default: + return nil, apperr.Errorf("%s: %T: %w", s.GetNameForDiff(), s, ddl.ErrNotSupported) + } + } + return result, nil + case (before == nil && after == nil) || reflect.DeepEqual(before, after) || before.String() == after.String(): + return nil, ddl.ErrNoDifference + } + + // DROP TABLE table_name; + for _, stmt := range onlyLeftStmt(before, after) { + switch beforeStmt := stmt.(type) { + case *CreateTableStmt: + result.Stmts = append(result.Stmts, &DropTableStmt{ + Name: beforeStmt.Name, + }) + case *CreateIndexStmt: + result.Stmts = append(result.Stmts, &DropIndexStmt{ + Name: beforeStmt.Name, + }) + default: + return nil, apperr.Errorf("%s: %T: %w", beforeStmt.GetNameForDiff(), beforeStmt, ddl.ErrNotSupported) + } + } + + // CREATE TABLE table_name + for _, stmt := range onlyLeftStmt(after, before) { + switch afterStmt := stmt.(type) { + case *CreateTableStmt: + result.Stmts = append(result.Stmts, afterStmt) + case *CreateIndexStmt: + result.Stmts = append(result.Stmts, afterStmt) + default: + return nil, apperr.Errorf("%s: %T: %w", afterStmt.GetNameForDiff(), afterStmt, ddl.ErrNotSupported) + } + } + + // ALTER TABLE table_name ... + // DROP INDEX index_name; CREATE INDEX index_name ... + for _, beforeStmt := range before.Stmts { + switch beforeStmt := beforeStmt.(type) { //nolint:gocritic + case *CreateTableStmt: + if afterStmt := findStmtByTypeAndName(beforeStmt, after.Stmts); afterStmt != nil { + afterStmt := afterStmt.(*CreateTableStmt) //nolint:forcetypeassert + alterStmt, err := DiffCreateTable(beforeStmt, afterStmt) + if err == nil { + result.Stmts = append(result.Stmts, alterStmt.Stmts...) + } + errorz.PanicOrIgnore(err, ddl.ErrNoDifference) // MEMO: If before and after table_name is match, DiffCreateTable does not return error except ddl.ErrNoDifference. + continue + } + case *CreateIndexStmt: + if afterStmt := findStmtByTypeAndName(beforeStmt, after.Stmts); afterStmt != nil { + afterStmt := afterStmt.(*CreateIndexStmt) //nolint:forcetypeassert + if beforeStmt.StringForDiff() != afterStmt.StringForDiff() { + result.Stmts = append(result.Stmts, + &DropIndexStmt{ + Comment: simplediff.Diff(beforeStmt.StringForDiff(), afterStmt.StringForDiff()).String(), + Name: beforeStmt.Name, + }, + afterStmt, + ) + } + } + } + } + + if len(result.Stmts) == 0 { + return nil, ddl.ErrNoDifference + } + + return result, nil +} + +func onlyLeftStmt(left, right *DDL) []Stmt { + result := make([]Stmt, 0) + + for _, stmt := range left.Stmts { + if findStmtByTypeAndName(stmt, right.Stmts) == nil { + result = append(result, stmt) + } + } + + return result +} + +func findStmtByTypeAndName(stmt Stmt, stmts []Stmt) Stmt { //nolint:ireturn + for _, s := range stmts { + if reflect.TypeOf(stmt) == reflect.TypeOf(s) && stmt.GetNameForDiff() == s.GetNameForDiff() { + return s + } + } + return nil +} diff --git a/pkg/ddl/spanner/diff_create_table.go b/pkg/ddl/spanner/diff_create_table.go new file mode 100644 index 0000000..4f53ec0 --- /dev/null +++ b/pkg/ddl/spanner/diff_create_table.go @@ -0,0 +1,278 @@ +package spanner + +import ( + "reflect" + + "github.com/kunitsucom/util.go/exp/diff/simplediff" + + apperr "github.com/kunitsucom/ddlctl/pkg/apperr" + "github.com/kunitsucom/ddlctl/pkg/ddl" +) + +type DiffCreateTableConfig struct { + UseAlterTableAddConstraintNotValid bool +} + +type DiffCreateTableOption interface { + apply(c *DiffCreateTableConfig) +} + +func DiffCreateTableUseAlterTableAddConstraintNotValid(notValid bool) DiffCreateTableOption { //nolint:ireturn + return &diffCreateTableConfigUseConstraintNotValid{ + useAlterTableAddConstraintNotValid: notValid, + } +} + +type diffCreateTableConfigUseConstraintNotValid struct { + useAlterTableAddConstraintNotValid bool +} + +func (o *diffCreateTableConfigUseConstraintNotValid) apply(c *DiffCreateTableConfig) { + c.UseAlterTableAddConstraintNotValid = o.useAlterTableAddConstraintNotValid +} + +//nolint:funlen,cyclop +func DiffCreateTable(before, after *CreateTableStmt, opts ...DiffCreateTableOption) (*DDL, error) { + config := &DiffCreateTableConfig{} + + for _, opt := range opts { + opt.apply(config) + } + + result := &DDL{} + + switch { + case before == nil && after != nil: + // CREATE TABLE table_name + result.Stmts = append(result.Stmts, after) + return result, nil + case before != nil && after == nil: + // DROP TABLE table_name; + result.Stmts = append(result.Stmts, &DropTableStmt{ + Name: before.Name, + }) + return result, nil + case before.Options.StringForDiff() != after.Options.StringForDiff(): + result.Stmts = append(result.Stmts, + &DropTableStmt{ + Comment: simplediff.Diff(before.Options.String(), after.Options.String()).String(), + Name: before.Name, + }, + after, + ) + return result, nil + case (before == nil && after == nil) || reflect.DeepEqual(before, after) || before.String() == after.String(): + return nil, apperr.Errorf("before: %s, after: %s: %w", before.GetNameForDiff(), after.GetNameForDiff(), ddl.ErrNoDifference) + } + + if before.Name.StringForDiff() != after.Name.StringForDiff() { + // ALTER TABLE table_name RENAME TO new_table_name; + rename := &RenameTable{ + NewName: after.Name, + } + if rename.NewName.Schema == nil { + rename.NewName.Schema = before.Name.Schema + } + result.Stmts = append(result.Stmts, &AlterTableStmt{ + Comment: simplediff.Diff(before.Name.StringForDiff(), after.Name.StringForDiff()).String(), + Name: before.Name, + Action: rename, + }) + } + + for _, beforeConstraint := range before.Constraints { + afterConstraint := findConstraintByName(beforeConstraint.GetName().Name, after.Constraints) + if afterConstraint == nil { + // ALTER TABLE table_name DROP CONSTRAINT constraint_name; + result.Stmts = append(result.Stmts, &AlterTableStmt{ + Comment: simplediff.Diff(beforeConstraint.String(), "").String(), + Name: after.Name, // ALTER TABLE RENAME TO で変更された後の可能性があるため after.Name を使用する + Action: &DropConstraint{ + Name: beforeConstraint.GetName(), + }, + }) + continue + } + } + + config.diffCreateTableColumn(result, before, after) + + for _, beforeConstraint := range before.Constraints { + afterConstraint := findConstraintByName(beforeConstraint.GetName().Name, after.Constraints) + if afterConstraint != nil { + if beforeConstraint.StringForDiff() != afterConstraint.StringForDiff() { + // ALTER TABLE table_name DROP CONSTRAINT constraint_name; + // ALTER TABLE table_name ADD CONSTRAINT constraint_name constraint; + result.Stmts = append( + result.Stmts, + &AlterTableStmt{ + Comment: simplediff.Diff(beforeConstraint.String(), "").String(), + Name: after.Name, // ALTER TABLE RENAME TO で変更された後の可能性があるため after.Name を使用する + Action: &DropConstraint{ + Name: beforeConstraint.GetName(), + }, + }, + &AlterTableStmt{ + Comment: simplediff.Diff("", afterConstraint.String()).String(), + Name: after.Name, + Action: &AddConstraint{ + Constraint: afterConstraint, + NotValid: config.UseAlterTableAddConstraintNotValid, + }, + }, + ) + } + continue + } + } + + for _, afterConstraint := range onlyLeftConstraint(after.Constraints, before.Constraints) { + // ALTER TABLE table_name ADD CONSTRAINT constraint_name constraint; + result.Stmts = append(result.Stmts, &AlterTableStmt{ + Comment: simplediff.Diff("", afterConstraint.String()).String(), + Name: after.Name, + Action: &AddConstraint{ + Constraint: afterConstraint, + NotValid: config.UseAlterTableAddConstraintNotValid, + }, + }) + } + + if len(result.Stmts) == 0 { + return nil, apperr.Errorf("before: %s, after: %s: %w", before.GetNameForDiff(), after.GetNameForDiff(), ddl.ErrNoDifference) + } + + return result, nil +} + +//nolint:funlen,cyclop +func (config *DiffCreateTableConfig) diffCreateTableColumn(ddls *DDL, before, after *CreateTableStmt) { + for _, beforeColumn := range before.Columns { + afterColumn := findColumnByName(beforeColumn.Name.Name, after.Columns) + if afterColumn == nil { + // ALTER TABLE table_name DROP COLUMN column_name; + ddls.Stmts = append(ddls.Stmts, &AlterTableStmt{ + Comment: simplediff.Diff(beforeColumn.String(), "").String(), + Name: after.Name, // ALTER TABLE RENAME TO で変更された後の可能性があるため after.Name を使用する + Action: &DropColumn{ + Name: beforeColumn.Name, + }, + }) + continue + } + + if beforeColumn.DataType.StringForDiff() != afterColumn.DataType.StringForDiff() || + beforeColumn.NotNull && !afterColumn.NotNull || + !beforeColumn.NotNull && afterColumn.NotNull { + // ALTER TABLE table_name ALTER COLUMN column_name data_type NOT NULL; + ddls.Stmts = append(ddls.Stmts, &AlterTableStmt{ + Comment: simplediff.Diff(beforeColumn.String(), afterColumn.String()).String(), + Name: after.Name, + Action: &AlterColumn{ + Name: afterColumn.Name, + Action: &AlterColumnDataType{ + DataType: afterColumn.DataType, + NotNull: afterColumn.NotNull, + }, + }, + }) + } + + switch { + case beforeColumn.Default != nil && afterColumn.Default == nil: + // ALTER TABLE table_name ALTER COLUMN column_name DROP DEFAULT; + ddls.Stmts = append(ddls.Stmts, &AlterTableStmt{ + Comment: simplediff.Diff(beforeColumn.String(), afterColumn.String()).String(), + Name: after.Name, + Action: &AlterColumn{ + Name: afterColumn.Name, + Action: &AlterColumnDropDefault{}, + }, + }) + case afterColumn.Default != nil && beforeColumn.Default.StringForDiff() != afterColumn.Default.StringForDiff(): + // ALTER TABLE table_name ALTER COLUMN column_name SET DEFAULT default_value; + ddls.Stmts = append(ddls.Stmts, &AlterTableStmt{ + Comment: simplediff.Diff(beforeColumn.String(), afterColumn.String()).String(), + Name: after.Name, + Action: &AlterColumn{ + Name: afterColumn.Name, + Action: &AlterColumnSetDefault{Default: afterColumn.Default}, + }, + }) + } + + switch { + case beforeColumn.Options != nil && afterColumn.Options == nil: + // ALTER TABLE table_name ALTER COLUMN column_name DROP OPTIONS; + ddls.Stmts = append(ddls.Stmts, &AlterTableStmt{ + Comment: simplediff.Diff(beforeColumn.String(), afterColumn.String()).String(), + Name: after.Name, + Action: &AlterColumn{ + Name: afterColumn.Name, + Action: &AlterColumnDropOptions{}, + }, + }) + case afterColumn.Options != nil && beforeColumn.Options.StringForDiff() != afterColumn.Options.StringForDiff(): + // ALTER TABLE table_name ALTER COLUMN column_name SET OPTIONS (option_name = option_value); + ddls.Stmts = append(ddls.Stmts, &AlterTableStmt{ + Comment: simplediff.Diff(beforeColumn.String(), afterColumn.String()).String(), + Name: after.Name, + Action: &AlterColumn{ + Name: afterColumn.Name, + Action: &AlterColumnSetOptions{Options: afterColumn.Options}, + }, + }) + } + } + + for _, afterColumn := range onlyLeftColumn(after.Columns, before.Columns) { + // ALTER TABLE table_name ADD COLUMN column_name data_type; + ddls.Stmts = append(ddls.Stmts, &AlterTableStmt{ + Comment: simplediff.Diff("", afterColumn.String()).String(), + Name: after.Name, + Action: &AddColumn{ + Column: afterColumn, + }, + }) + } +} + +func onlyLeftColumn(left, right []*Column) []*Column { + onlyLeftColumns := make([]*Column, 0) + for _, leftColumn := range left { + foundColumnByRight := findColumnByName(leftColumn.Name.Name, right) + if foundColumnByRight == nil { + onlyLeftColumns = append(onlyLeftColumns, leftColumn) + } + } + return onlyLeftColumns +} + +func findColumnByName(name string, columns []*Column) *Column { + for _, column := range columns { + if column.Name.Name == name { + return column + } + } + return nil +} + +func onlyLeftConstraint(left, right Constraints) []Constraint { + onlyLeftConstraints := make(Constraints, 0) + for _, leftConstraint := range left { + foundConstraintByRight := findConstraintByName(leftConstraint.GetName().Name, right) + if foundConstraintByRight == nil { + onlyLeftConstraints = onlyLeftConstraints.Append(leftConstraint) + } + } + return onlyLeftConstraints +} + +func findConstraintByName(name string, constraints []Constraint) Constraint { //nolint:ireturn + for _, constraint := range constraints { + if constraint.GetName().Name == name { + return constraint + } + } + return nil +} diff --git a/pkg/ddl/spanner/diff_create_table_test.go b/pkg/ddl/spanner/diff_create_table_test.go new file mode 100644 index 0000000..f8691e0 --- /dev/null +++ b/pkg/ddl/spanner/diff_create_table_test.go @@ -0,0 +1,561 @@ +package spanner + +import ( + "testing" + + "github.com/kunitsucom/util.go/testing/assert" + "github.com/kunitsucom/util.go/testing/require" + + "github.com/kunitsucom/ddlctl/pkg/ddl" +) + +//nolint:paralleltest,tparallel +func TestDiffCreateTable(t *testing.T) { + t.Run("failure,ddl.ErrNoDifference", func(t *testing.T) { + t.Parallel() + + before := `CREATE TABLE "users" (id STRING(36) NOT NULL, group_id STRING(36) NOT NULL, "name" STRING(255) NOT NULL, description STRING, CONSTRAINT users_group_id_fkey FOREIGN KEY (group_id) REFERENCES "groups" ("id")) PRIMARY KEY ("id");` + beforeDDL, err := NewParser(NewLexer(before)).Parse() + require.NoError(t, err) + + after := `CREATE TABLE "users" (id STRING(36) NOT NULL, group_id STRING(36) NOT NULL, "name" STRING(255) NOT NULL, description STRING, CONSTRAINT users_group_id_fkey FOREIGN KEY (group_id) REFERENCES "groups" ("id")) PRIMARY KEY ("id");` + + afterDDL, err := NewParser(NewLexer(after)).Parse() + require.NoError(t, err) + + actual, err := DiffCreateTable( + beforeDDL.Stmts[0].(*CreateTableStmt), + afterDDL.Stmts[0].(*CreateTableStmt), + DiffCreateTableUseAlterTableAddConstraintNotValid(false), + ) + + assert.ErrorIs(t, err, ddl.ErrNoDifference) + assert.Nil(t, actual) + + t.Logf("✅: %s: actual: %%#v: \n%#v", t.Name(), actual) + t.Logf("✅: %s: actual: %%s: \n%s", t.Name(), actual) + }) + + t.Run("failure,ddl.ErrNoDifference,SameContent", func(t *testing.T) { + t.Parallel() + + before := `CREATE TABLE "users" (id STRING(36) NOT NULL, group_id STRING(36) NOT NULL, "name" STRING(255) NOT NULL, description STRING, CONSTRAINT users_group_id_fkey FOREIGN KEY (group_id) REFERENCES "groups" ("id")) PRIMARY KEY ("id");` + beforeDDL, err := NewParser(NewLexer(before)).Parse() + require.NoError(t, err) + + after := `CREATE TABLE users (id STRING(36) NOT NULL, group_id STRING(36) NOT NULL, name STRING(255) NOT NULL, description STRING, CONSTRAINT users_group_id_fkey FOREIGN KEY (group_id) REFERENCES groups (id)) PRIMARY KEY (id);` + + afterDDL, err := NewParser(NewLexer(after)).Parse() + require.NoError(t, err) + + actual, err := DiffCreateTable( + beforeDDL.Stmts[0].(*CreateTableStmt), + afterDDL.Stmts[0].(*CreateTableStmt), + DiffCreateTableUseAlterTableAddConstraintNotValid(false), + ) + + assert.ErrorIs(t, err, ddl.ErrNoDifference) + assert.Nil(t, actual) + + t.Logf("✅: %s: actual: %%#v: \n%#v", t.Name(), actual) + t.Logf("✅: %s: actual: %%s: \n%s", t.Name(), actual) + }) + + t.Run("success,ADD_COLUMN", func(t *testing.T) { + t.Parallel() + + before := `CREATE TABLE "users" (id STRING(36) NOT NULL, group_id STRING(36) NOT NULL, "name" STRING(255) NOT NULL, description STRING, CONSTRAINT users_group_id_fkey FOREIGN KEY (group_id) REFERENCES "groups" ("id")) PRIMARY KEY ("id");` + beforeDDL, err := NewParser(NewLexer(before)).Parse() + require.NoError(t, err) + + after := `CREATE TABLE "users" (id STRING(36) NOT NULL, group_id STRING(36) NOT NULL, "name" STRING(255) NOT NULL, "age" INT64 DEFAULT 0 NOT NULL CHECK ("age" >= 0), description STRING, CONSTRAINT users_group_id_fkey FOREIGN KEY (group_id) REFERENCES "groups" ("id")) PRIMARY KEY ("id");` + + afterDDL, err := NewParser(NewLexer(after)).Parse() + require.NoError(t, err) + + actual, err := DiffCreateTable( + beforeDDL.Stmts[0].(*CreateTableStmt), + afterDDL.Stmts[0].(*CreateTableStmt), + DiffCreateTableUseAlterTableAddConstraintNotValid(false), + ) + + expectedStr := `-- - +-- +"age" INT64 NOT NULL DEFAULT 0 +ALTER TABLE "users" ADD COLUMN "age" INT64 NOT NULL DEFAULT 0; +-- - +-- +CONSTRAINT users_age_check CHECK ("age" >= 0) +ALTER TABLE "users" ADD CONSTRAINT users_age_check CHECK ("age" >= 0); +` + + assert.NoError(t, err) + assert.Equal(t, expectedStr, actual.String()) + + t.Logf("✅: %s: actual: %%#v:\n%#v", t.Name(), actual) + }) + + t.Run("success,DROP_COLUMN", func(t *testing.T) { + t.Parallel() + + before := `CREATE TABLE "users" (id STRING(36) NOT NULL, group_id STRING(36) NOT NULL REFERENCES "groups" ("id"), "name" STRING(255) NOT NULL, "age" INT64 DEFAULT 0 NOT NULL CHECK ("age" >= 0), description STRING) PRIMARY KEY ("id");` + beforeDDL, err := NewParser(NewLexer(before)).Parse() + require.NoError(t, err) + + after := `CREATE TABLE "users" (id STRING(36) NOT NULL, group_id STRING(36) NOT NULL REFERENCES "groups" ("id"), "name" STRING(255) NOT NULL, description STRING) PRIMARY KEY ("id");` + + afterDDL, err := NewParser(NewLexer(after)).Parse() + require.NoError(t, err) + + actual, err := DiffCreateTable( + beforeDDL.Stmts[0].(*CreateTableStmt), + afterDDL.Stmts[0].(*CreateTableStmt), + DiffCreateTableUseAlterTableAddConstraintNotValid(false), + ) + + expectedStr := `-- -CONSTRAINT users_age_check CHECK ("age" >= 0) +-- + +ALTER TABLE "users" DROP CONSTRAINT users_age_check; +-- -"age" INT64 NOT NULL DEFAULT 0 +-- + +ALTER TABLE "users" DROP COLUMN "age"; +` + + assert.NoError(t, err) + assert.Equal(t, expectedStr, actual.String()) + + t.Logf("✅: %s: actual: %%#v:\n%#v", t.Name(), actual) + }) + + t.Run("success,ALTER_COLUMN_SET_DATA_TYPE", func(t *testing.T) { + t.Parallel() + + before := `CREATE TABLE "users" (id STRING(36) NOT NULL, group_id STRING(36) NOT NULL REFERENCES "groups" ("id"), "name" STRING(255) NOT NULL, "age" INT64 DEFAULT 0 CHECK ("age" >= 0), description STRING) PRIMARY KEY ("id");` + beforeDDL, err := NewParser(NewLexer(before)).Parse() + require.NoError(t, err) + + after := `CREATE TABLE "users" (id STRING(36) NOT NULL, group_id STRING(36) NOT NULL REFERENCES "groups" ("id"), "name" STRING NOT NULL, "age" INT64 DEFAULT 0 CHECK ("age" >= 0), description STRING) PRIMARY KEY ("id");` + + afterDDL, err := NewParser(NewLexer(after)).Parse() + require.NoError(t, err) + + actual, err := DiffCreateTable( + beforeDDL.Stmts[0].(*CreateTableStmt), + afterDDL.Stmts[0].(*CreateTableStmt), + DiffCreateTableUseAlterTableAddConstraintNotValid(false), + ) + + expectedStr := `-- -"name" STRING(255) NOT NULL +-- +"name" STRING NOT NULL +ALTER TABLE "users" ALTER COLUMN "name" STRING NOT NULL; +` + + assert.NoError(t, err) + assert.Equal(t, expectedStr, actual.String()) + + t.Logf("✅: %s: actual: %%#v:\n%#v", t.Name(), actual) + }) + + t.Run("success,ALTER_COLUMN_DROP_DEFAULT", func(t *testing.T) { + before := `CREATE TABLE "users" (id STRING(36) NOT NULL, group_id STRING(36) NOT NULL, "name" STRING(255) NOT NULL, "age" INT64 DEFAULT 0 CHECK ("age" >= 0), description STRING, CONSTRAINT users_group_id_fkey FOREIGN KEY (group_id) REFERENCES "groups" ("id")) PRIMARY KEY ("id");` + beforeDDL, err := NewParser(NewLexer(before)).Parse() + require.NoError(t, err) + + after := `CREATE TABLE "users" (id STRING(36) NOT NULL, group_id STRING(36) NOT NULL, "name" STRING(255) NOT NULL, "age" INT64 CHECK ("age" >= 0), description STRING, CONSTRAINT users_group_id_fkey FOREIGN KEY (group_id) REFERENCES "groups" ("id")) PRIMARY KEY ("id");` + afterDDL, err := NewParser(NewLexer(after)).Parse() + require.NoError(t, err) + + expectedStr := `-- -"age" INT64 DEFAULT 0 +-- +"age" INT64 +ALTER TABLE "users" ALTER COLUMN "age" DROP DEFAULT; +` + + actual, err := DiffCreateTable( + beforeDDL.Stmts[0].(*CreateTableStmt), + afterDDL.Stmts[0].(*CreateTableStmt), + DiffCreateTableUseAlterTableAddConstraintNotValid(false), + ) + assert.NoError(t, err) + assert.Equal(t, expectedStr, actual.String()) + + t.Logf("✅: %s: actual: %%#v:\n%#v", t.Name(), actual) + }) + + t.Run("success,ALTER_COLUMN_SET_DEFAULT", func(t *testing.T) { + before := `CREATE TABLE "users" (id STRING(36) NOT NULL, group_id STRING(36) NOT NULL, "name" STRING(255) NOT NULL, "age" INT64 CHECK ("age" >= 0), description STRING, CONSTRAINT users_group_id_fkey FOREIGN KEY (group_id) REFERENCES "groups" ("id")) PRIMARY KEY ("id");` + beforeDDL, err := NewParser(NewLexer(before)).Parse() + require.NoError(t, err) + + after := `CREATE TABLE "users" (id STRING(36) NOT NULL, group_id STRING(36) NOT NULL, "name" STRING(255) NOT NULL, "age" INT64 DEFAULT 0 CHECK ("age" <> 0), description STRING, CONSTRAINT users_group_id_fkey FOREIGN KEY (group_id) REFERENCES "groups" ("id")) PRIMARY KEY (id);` + afterDDL, err := NewParser(NewLexer(after)).Parse() + require.NoError(t, err) + + expectedStr := `-- -"age" INT64 +-- +"age" INT64 DEFAULT 0 +ALTER TABLE "users" ALTER COLUMN "age" SET DEFAULT 0; +-- -CONSTRAINT users_age_check CHECK ("age" >= 0) +-- + +ALTER TABLE "users" DROP CONSTRAINT users_age_check; +-- - +-- +CONSTRAINT users_age_check CHECK ("age" <> 0) +ALTER TABLE "users" ADD CONSTRAINT users_age_check CHECK ("age" <> 0); +` + + actual, err := DiffCreateTable( + beforeDDL.Stmts[0].(*CreateTableStmt), + afterDDL.Stmts[0].(*CreateTableStmt), + DiffCreateTableUseAlterTableAddConstraintNotValid(false), + ) + assert.NoError(t, err) + assert.Equal(t, expectedStr, actual.String()) + + t.Logf("✅: %s: actual: %%#v:\n%#v", t.Name(), actual) + }) + + t.Run("success,ALTER_TABLE_RENAME_TO", func(t *testing.T) { + t.Parallel() + + before := `CREATE TABLE "public.users" (id STRING(36) NOT NULL, group_id STRING(36) NOT NULL REFERENCES "groups" ("id"), "name" STRING(255) NOT NULL, "age" INT64 DEFAULT 0 CHECK ("age" >= 0), description STRING) PRIMARY KEY ("id");` + beforeDDL, err := NewParser(NewLexer(before)).Parse() + require.NoError(t, err) + + after := `CREATE TABLE "app_users" (id STRING(36) NOT NULL, group_id STRING(36) NOT NULL REFERENCES "groups" ("id"), "name" STRING(255) NOT NULL, "age" INT64 DEFAULT 0 CHECK ("age" >= 0), description STRING) PRIMARY KEY ("id");` + afterDDL, err := NewParser(NewLexer(after)).Parse() + require.NoError(t, err) + + expectedStr := `-- -public.users +-- +public.app_users +ALTER TABLE "public.users" RENAME TO "public.app_users"; +-- -CONSTRAINT users_group_id_fkey FOREIGN KEY (group_id) REFERENCES "groups" ("id") +-- + +ALTER TABLE "public.app_users" DROP CONSTRAINT users_group_id_fkey; +-- -CONSTRAINT users_age_check CHECK ("age" >= 0) +-- + +ALTER TABLE "public.app_users" DROP CONSTRAINT users_age_check; +-- - +-- +CONSTRAINT app_users_group_id_fkey FOREIGN KEY (group_id) REFERENCES "groups" ("id") +ALTER TABLE "public.app_users" ADD CONSTRAINT app_users_group_id_fkey FOREIGN KEY (group_id) REFERENCES "groups" ("id"); +-- - +-- +CONSTRAINT app_users_age_check CHECK ("age" >= 0) +ALTER TABLE "public.app_users" ADD CONSTRAINT app_users_age_check CHECK ("age" >= 0); +` + + actual, err := DiffCreateTable( + beforeDDL.Stmts[0].(*CreateTableStmt), + afterDDL.Stmts[0].(*CreateTableStmt), + DiffCreateTableUseAlterTableAddConstraintNotValid(false), + ) + assert.NoError(t, err) + assert.Equal(t, expectedStr, actual.String()) + + t.Logf("✅: %s: actual: %%#v: \n%#v", t.Name(), actual) + t.Logf("✅: %s: actual: %%s: \n%s", t.Name(), actual) + }) + + t.Run("success,SET_NOT_NULL", func(t *testing.T) { + t.Parallel() + + before := `CREATE TABLE "users" (id STRING(36) NOT NULL, group_id STRING(36) NOT NULL REFERENCES "groups" ("id"), "name" STRING(255) NOT NULL, "age" INT64 DEFAULT 0 CHECK ("age" >= 0), description STRING) PRIMARY KEY ("id");` + beforeDDL, err := NewParser(NewLexer(before)).Parse() + require.NoError(t, err) + + after := `CREATE TABLE "users" (id STRING(36) NOT NULL, group_id STRING(36) NOT NULL REFERENCES "groups" ("id"), "name" STRING(255) NOT NULL, "age" INT64 DEFAULT 0 NOT NULL CHECK ("age" >= 0), description STRING) PRIMARY KEY ("id");` + afterDDL, err := NewParser(NewLexer(after)).Parse() + require.NoError(t, err) + + expectedStr := `-- -"age" INT64 DEFAULT 0 +-- +"age" INT64 NOT NULL DEFAULT 0 +ALTER TABLE "users" ALTER COLUMN "age" INT64 NOT NULL; +` + + actual, err := DiffCreateTable( + beforeDDL.Stmts[0].(*CreateTableStmt), + afterDDL.Stmts[0].(*CreateTableStmt), + DiffCreateTableUseAlterTableAddConstraintNotValid(false), + ) + assert.NoError(t, err) + assert.Equal(t, expectedStr, actual.String()) + + t.Logf("✅: %s: actual: %%#v:\n%#v", t.Name(), actual) + }) + + t.Run("success,DROP_NOT_NULL", func(t *testing.T) { + t.Parallel() + + before := `CREATE TABLE "users" (id STRING(36) NOT NULL, group_id STRING(36) NOT NULL REFERENCES "groups" ("id"), "name" STRING(255) NOT NULL, "age" INT64 DEFAULT 0 NOT NULL CHECK ("age" >= 0), description STRING) PRIMARY KEY ("id");` + beforeDDL, err := NewParser(NewLexer(before)).Parse() + require.NoError(t, err) + + after := `CREATE TABLE "users" (id STRING(36) NOT NULL, group_id STRING(36) NOT NULL REFERENCES "groups" ("id"), "name" STRING(255) NOT NULL, "age" INT64 DEFAULT 0 CHECK ("age" >= 0), description STRING) PRIMARY KEY ("id");` + afterDDL, err := NewParser(NewLexer(after)).Parse() + require.NoError(t, err) + + expectedStr := `-- -"age" INT64 NOT NULL DEFAULT 0 +-- +"age" INT64 DEFAULT 0 +ALTER TABLE "users" ALTER COLUMN "age" INT64; +` + + actual, err := DiffCreateTable( + beforeDDL.Stmts[0].(*CreateTableStmt), + afterDDL.Stmts[0].(*CreateTableStmt), + DiffCreateTableUseAlterTableAddConstraintNotValid(false), + ) + assert.NoError(t, err) + assert.Equal(t, expectedStr, actual.String()) + + t.Logf("✅: %s: actual: %%#v:\n%#v", t.Name(), actual) + }) + + t.Run("success,ALTER_PRIMARY_KEY", func(t *testing.T) { + t.Parallel() + + before := `CREATE TABLE "users" (id STRING(36) NOT NULL, group_id STRING(36) NOT NULL REFERENCES "groups" ("id"), "name" STRING(255) NOT NULL, "age" INT64 DEFAULT 0 NOT NULL CHECK ("age" >= 0), description STRING) PRIMARY KEY ("id");` + beforeDDL, err := NewParser(NewLexer(before)).Parse() + require.NoError(t, err) + + after := `CREATE TABLE "users" (id STRING(36) NOT NULL, group_id STRING(36) NOT NULL REFERENCES "groups" ("id"), "name" STRING(255) NOT NULL, "age" INT64 DEFAULT 0 NOT NULL CHECK ("age" >= 0), description STRING) PRIMARY KEY ("id", name);` + afterDDL, err := NewParser(NewLexer(after)).Parse() + require.NoError(t, err) + + actual, err := DiffCreateTable( + beforeDDL.Stmts[0].(*CreateTableStmt), + afterDDL.Stmts[0].(*CreateTableStmt), + DiffCreateTableUseAlterTableAddConstraintNotValid(false), + ) + assert.NoError(t, err) + expected := `-- -PRIMARY KEY ("id") +-- +PRIMARY KEY ("id", name) +DROP TABLE "users"; +CREATE TABLE "users" ( + id STRING(36) NOT NULL, + group_id STRING(36) NOT NULL, + "name" STRING(255) NOT NULL, + "age" INT64 NOT NULL DEFAULT 0, + description STRING, + CONSTRAINT users_group_id_fkey FOREIGN KEY (group_id) REFERENCES "groups" ("id"), + CONSTRAINT users_age_check CHECK ("age" >= 0) +) PRIMARY KEY ("id", name); +` + + assert.Equal(t, expected, actual.String()) + }) + + t.Run("success,DROP_ADD_FOREIGN_KEY", func(t *testing.T) { + t.Parallel() + + before := `CREATE TABLE "users" (id STRING(36) NOT NULL, group_id STRING(36) NOT NULL, "name" STRING(255) NOT NULL, "age" INT64 DEFAULT 0 NOT NULL CHECK ("age" >= 0), description STRING, CONSTRAINT users_group_id_fkey FOREIGN KEY (group_id) REFERENCES "groups" ("id")) PRIMARY KEY ("id");` + beforeDDL, err := NewParser(NewLexer(before)).Parse() + require.NoError(t, err) + + after := `CREATE TABLE "users" (id STRING(36) NOT NULL, group_id STRING(36) NOT NULL, "name" STRING(255) NOT NULL, "age" INT64 DEFAULT 0 NOT NULL CHECK ("age" >= 0), description STRING, CONSTRAINT users_group_id_fkey FOREIGN KEY (group_id, name) REFERENCES "groups" ("id", name)) PRIMARY KEY ("id");` + afterDDL, err := NewParser(NewLexer(after)).Parse() + require.NoError(t, err) + + expectedStr := `-- -CONSTRAINT users_group_id_fkey FOREIGN KEY (group_id) REFERENCES "groups" ("id") +-- + +ALTER TABLE "users" DROP CONSTRAINT users_group_id_fkey; +-- - +-- +CONSTRAINT users_group_id_fkey FOREIGN KEY (group_id, name) REFERENCES "groups" ("id", name) +ALTER TABLE "users" ADD CONSTRAINT users_group_id_fkey FOREIGN KEY (group_id, name) REFERENCES "groups" ("id", name); +` + + actual, err := DiffCreateTable( + beforeDDL.Stmts[0].(*CreateTableStmt), + afterDDL.Stmts[0].(*CreateTableStmt), + DiffCreateTableUseAlterTableAddConstraintNotValid(false), + ) + assert.NoError(t, err) + assert.Equal(t, expectedStr, actual.String()) + + t.Logf("✅: %s: actual: %%#v:\n%#v", t.Name(), actual) + }) + + t.Run("success,DROP_ADD_UNIQUE", func(t *testing.T) { + t.Parallel() + + before := `CREATE TABLE "users" (id STRING(36) NOT NULL, group_id STRING(36) NOT NULL, "name" STRING(255) NOT NULL, "age" INT64 DEFAULT 0 NOT NULL CHECK ("age" >= 0), description STRING, CONSTRAINT users_group_id_fkey FOREIGN KEY (group_id) REFERENCES "groups" ("id")) PRIMARY KEY ("id");` + beforeDDL, err := NewParser(NewLexer(before)).Parse() + require.NoError(t, err) + + after := `CREATE TABLE "users" (id STRING(36) NOT NULL, group_id STRING(36) NOT NULL, "name" STRING(255) NOT NULL, "age" INT64 DEFAULT 0 NOT NULL CHECK ("age" >= 0), description STRING, CONSTRAINT users_group_id_fkey_2 FOREIGN KEY (group_id) REFERENCES "groups" ("id")) PRIMARY KEY ("id");` + afterDDL, err := NewParser(NewLexer(after)).Parse() + require.NoError(t, err) + + expectedStr := `-- -CONSTRAINT users_group_id_fkey FOREIGN KEY (group_id) REFERENCES "groups" ("id") +-- + +ALTER TABLE "users" DROP CONSTRAINT users_group_id_fkey; +-- - +-- +CONSTRAINT users_group_id_fkey_2 FOREIGN KEY (group_id) REFERENCES "groups" ("id") +ALTER TABLE "users" ADD CONSTRAINT users_group_id_fkey_2 FOREIGN KEY (group_id) REFERENCES "groups" ("id"); +` + + actual, err := DiffCreateTable( + beforeDDL.Stmts[0].(*CreateTableStmt), + afterDDL.Stmts[0].(*CreateTableStmt), + DiffCreateTableUseAlterTableAddConstraintNotValid(false), + ) + assert.NoError(t, err) + assert.Equal(t, expectedStr, actual.String()) + + t.Logf("✅: %s: actual: %%#v:\n%#v", t.Name(), actual) + }) + + t.Run("success,ALTER_COLUMN_SET_DEFAULT_OVERWRITE", func(t *testing.T) { + t.Parallel() + + before := `CREATE TABLE "users" (id STRING(36) NOT NULL, group_id STRING(36) NOT NULL, "name" STRING(255) NOT NULL, "age" INT64 DEFAULT 0 NOT NULL CHECK ("age" >= 0), description STRING) PRIMARY KEY ("id");` + beforeDDL, err := NewParser(NewLexer(before)).Parse() + require.NoError(t, err) + + after := `CREATE TABLE "users" (id STRING(36) NOT NULL, group_id STRING(36) NOT NULL, "name" STRING(255) NOT NULL, "age" INT64 DEFAULT ( (0 + 3) - 1 * 4 / 2 ) NOT NULL CHECK ("age" >= 0), description STRING) PRIMARY KEY ("id");` + afterDDL, err := NewParser(NewLexer(after)).Parse() + require.NoError(t, err) + + expectedStr := `-- -"age" INT64 NOT NULL DEFAULT 0 +-- +"age" INT64 NOT NULL DEFAULT ((0 + 3) - 1 * 4 / 2) +ALTER TABLE "users" ALTER COLUMN "age" SET DEFAULT ((0 + 3) - 1 * 4 / 2); +` + + actual, err := DiffCreateTable( + beforeDDL.Stmts[0].(*CreateTableStmt), + afterDDL.Stmts[0].(*CreateTableStmt), + DiffCreateTableUseAlterTableAddConstraintNotValid(false), + ) + assert.NoError(t, err) + assert.Equal(t, expectedStr, actual.String()) + + t.Logf("✅: %s: actual: %%#v:\n%#v", t.Name(), actual) + }) + + t.Run("success,ALTER_COLUMN_SET_DEFAULT_complex", func(t *testing.T) { + t.Parallel() + + before := `CREATE TABLE complex_defaults ( + id INT64, + created_at TIMESTAMP OPTIONS (allow_commit_timestamp=true, option_name=null), + updated_at TIMESTAMP, + unique_code STRING, + status STRING DEFAULT ('pending'), + random_number INT64 DEFAULT (FLOOR(RANDOM() * 100)), + json_data JSON DEFAULT ('{}'), + calculated_value INT64 DEFAULT (SELECT COUNT(*) FROM another_table) +) PRIMARY KEY (id); +` + beforeDDL, err := NewParser(NewLexer(before)).Parse() + require.NoError(t, err) + + after := `CREATE TABLE complex_defaults ( + id INT64, + created_at TIMESTAMP, + updated_at TIMESTAMP OPTIONS (allow_commit_timestamp=true, option_name=null), + unique_code STRING DEFAULT (GENERATE_UUID()), + status STRING DEFAULT ('pending'), + random_number INT64 DEFAULT (FLOOR(RANDOM() * 100)), + json_data JSON DEFAULT ('{}'), + calculated_value INT64 DEFAULT (SELECT COUNT(*) FROM another_table) +) PRIMARY KEY (id); +` + afterDDL, err := NewParser(NewLexer(after)).Parse() + require.NoError(t, err) + + expectedStr := `-- -created_at TIMESTAMP OPTIONS (allow_commit_timestamp = TRUE, option_name = NULL) +-- +created_at TIMESTAMP +ALTER TABLE complex_defaults ALTER COLUMN created_at DROP OPTIONS; +-- -updated_at TIMESTAMP +-- +updated_at TIMESTAMP OPTIONS (allow_commit_timestamp = TRUE, option_name = NULL) +ALTER TABLE complex_defaults ALTER COLUMN updated_at SET OPTIONS (allow_commit_timestamp = TRUE, option_name = NULL); +-- -unique_code STRING +-- +unique_code STRING DEFAULT (GENERATE_UUID()) +ALTER TABLE complex_defaults ALTER COLUMN unique_code SET DEFAULT (GENERATE_UUID()); +` + + actual, err := DiffCreateTable( + beforeDDL.Stmts[0].(*CreateTableStmt), + afterDDL.Stmts[0].(*CreateTableStmt), + DiffCreateTableUseAlterTableAddConstraintNotValid(false), + ) + assert.NoError(t, err) + assert.Equal(t, expectedStr, actual.String()) + + t.Logf("✅: %s: actual: %%#v:\n%#v", t.Name(), actual) + }) + + t.Run("success,DiffCreateTableUseAlterTableAddConstraintNotValid", func(t *testing.T) { + t.Parallel() + + beforeDDL, err := NewParser(NewLexer(`CREATE TABLE "users" (id STRING(36) NOT NULL, group_id STRING(36) NOT NULL REFERENCES "groups" ("id"), "name" STRING(255) NOT NULL, "age" INT64 DEFAULT 0, description STRING) PRIMARY KEY ("id");`)).Parse() + require.NoError(t, err) + + afterDDL, err := NewParser(NewLexer(`CREATE TABLE "users" (id STRING(36) NOT NULL, group_id STRING(36) NOT NULL REFERENCES "groups" ("id"), "name" STRING(255) NOT NULL, "age" INT64 DEFAULT 0 CHECK ("age" >= 0), description STRING) PRIMARY KEY ("id");`)).Parse() + require.NoError(t, err) + + expected := `-- - +-- +CONSTRAINT users_age_check CHECK ("age" >= 0) +ALTER TABLE "users" ADD CONSTRAINT users_age_check CHECK ("age" >= 0) NOT VALID; +` + actual, err := DiffCreateTable( + beforeDDL.Stmts[0].(*CreateTableStmt), + afterDDL.Stmts[0].(*CreateTableStmt), + DiffCreateTableUseAlterTableAddConstraintNotValid(true), + ) + + assert.NoError(t, err) + assert.Equal(t, expected, actual.String()) + + t.Logf("✅: %s: actual: %%#v: \n%#v", t.Name(), actual) + t.Logf("✅: %s: actual: %%s: \n%s", t.Name(), actual) + }) + + t.Run("success,CREATE_TABLE", func(t *testing.T) { + t.Parallel() + + afterDDL, err := NewParser(NewLexer(`CREATE TABLE "users" (id STRING(36) NOT NULL, group_id STRING(36) NOT NULL REFERENCES "groups" ("id"), "name" STRING(255) NOT NULL, "age" INT64 DEFAULT 0 CHECK ("age" >= 0), description STRING) PRIMARY KEY ("id");`)).Parse() + require.NoError(t, err) + + expected := `CREATE TABLE "users" ( + id STRING(36) NOT NULL, + group_id STRING(36) NOT NULL, + "name" STRING(255) NOT NULL, + "age" INT64 DEFAULT 0, + description STRING, + CONSTRAINT users_group_id_fkey FOREIGN KEY (group_id) REFERENCES "groups" ("id"), + CONSTRAINT users_age_check CHECK ("age" >= 0) +) PRIMARY KEY ("id"); +` + actual, err := DiffCreateTable( + nil, + afterDDL.Stmts[0].(*CreateTableStmt), + DiffCreateTableUseAlterTableAddConstraintNotValid(true), + ) + + assert.NoError(t, err) + assert.Equal(t, expected, actual.String()) + + t.Logf("✅: %s: actual: %%#v: \n%#v", t.Name(), actual) + t.Logf("✅: %s: actual: %%s: \n%s", t.Name(), actual) + }) + + t.Run("success,DROP_TABLE", func(t *testing.T) { + t.Parallel() + + before := `CREATE TABLE "users" (id STRING(36) NOT NULL, group_id STRING(36) NOT NULL REFERENCES "groups" ("id"), "name" STRING(255) NOT NULL, "age" INT64 DEFAULT 0 CHECK ("age" >= 0), description STRING) PRIMARY KEY ("id");` + + beforeDDL, err := NewParser(NewLexer(before)).Parse() + require.NoError(t, err) + + ddls, err := DiffCreateTable( + beforeDDL.Stmts[0].(*CreateTableStmt), + nil, + DiffCreateTableUseAlterTableAddConstraintNotValid(true), + ) + + assert.NoError(t, err) + assert.Equal(t, &DDL{ + Stmts: []Stmt{ + &DropTableStmt{ + Name: &ObjectName{Name: &Ident{Name: "users", QuotationMark: `"`, Raw: `"users"`}}, + }, + }, + }, ddls) + + t.Logf("✅: %s:\n%s", t.Name(), ddls) + }) +} diff --git a/pkg/ddl/spanner/diff_test.go b/pkg/ddl/spanner/diff_test.go new file mode 100644 index 0000000..214b961 --- /dev/null +++ b/pkg/ddl/spanner/diff_test.go @@ -0,0 +1,449 @@ +package spanner + +import ( + "fmt" + "testing" + + "github.com/kunitsucom/util.go/testing/assert" + "github.com/kunitsucom/util.go/testing/require" + + "github.com/kunitsucom/ddlctl/pkg/ddl" +) + +func TestDiff(t *testing.T) { + t.Parallel() + + t.Run("failure,ddl.ErrNoDifference", func(t *testing.T) { + t.Parallel() + + before := &DDL{} + after := &DDL{} + _, err := Diff(before, after) + require.ErrorIs(t, err, ddl.ErrNoDifference) + }) + + t.Run("failure,ddl.ErrNotSupported,DropTableStmt", func(t *testing.T) { + t.Parallel() + + { + before := &DDL{ + Stmts: []Stmt{ + &DropTableStmt{Name: &ObjectName{Name: &Ident{Name: "table_name", Raw: "table_name"}}}, + }, + } + after := (*DDL)(nil) + _, err := Diff(before, after) + require.ErrorIs(t, err, ddl.ErrNotSupported) + } + { + before := &DDL{ + Stmts: []Stmt{ + &DropTableStmt{Name: &ObjectName{Name: &Ident{Name: "table_name", Raw: "table_name"}}}, + }, + } + after := &DDL{} + _, err := Diff(before, after) + require.ErrorIs(t, err, ddl.ErrNotSupported) + } + { + before := &DDL{} + after := &DDL{ + Stmts: []Stmt{ + &DropTableStmt{Name: &ObjectName{Name: &Ident{Name: "table_name", Raw: "table_name"}}}, + }, + } + _, err := Diff(before, after) + require.ErrorIs(t, err, ddl.ErrNotSupported) + } + }) + + t.Run("success,after", func(t *testing.T) { + t.Parallel() + + before := (*DDL)(nil) + after := &DDL{ + Stmts: []Stmt{ + &CreateTableStmt{ + Name: &ObjectName{Name: &Ident{Name: "table_name", Raw: "table_name"}}, + Columns: []*Column{ + { + Name: &Ident{Name: "column_name", Raw: "column_name"}, + DataType: &DataType{ + Name: "STRING", + }, + NotNull: true, + }, + }, + Constraints: []Constraint{ + &CheckConstraint{ + Name: NewRawIdent("table_name_check_column_name"), + Expr: &Expr{ + Idents: []*Ident{ + NewRawIdent("("), + NewRawIdent("column_name"), + NewRawIdent("!="), + NewRawIdent("''"), + NewRawIdent(")"), + }, + }, + }, + }, + Options: []*Option{ + {Name: "PRIMARY KEY", Value: &Expr{Idents: []*Ident{NewRawIdent("("), NewRawIdent("column_name"), NewRawIdent(")")}}}, + }, + }, + }, + } + expected := `CREATE TABLE table_name ( + column_name STRING NOT NULL, + CONSTRAINT table_name_check_column_name CHECK (column_name != '') +) PRIMARY KEY (column_name); +` + actual, err := Diff(before, after) + require.NoError(t, err) + assert.Equal(t, expected, actual.String()) + + t.Logf("✅: %s: actual: %%#v: \n%#v", t.Name(), actual) + t.Logf("✅: %s: actual: %%s: \n%s", t.Name(), actual) + }) + + t.Run("success,before,nil,Table", func(t *testing.T) { + t.Parallel() + + before := &DDL{ + Stmts: []Stmt{ + &CreateTableStmt{ + Name: &ObjectName{Schema: &Ident{Name: "public", Raw: "public"}, Name: &Ident{Name: "table_name", Raw: "table_name"}}, + Columns: []*Column{ + { + Name: &Ident{Name: "column_name", Raw: "column_name"}, + }, + }, + }, + }, + } + after := (*DDL)(nil) + + expected := `DROP TABLE public.table_name; +` + actual, err := Diff(before, after) + require.NoError(t, err) + + assert.Equal(t, expected, actual.String()) + + t.Logf("✅: %s: actual: %%#v: \n%#v", t.Name(), actual) + t.Logf("✅: %s: actual: %%s: \n%s", t.Name(), actual) + }) + + t.Run("success,before,Table", func(t *testing.T) { + t.Parallel() + + before := &DDL{ + Stmts: []Stmt{ + &CreateTableStmt{ + Name: &ObjectName{Name: &Ident{Name: "table_name", Raw: "table_name"}}, + Columns: []*Column{ + { + Name: &Ident{Name: "column_name", Raw: "column_name"}, + }, + }, + }, + }, + } + after := &DDL{} + + expected := `DROP TABLE table_name; +` + actual, err := Diff(before, after) + require.NoError(t, err) + assert.Equal(t, expected, actual.String()) + + t.Logf("✅: %s: actual: %%#v: \n%#v", t.Name(), actual) + t.Logf("✅: %s: actual: %%s: \n%s", t.Name(), actual) + }) + + t.Run("success,before,nil,Index", func(t *testing.T) { + t.Parallel() + + before := &DDL{ + Stmts: []Stmt{ + &CreateIndexStmt{ + Name: &ObjectName{Name: &Ident{Name: "table_name_idx_column_name", Raw: "table_name_idx_column_name"}}, + Columns: []*ColumnIdent{ + { + Ident: &Ident{Name: "column_name", Raw: "column_name"}, + }, + }, + }, + }, + } + after := (*DDL)(nil) + actual, err := Diff(before, after) + require.NoError(t, err) + expected := `DROP INDEX table_name_idx_column_name; +` + assert.Equal(t, expected, actual.String()) + + t.Logf("✅: %s: actual: %%#v: \n%#v", t.Name(), actual) + t.Logf("✅: %s: actual: %%s: \n%s", t.Name(), actual) + }) + + t.Run("success,before,Index", func(t *testing.T) { + t.Parallel() + + before := &DDL{ + Stmts: []Stmt{ + &CreateIndexStmt{ + Name: &ObjectName{Name: &Ident{Name: "table_name_idx_column_name", Raw: "table_name_idx_column_name"}}, + Columns: []*ColumnIdent{ + { + Ident: &Ident{Name: "column_name", Raw: "column_name"}, + }, + }, + }, + }, + } + after := &DDL{} + actual, err := Diff(before, after) + require.NoError(t, err) + expected := `DROP INDEX table_name_idx_column_name; +` + assert.Equal(t, expected, actual.String()) + + t.Logf("✅: %s: actual: %%#v: \n%#v", t.Name(), actual) + t.Logf("✅: %s: actual: %%s: \n%s", t.Name(), actual) + }) + + t.Run("success,before,Table", func(t *testing.T) { + t.Parallel() + + before := &DDL{} + after := &DDL{ + Stmts: []Stmt{ + &CreateTableStmt{ + Name: &ObjectName{Name: &Ident{Name: "table_name", Raw: "table_name"}}, + Columns: []*Column{ + { + Name: &Ident{Name: "column_name", Raw: "column_name"}, + DataType: &DataType{ + Name: "STRING", + }, + NotNull: true, + }, + }, + Constraints: []Constraint{ + &CheckConstraint{ + Name: NewRawIdent("table_name_check_column_name"), + Expr: &Expr{ + Idents: []*Ident{ + NewRawIdent("("), + NewRawIdent("column_name"), + NewRawIdent("!="), + NewRawIdent("''"), + NewRawIdent(")"), + }, + }, + }, + }, + Options: []*Option{ + {Name: "PRIMARY KEY", Value: &Expr{Idents: []*Ident{NewRawIdent("("), NewRawIdent("column_name"), NewRawIdent(")")}}}, + }, + }, + }, + } + + expected := `CREATE TABLE table_name ( + column_name STRING NOT NULL, + CONSTRAINT table_name_check_column_name CHECK (column_name != '') +) PRIMARY KEY (column_name); +` + actual, err := Diff(before, after) + require.NoError(t, err) + + assert.Equal(t, expected, actual.String()) + + t.Logf("✅: %s: actual: %%#v: \n%#v", t.Name(), actual) + t.Logf("✅: %s: actual: %%s: \n%s", t.Name(), actual) + }) + + t.Run("success,before,Index", func(t *testing.T) { + t.Parallel() + + before := &DDL{} + after := &DDL{ + Stmts: []Stmt{ + &CreateIndexStmt{ + Name: &ObjectName{Name: &Ident{Name: "table_name_idx_column_name", Raw: "table_name_idx_column_name"}}, + TableName: &ObjectName{Name: &Ident{Name: "table_name", Raw: "table_name"}}, + Columns: []*ColumnIdent{ + { + Ident: &Ident{Name: "column_name", Raw: "column_name"}, + }, + }, + }, + }, + } + actual, err := Diff(before, after) + require.NoError(t, err) + if !assert.Equal(t, after, actual) { + assert.Equal(t, fmt.Sprintf("%#v", after), fmt.Sprintf("%#v", actual)) + } + assert.Equal(t, `CREATE INDEX table_name_idx_column_name ON table_name (column_name); +`, actual.String()) + }) + + t.Run("success,before,after,Table", func(t *testing.T) { + t.Parallel() + + before, err := NewParser(NewLexer(`CREATE TABLE public.users ( + user_id STRING(36) NOT NULL, + username STRING(256) NOT NULL, + is_verified BOOL NOT NULL DEFAULT (false), + created_at TIMESTAMP NOT NULL DEFAULT (CURRENT_TIMESTAMP()), +) PRIMARY KEY (user_id); +`)).Parse() + require.NoError(t, err) + + after, err := NewParser(NewLexer(`CREATE TABLE public.users ( + user_id STRING(36) NOT NULL, + username STRING(256) NOT NULL, + is_verified BOOL NOT NULL DEFAULT (false), + created_at TIMESTAMP NOT NULL DEFAULT (CURRENT_TIMESTAMP()), + updated_at TIMESTAMP NOT NULL DEFAULT (CURRENT_TIMESTAMP()), +) PRIMARY KEY (user_id); +`)).Parse() + require.NoError(t, err) + + expected := `-- - +-- +updated_at TIMESTAMP NOT NULL DEFAULT (CURRENT_TIMESTAMP()) +ALTER TABLE public.users ADD COLUMN updated_at TIMESTAMP NOT NULL DEFAULT (CURRENT_TIMESTAMP()); +` + actual, err := Diff(before, after) + require.NoError(t, err) + + assert.Equal(t, expected, actual.String()) + + t.Logf("✅: %s: actual: %%#v: \n%#v", t.Name(), actual) + t.Logf("✅: %s: actual: %%s: \n%s", t.Name(), actual) + }) + + t.Run("success,before,after,Table,Asc", func(t *testing.T) { + t.Parallel() + + before, err := NewParser(NewLexer(`CREATE TABLE users ( + user_id STRING(36) NOT NULL, + username STRING(256) NOT NULL, + is_verified BOOL NOT NULL DEFAULT (false), + created_at TIMESTAMP NOT NULL DEFAULT (CURRENT_TIMESTAMP()), + updated_at TIMESTAMP NOT NULL DEFAULT (CURRENT_TIMESTAMP()), +) PRIMARY KEY (user_id); +CREATE INDEX users_idx_by_username ON users (username DESC); +`)).Parse() + require.NoError(t, err) + + after, err := NewParser(NewLexer(`CREATE TABLE users ( + user_id STRING(36) NOT NULL, + username STRING(256) NOT NULL, + is_verified BOOL NOT NULL DEFAULT (false), + created_at TIMESTAMP NOT NULL DEFAULT (CURRENT_TIMESTAMP()), + updated_at TIMESTAMP NOT NULL DEFAULT (CURRENT_TIMESTAMP()), +) PRIMARY KEY (user_id); +CREATE INDEX users_idx_by_username ON users (username ASC); +`)).Parse() + require.NoError(t, err) + + expected := `-- -CREATE INDEX users_idx_by_username ON users (username DESC); +-- +CREATE INDEX users_idx_by_username ON users (username ASC); +-- +DROP INDEX users_idx_by_username; +CREATE INDEX users_idx_by_username ON users (username ASC); +` + actual, err := Diff(before, after) + require.NoError(t, err) + assert.Equal(t, expected, actual.String()) + + t.Logf("✅: %s: actual: %%#v: \n%#v", t.Name(), actual) + t.Logf("✅: %s: actual: %%s: \n%s", t.Name(), actual) + }) + + t.Run("success,before,after,Index", func(t *testing.T) { + t.Parallel() + + before, err := NewParser(NewLexer(`CREATE UNIQUE INDEX IF NOT EXISTS public.users_idx_by_username ON public.users (username DESC);`)).Parse() + require.NoError(t, err) + + after, err := NewParser(NewLexer(`CREATE UNIQUE INDEX IF NOT EXISTS public.users_idx_by_username ON public.users (username ASC, age ASC);`)).Parse() + require.NoError(t, err) + + expected := `-- -CREATE UNIQUE INDEX public.users_idx_by_username ON public.users (username DESC); +-- +CREATE UNIQUE INDEX public.users_idx_by_username ON public.users (username ASC, age ASC); +-- +DROP INDEX public.users_idx_by_username; +CREATE UNIQUE INDEX IF NOT EXISTS public.users_idx_by_username ON public.users (username ASC, age ASC); +` + actual, err := Diff(before, after) + require.NoError(t, err) + + assert.Equal(t, expected, actual.String()) + + t.Logf("✅: %s: actual: %%#v: \n%#v", t.Name(), actual) + t.Logf("✅: %s: actual: %%s: \n%s", t.Name(), actual) + }) + + t.Run("success,STRING(10)->STRING(11)", func(t *testing.T) { + t.Parallel() + + before, err := NewParser(NewLexer(`CREATE TABLE public.users ( username STRING(10) NOT NULL );`)).Parse() + require.NoError(t, err) + + after, err := NewParser(NewLexer(`CREATE TABLE public.users ( username STRING(11) NOT NULL );`)).Parse() + require.NoError(t, err) + + expected := `-- -username STRING(10) NOT NULL +-- +username STRING(11) NOT NULL +ALTER TABLE public.users ALTER COLUMN username STRING(11) NOT NULL; +` + actual, err := Diff(before, after) + require.NoError(t, err) + + if !assert.Equal(t, expected, actual.String()) { + t.Errorf("❌: %s: stmt: %%#v: \n%#v", t.Name(), actual) + } + }) + + t.Run("success,SET_DEFAULT_TRUE_FALSE", func(t *testing.T) { + t.Parallel() + + before, err := NewParser(NewLexer(`CREATE TABLE public.passwords ( user_id STRING(36) NOT NULL, password STRING NOT NULL, is_verified BOOL NOT NULL DEFAULT (false), is_expired BOOL NOT NULL DEFAULT (true) );`)).Parse() + require.NoError(t, err) + + after, err := NewParser(NewLexer(`CREATE TABLE public.passwords ( user_id STRING(36) NOT NULL, password STRING NOT NULL, is_verified BOOL NOT NULL DEFAULT (FALSE), is_expired BOOL NOT NULL DEFAULT (TRUE) );`)).Parse() + require.NoError(t, err) + + expected := `` + actual, err := Diff(before, after) + assert.ErrorIs(t, err, ddl.ErrNoDifference) + + if !assert.Equal(t, expected, actual.String()) { + t.Errorf("❌: %s: stmt: %%#v: \n%#v", t.Name(), actual) + } + }) + + t.Run("success,ddl.ErrNoDifference,SameContent", func(t *testing.T) { + t.Parallel() + + before, err := NewParser(NewLexer(`CREATE TABLE passwords ( user_id STRING(36) NOT NULL, password STRING NOT NULL, is_verified BOOL NOT NULL DEFAULT (false), is_expired BOOL NOT NULL DEFAULT (true) );`)).Parse() + require.NoError(t, err) + + after, err := NewParser(NewLexer(`CREATE TABLE "passwords" ( "user_id" STRING(36) NOT NULL, "password" STRING NOT NULL, "is_verified" BOOL NOT NULL DEFAULT (FALSE), "is_expired" BOOL NOT NULL DEFAULT (TRUE) );`)).Parse() + require.NoError(t, err) + + expected := `` + actual, err := Diff(before, after) + assert.ErrorIs(t, err, ddl.ErrNoDifference) + + if !assert.Equal(t, expected, actual.String()) { + t.Errorf("❌: %s: stmt: %%#v: \n%#v", t.Name(), actual) + } + }) +} diff --git a/pkg/ddl/spanner/lexar-gen.sh b/pkg/ddl/spanner/lexar-gen.sh new file mode 100644 index 0000000..1d64aff --- /dev/null +++ b/pkg/ddl/spanner/lexar-gen.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash + + echo ' // START CASES DO NOT EDIT' + echo ' switch token {' + grep -E "^\tTOKEN_[A-Za-z0-9_]+ +TokenType += +[\"\`][A-Za-z0-9_]+[\"\`]" "${1:?}" | while read -r LINE; do + const=$(awk '{print $1}' <<<"${LINE:-}") + literal=$(awk '{print $4}' <<<"${LINE:-}") + case "${literal:?}" in + '"IDENT"') + echo -e "\tdefault:" + echo -e "\t\treturn ${const:?}" + ;; + '"OPEN_PAREN"' | '"CLOSE_PAREN"' | '"COMMA"' | '"SEMICOLON"' | '"ILLEGAL"' | '"EOF"') + continue + ;; + *) + echo -e "\tcase ${literal:?}:" + echo -e "\t\treturn ${const:?}" + ;; + esac + done + echo ' }' + echo ' // END CASES DO NOT EDIT' diff --git a/pkg/ddl/spanner/lexar.go b/pkg/ddl/spanner/lexar.go new file mode 100644 index 0000000..aa172ba --- /dev/null +++ b/pkg/ddl/spanner/lexar.go @@ -0,0 +1,407 @@ +package spanner + +import ( + "strings" +) + +// MEMO: https://www.postgresql.jp/docs/11/datatype.html + +// Token はSQL文のトークンを表す型です。 +type Token struct { + Type TokenType + Literal Literal +} + +type Literal struct { + Str string +} + +func (l *Literal) String() string { + return l.Str +} + +func (l *Literal) StringForDiff() string { + return l.Str +} + +type TokenType string + +func (t TokenType) String() string { + return string(t) +} + +//nolint:revive,stylecheck +const ( + // SPECIAL TOKENS. + TOKEN_ILLEGAL TokenType = "ILLEGAL" + TOKEN_EOF TokenType = "EOF" + + // SPECIAL CHARACTERS. + TOKEN_OPEN_PAREN TokenType = "OPEN_PAREN" // ( + TOKEN_CLOSE_PAREN TokenType = "CLOSE_PAREN" // ) + TOKEN_COMMA TokenType = "COMMA" // , + TOKEN_SEMICOLON TokenType = "SEMICOLON" // ; + TOKEN_EQUAL TokenType = "EQUAL" // = + TOKEN_GREATER TokenType = "GREATER" // > + TOKEN_LESS TokenType = "LESS" // < + TOKEN_PLUS TokenType = "PLUS" // + + TOKEN_MINUS TokenType = "MINUS" // - + TOKEN_ASTERISK TokenType = "ASTERISK" // * + TOKEN_SLASH TokenType = "SLASH" // / + TOKEN_STRING_CONCAT TokenType = "STRING_CONCAT" //nolint:gosec // || + TOKEN_TYPECAST TokenType = "TYPECAST" // :: + TOKEN_TYPE_ANNOTATION TokenType = "TYPE_ANNOTATION" // ::: //diff:ignore-line-postgres-cockroach + + // VERB. + TOKEN_CREATE TokenType = "CREATE" + TOKEN_ALTER TokenType = "ALTER" + TOKEN_DROP TokenType = "DROP" + TOKEN_RENAME TokenType = "RENAME" + TOKEN_TRUNCATE TokenType = "TRUNCATE" + TOKEN_DELETE TokenType = "DELETE" + TOKEN_UPDATE TokenType = "UPDATE" + + // OBJECT. + TOKEN_TABLE TokenType = "TABLE" + TOKEN_INDEX TokenType = "INDEX" + TOKEN_VIEW TokenType = "VIEW" + + // OTHER. + TOKEN_IF TokenType = "IF" + TOKEN_EXISTS TokenType = "EXISTS" + TOKEN_USING TokenType = "USING" + TOKEN_ON TokenType = "ON" + TOKEN_TO TokenType = "TO" + TOKEN_WITH TokenType = "WITH" + + // DATA TYPE. + TOKEN_BOOL TokenType = "BOOL" //diff:ignore-line-postgres-cockroach + TOKEN_INT64 TokenType = "INT64" //diff:ignore-line-postgres-cockroach + TOKEN_FLOAT64 TokenType = "FLOAT64" + TOKEN_NUMERIC TokenType = "NUMERIC" + TOKEN_JSON TokenType = "JSON" + TOKEN_STRING TokenType = "STRING" //diff:ignore-line-postgres-cockroach + TOKEN_BYTES TokenType = "BYTES" + TOKEN_TIMESTAMP TokenType = "TIMESTAMP" + TOKEN_DATE TokenType = "DATE" + TOKEN_ARRAY TokenType = "ARRAY" + TOKEN_STRUCT TokenType = "STRUCT" + + // COLUMN. + TOKEN_DEFAULT TokenType = "DEFAULT" + TOKEN_NOT TokenType = "NOT" + TOKEN_ASC TokenType = "ASC" + TOKEN_DESC TokenType = "DESC" + TOKEN_OPTIONS TokenType = "OPTIONS" + TOKEN_INTERLEAVE TokenType = "INTERLEAVE" + TOKEN_IN TokenType = "IN" + TOKEN_PARENT TokenType = "PARENT" + TOKEN_CASCADE TokenType = "CASCADE" + TOKEN_NO TokenType = "NO" + TOKEN_ACTION TokenType = "ACTION" + + // CONSTRAINT. + TOKEN_CONSTRAINT TokenType = "CONSTRAINT" + TOKEN_PRIMARY TokenType = "PRIMARY" + TOKEN_KEY TokenType = "KEY" + TOKEN_FOREIGN TokenType = "FOREIGN" + TOKEN_REFERENCES TokenType = "REFERENCES" + TOKEN_UNIQUE TokenType = "UNIQUE" + TOKEN_CHECK TokenType = "CHECK" + + // FUNCTION. + TOKEN_NULLIF TokenType = "NULLIF" + + // VALUE. + TOKEN_NULL TokenType = "NULL" + TOKEN_TRUE TokenType = "TRUE" + TOKEN_FALSE TokenType = "FALSE" + + // LITERAL. + TOKEN_LITERAL TokenType = "LITERAL" + + // IDENTIFIER. + TOKEN_IDENT TokenType = "IDENT" +) + +//nolint:funlen,cyclop,gocognit,gocyclo +func lookupIdent(ident string) TokenType { + token := strings.ToUpper(ident) + // MEMO: bash lexar-gen.sh lexar.go | pbcopy + // START CASES DO NOT EDIT + switch token { + case "EQUAL": + return TOKEN_EQUAL + case "GREATER": + return TOKEN_GREATER + case "LESS": + return TOKEN_LESS + case "CREATE": + return TOKEN_CREATE + case "ALTER": + return TOKEN_ALTER + case "DROP": + return TOKEN_DROP + case "RENAME": + return TOKEN_RENAME + case "TRUNCATE": + return TOKEN_TRUNCATE + case "DELETE": + return TOKEN_DELETE + case "UPDATE": + return TOKEN_UPDATE + case "TABLE": + return TOKEN_TABLE + case "INDEX": + return TOKEN_INDEX + case "VIEW": + return TOKEN_VIEW + case "IF": + return TOKEN_IF + case "EXISTS": + return TOKEN_EXISTS + case "USING": + return TOKEN_USING + case "ON": + return TOKEN_ON + case "TO": + return TOKEN_TO + case "WITH": + return TOKEN_WITH + case "BOOL": + return TOKEN_BOOL + case "INT64": + return TOKEN_INT64 + case "FLOAT64": + return TOKEN_FLOAT64 + case "NUMERIC": + return TOKEN_NUMERIC + case "JSON": + return TOKEN_JSON + case "STRING": + return TOKEN_STRING + case "BYTES": + return TOKEN_BYTES + case "TIMESTAMP": + return TOKEN_TIMESTAMP + case "DATE": + return TOKEN_DATE + case "ARRAY": + return TOKEN_ARRAY + case "STRUCT": + return TOKEN_STRUCT + case "DEFAULT": + return TOKEN_DEFAULT + case "NOT": + return TOKEN_NOT + case "ASC": + return TOKEN_ASC + case "DESC": + return TOKEN_DESC + case "OPTIONS": + return TOKEN_OPTIONS + case "INTERLEAVE": + return TOKEN_INTERLEAVE + case "IN": + return TOKEN_IN + case "PARENT": + return TOKEN_PARENT + case "CASCADE": + return TOKEN_CASCADE + case "NO": + return TOKEN_NO + case "ACTION": + return TOKEN_ACTION + case "CONSTRAINT": + return TOKEN_CONSTRAINT + case "PRIMARY": + return TOKEN_PRIMARY + case "KEY": + return TOKEN_KEY + case "FOREIGN": + return TOKEN_FOREIGN + case "REFERENCES": + return TOKEN_REFERENCES + case "UNIQUE": + return TOKEN_UNIQUE + case "CHECK": + return TOKEN_CHECK + case "NULLIF": + return TOKEN_NULLIF + case "NULL": + return TOKEN_NULL + case "TRUE": + return TOKEN_TRUE + case "FALSE": + return TOKEN_FALSE + default: + return TOKEN_IDENT + } + // END CASES DO NOT EDIT +} + +// Lexer はSQL文をトークンに分割するレキサーです。 +type Lexer struct { + input string + position int // 現在の位置 + readPosition int // 次の位置 + ch byte // 現在の文字 +} + +// NewLexer は新しいLexerを生成します。 +func NewLexer(input string) *Lexer { + l := &Lexer{input: input} + + // 1文字読み込む + l.readChar() + + return l +} + +// readChar は入力から次の文字を読み込みます。 +func (l *Lexer) readChar() { + if l.readPosition >= len(l.input) { + // 終端に達したら0を返す + l.ch = 0 + } else { + // 1文字読み込む + l.ch = l.input[l.readPosition] + } + l.position = l.readPosition + l.readPosition++ +} + +// NextToken は次のトークンを返します。 +// +//nolint:funlen,cyclop +func (l *Lexer) NextToken() Token { + var tok Token + + l.skipWhitespace() + + if l.ch == '-' && l.peekChar() == '-' { + l.skipComment() + return l.NextToken() + } + + switch l.ch { + case '"', '\'', '`': + tok.Type = TOKEN_IDENT + tok.Literal = Literal{Str: l.readQuotedLiteral(l.ch)} + // MEMO: backup + // case '|': + // if l.peekChar() == '|' { + // ch := l.ch + // l.readChar() + // literal := string(ch) + string(l.ch) + // tok = Token{Type: TOKEN_STRING_CONCAT, Literal: Literal{Str: literal}} + // } else { + // tok = newToken(TOKEN_ILLEGAL, l.ch) + // } + // case ':': + // if l.peekChar() == ':' { + // l.readChar() + // if l.peekChar() == ':' { //diff:ignore-line-postgres-cockroach + // l.readChar() //diff:ignore-line-postgres-cockroach + // tok = Token{Type: TOKEN_TYPE_ANNOTATION, Literal: Literal{Str: ":::"}} //diff:ignore-line-postgres-cockroach + // } else { //diff:ignore-line-postgres-cockroach + // tok = Token{Type: TOKEN_TYPECAST, Literal: Literal{Str: "::"}} + // } //diff:ignore-line-postgres-cockroach + // } else { + // tok = newToken(TOKEN_ILLEGAL, l.ch) + // } + case '(': + tok = newToken(TOKEN_OPEN_PAREN, l.ch) + case ')': + tok = newToken(TOKEN_CLOSE_PAREN, l.ch) + case ',': + tok = newToken(TOKEN_COMMA, l.ch) + case ';': + tok = newToken(TOKEN_SEMICOLON, l.ch) + case '=': + tok = newToken(TOKEN_EQUAL, l.ch) + case '>': + tok = newToken(TOKEN_GREATER, l.ch) + case '<': + tok = newToken(TOKEN_LESS, l.ch) + case '+': + tok = newToken(TOKEN_PLUS, l.ch) + case '-': + tok = newToken(TOKEN_MINUS, l.ch) + case '*': + tok = newToken(TOKEN_ASTERISK, l.ch) + case '/': + tok = newToken(TOKEN_SLASH, l.ch) + case 0: + tok.Literal = Literal{} + tok.Type = TOKEN_EOF + default: + if isLiteral(l.ch) { + lit := l.readIdentifier() + tok.Type = lookupIdent(lit) + tok.Literal = Literal{Str: lit} + return tok + } + tok = newToken(TOKEN_ILLEGAL, l.ch) + } + + l.readChar() + return tok +} + +// readQuotedLiteral はクォーテーションで囲まれた文字列を読み込みます。 +func (l *Lexer) readQuotedLiteral(quote byte) string { + // position := l.position + 1 // クォーテーションの次の文字から開始 + position := l.position // クォーテーションの文字から開始 + for { + l.readChar() + if l.ch == quote || l.ch == 0 { + break + } + } + return l.input[position : l.position+1] +} + +// peekChar は次の文字を覗き見ますが、現在の位置は進めません。 +func (l *Lexer) peekChar() byte { + if l.readPosition >= len(l.input) { + return 0 + } + return l.input[l.readPosition] +} + +func newToken(tokenType TokenType, ch byte) Token { + return Token{Type: tokenType, Literal: Literal{Str: string(ch)}} +} + +func (l *Lexer) readIdentifier() string { + position := l.position + for isLiteral(l.ch) { + l.readChar() + } + str := l.input[position:l.position] + + return str +} + +func isLiteral(ch byte) bool { + return 'A' <= ch && ch <= 'Z' || + 'a' <= ch && ch <= 'z' || + '0' <= ch && ch <= '9' || + ch == '_' || + ch == '.' +} + +func (l *Lexer) skipWhitespace() (skipped bool) { + for l.ch == ' ' || l.ch == '\t' || l.ch == '\n' || l.ch == '\r' { + skipped = true || skipped + l.readChar() + } + return skipped +} + +func (l *Lexer) skipComment() { + for l.ch != '\n' && l.ch != 0 { + l.readChar() + } +} diff --git a/pkg/ddl/spanner/lexar_test.go b/pkg/ddl/spanner/lexar_test.go new file mode 100644 index 0000000..35bf6cf --- /dev/null +++ b/pkg/ddl/spanner/lexar_test.go @@ -0,0 +1,286 @@ +package spanner + +import ( + "math" + "testing" + + "github.com/kunitsucom/util.go/testing/require" +) + +func Test_lookupIdent(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + want TokenType + }{ + {name: "success,EQUAL", input: "EQUAL", want: TOKEN_EQUAL}, + {name: "success,GREATER", input: "GREATER", want: TOKEN_GREATER}, + {name: "success,LESS", input: "LESS", want: TOKEN_LESS}, + {name: "success,CREATE", input: "CREATE", want: TOKEN_CREATE}, + {name: "success,ALTER", input: "ALTER", want: TOKEN_ALTER}, + {name: "success,DROP", input: "DROP", want: TOKEN_DROP}, + {name: "success,RENAME", input: "RENAME", want: TOKEN_RENAME}, + {name: "success,CREATE", input: "CREATE", want: TOKEN_CREATE}, + {name: "success,ALTER", input: "ALTER", want: TOKEN_ALTER}, + {name: "success,DROP", input: "DROP", want: TOKEN_DROP}, + {name: "success,RENAME", input: "RENAME", want: TOKEN_RENAME}, + {name: "success,TRUNCATE", input: "TRUNCATE", want: TOKEN_TRUNCATE}, + {name: "success,DELETE", input: "DELETE", want: TOKEN_DELETE}, + {name: "success,UPDATE", input: "UPDATE", want: TOKEN_UPDATE}, + {name: "success,TABLE", input: "TABLE", want: TOKEN_TABLE}, + {name: "success,INDEX", input: "INDEX", want: TOKEN_INDEX}, + {name: "success,VIEW", input: "VIEW", want: TOKEN_VIEW}, + {name: "success,IF", input: "IF", want: TOKEN_IF}, + {name: "success,EXISTS", input: "EXISTS", want: TOKEN_EXISTS}, + {name: "success,ON", input: "ON", want: TOKEN_ON}, + {name: "success,TO", input: "TO", want: TOKEN_TO}, + {name: "success,WITH", input: "WITH", want: TOKEN_WITH}, + {name: "success,BOOL", input: "BOOL", want: TOKEN_BOOL}, + {name: "success,NUMERIC", input: "NUMERIC", want: TOKEN_NUMERIC}, + {name: "success,FLOAT64", input: "FLOAT64", want: TOKEN_FLOAT64}, + {name: "success,JSON", input: "JSON", want: TOKEN_JSON}, + {name: "success,STRING", input: "STRING", want: TOKEN_STRING}, + {name: "success,BYTES", input: "BYTES", want: TOKEN_BYTES}, + {name: "success,TIMESTAMP", input: "TIMESTAMP", want: TOKEN_TIMESTAMP}, + {name: "success,DATE", input: "DATE", want: TOKEN_DATE}, + {name: "success,ARRAY", input: "ARRAY", want: TOKEN_ARRAY}, + {name: "success,STRUCT", input: "STRUCT", want: TOKEN_STRUCT}, + {name: "success,DEFAULT", input: "DEFAULT", want: TOKEN_DEFAULT}, + {name: "success,NOT", input: "NOT", want: TOKEN_NOT}, + {name: "success,NULL", input: "NULL", want: TOKEN_NULL}, + {name: "success,ASC", input: "ASC", want: TOKEN_ASC}, + {name: "success,DESC", input: "DESC", want: TOKEN_DESC}, + {name: "success,CASCADE", input: "CASCADE", want: TOKEN_CASCADE}, + {name: "success,CONSTRAINT", input: "CONSTRAINT", want: TOKEN_CONSTRAINT}, + {name: "success,PRIMARY", input: "PRIMARY", want: TOKEN_PRIMARY}, + {name: "success,KEY", input: "KEY", want: TOKEN_KEY}, + {name: "success,FOREIGN", input: "FOREIGN", want: TOKEN_FOREIGN}, + {name: "success,REFERENCES", input: "REFERENCES", want: TOKEN_REFERENCES}, + {name: "success,UNIQUE", input: "UNIQUE", want: TOKEN_UNIQUE}, + {name: "success,CHECK", input: "CHECK", want: TOKEN_CHECK}, + {name: "success,NULLIF", input: "NULLIF", want: TOKEN_NULLIF}, + {name: "success,IDENT", input: "users", want: TOKEN_IDENT}, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got := lookupIdent(tt.input) + + if !require.Equal(t, tt.want, got) { + t.FailNow() + } + }) + } +} + +func TestLex(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + want []Token + }{ + { + name: "success,CREATE_TABLE", + input: `CREATE TABLE IF NOT EXISTS "users" ( + "user_id" STRING(36) NOT NULL, + "name" STRING(255) NOT NULL, + "email" STRING(255) NOT NULL, + "password" STRING(255) NOT NULL, + "created_at" TIMESTAMP NOT NULL, + "updated_at" TIMESTAMP NOT NULL, + PRIMARY KEY ("user_id"), + UNIQUE ("email") +);`, + want: []Token{ + {Type: TOKEN_CREATE, Literal: Literal{Str: "CREATE"}}, + {Type: TOKEN_TABLE, Literal: Literal{Str: "TABLE"}}, + {Type: TOKEN_IF, Literal: Literal{Str: "IF"}}, + {Type: TOKEN_NOT, Literal: Literal{Str: "NOT"}}, + {Type: TOKEN_EXISTS, Literal: Literal{Str: "EXISTS"}}, + {Type: TOKEN_IDENT, Literal: Literal{Str: `"users"`}}, + {Type: TOKEN_OPEN_PAREN, Literal: Literal{Str: "("}}, + {Type: TOKEN_IDENT, Literal: Literal{Str: `"user_id"`}}, + {Type: TOKEN_STRING, Literal: Literal{Str: "STRING"}}, + {Type: TOKEN_OPEN_PAREN, Literal: Literal{Str: "("}}, + {Type: TOKEN_IDENT, Literal: Literal{Str: "36"}}, + {Type: TOKEN_CLOSE_PAREN, Literal: Literal{Str: ")"}}, + {Type: TOKEN_NOT, Literal: Literal{Str: "NOT"}}, + {Type: TOKEN_NULL, Literal: Literal{Str: "NULL"}}, + {Type: TOKEN_COMMA, Literal: Literal{Str: ","}}, + {Type: TOKEN_IDENT, Literal: Literal{Str: `"name"`}}, + {Type: TOKEN_STRING, Literal: Literal{Str: "STRING"}}, + {Type: TOKEN_OPEN_PAREN, Literal: Literal{Str: "("}}, + {Type: TOKEN_IDENT, Literal: Literal{Str: "255"}}, + {Type: TOKEN_CLOSE_PAREN, Literal: Literal{Str: ")"}}, + {Type: TOKEN_NOT, Literal: Literal{Str: "NOT"}}, + {Type: TOKEN_NULL, Literal: Literal{Str: "NULL"}}, + {Type: TOKEN_COMMA, Literal: Literal{Str: ","}}, + {Type: TOKEN_IDENT, Literal: Literal{Str: `"email"`}}, + {Type: TOKEN_STRING, Literal: Literal{Str: "STRING"}}, + {Type: TOKEN_OPEN_PAREN, Literal: Literal{Str: "("}}, + {Type: TOKEN_IDENT, Literal: Literal{Str: "255"}}, + {Type: TOKEN_CLOSE_PAREN, Literal: Literal{Str: ")"}}, + {Type: TOKEN_NOT, Literal: Literal{Str: "NOT"}}, + {Type: TOKEN_NULL, Literal: Literal{Str: "NULL"}}, + {Type: TOKEN_COMMA, Literal: Literal{Str: ","}}, + {Type: TOKEN_IDENT, Literal: Literal{Str: `"password"`}}, + {Type: TOKEN_STRING, Literal: Literal{Str: "STRING"}}, + {Type: TOKEN_OPEN_PAREN, Literal: Literal{Str: "("}}, + {Type: TOKEN_IDENT, Literal: Literal{Str: "255"}}, + {Type: TOKEN_CLOSE_PAREN, Literal: Literal{Str: ")"}}, + {Type: TOKEN_NOT, Literal: Literal{Str: "NOT"}}, + {Type: TOKEN_NULL, Literal: Literal{Str: "NULL"}}, + {Type: TOKEN_COMMA, Literal: Literal{Str: ","}}, + {Type: TOKEN_IDENT, Literal: Literal{Str: `"created_at"`}}, + {Type: TOKEN_TIMESTAMP, Literal: Literal{Str: "TIMESTAMP"}}, + {Type: TOKEN_NOT, Literal: Literal{Str: "NOT"}}, + {Type: TOKEN_NULL, Literal: Literal{Str: "NULL"}}, + {Type: TOKEN_COMMA, Literal: Literal{Str: ","}}, + {Type: TOKEN_IDENT, Literal: Literal{Str: `"updated_at"`}}, + {Type: TOKEN_TIMESTAMP, Literal: Literal{Str: "TIMESTAMP"}}, + {Type: TOKEN_NOT, Literal: Literal{Str: "NOT"}}, + {Type: TOKEN_NULL, Literal: Literal{Str: "NULL"}}, + {Type: TOKEN_COMMA, Literal: Literal{Str: ","}}, + {Type: TOKEN_PRIMARY, Literal: Literal{Str: "PRIMARY"}}, + {Type: TOKEN_KEY, Literal: Literal{Str: "KEY"}}, + {Type: TOKEN_OPEN_PAREN, Literal: Literal{Str: "("}}, + {Type: TOKEN_IDENT, Literal: Literal{Str: `"user_id"`}}, + {Type: TOKEN_CLOSE_PAREN, Literal: Literal{Str: ")"}}, + {Type: TOKEN_COMMA, Literal: Literal{Str: ","}}, + {Type: TOKEN_UNIQUE, Literal: Literal{Str: "UNIQUE"}}, + {Type: TOKEN_OPEN_PAREN, Literal: Literal{Str: "("}}, + {Type: TOKEN_IDENT, Literal: Literal{Str: `"email"`}}, + {Type: TOKEN_CLOSE_PAREN, Literal: Literal{Str: ")"}}, + {Type: TOKEN_CLOSE_PAREN, Literal: Literal{Str: ")"}}, + {Type: TOKEN_SEMICOLON, Literal: Literal{Str: ";"}}, + }, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + l := NewLexer(tt.input) + got := make([]Token, 0) + for { + tok := l.NextToken() + if tok.Type == TOKEN_EOF { + break + } + got = append(got, tok) + } + + if !require.Equal(t, tt.want, got) { + t.FailNow() + } + + for i := range got { + if !require.Equal(t, got[i].Type, tt.want[i].Type) { + t.Fail() + } + + if !require.Equal(t, got[i].Literal, tt.want[i].Literal) { + t.Fail() + } + } + }) + } +} + +func TestLexer_NextToken(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + want Token + }{ + { + name: "failure,|", + input: `|`, + want: Token{ + Type: TOKEN_ILLEGAL, + Literal: Literal{Str: "|"}, + }, + }, + { + name: "failure,:", + input: `:`, + want: Token{ + Type: TOKEN_ILLEGAL, + Literal: Literal{Str: ":"}, + }, + }, + { + name: "failure,!", + input: `!`, + want: Token{ + Type: TOKEN_ILLEGAL, + Literal: Literal{Str: "!"}, + }, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + l := NewLexer(tt.input) + got := l.NextToken() + + if !require.Equal(t, tt.want, got) { + t.FailNow() + } + }) + } +} + +func TestLexer_peekChar(t *testing.T) { + t.Parallel() + + t.Run("success,peekChar", func(t *testing.T) { + t.Parallel() + + l := NewLexer("") + l.readPosition = math.MaxInt64 + expected := byte(0) + actual := l.peekChar() + + require.Equal(t, expected, actual) + }) +} + +func TestLiteral(t *testing.T) { + t.Parallel() + + t.Run("success,String", func(t *testing.T) { + t.Parallel() + + literal := Literal{Str: "users"} + expected := literal.Str + actual := literal.String() + + require.Equal(t, expected, actual) + }) + + t.Run("success,PlainString", func(t *testing.T) { + t.Parallel() + + literal := Literal{Str: "users"} + expected := literal.Str + actual := literal.StringForDiff() + + require.Equal(t, expected, actual) + }) +} diff --git a/pkg/ddl/spanner/parser.go b/pkg/ddl/spanner/parser.go new file mode 100644 index 0000000..f7bf59f --- /dev/null +++ b/pkg/ddl/spanner/parser.go @@ -0,0 +1,774 @@ +package spanner + +// MEMO: https://www.postgresql.org/docs/current/ddl-constraints.html +// MEMO: https://www.postgresql.jp/docs/11/ddl-constraints.html + +import ( + "fmt" + "runtime" + "strings" + + filepathz "github.com/kunitsucom/util.go/path/filepath" + stringz "github.com/kunitsucom/util.go/strings" + + apperr "github.com/kunitsucom/ddlctl/pkg/apperr" + "github.com/kunitsucom/ddlctl/pkg/ddl" + "github.com/kunitsucom/ddlctl/pkg/ddl/logs" +) + +//nolint:gochecknoglobals +var quotationMarks = []string{`"`, "`", "'"} + +func NewRawIdent(raw string) *Ident { + for _, q := range quotationMarks { + if strings.HasPrefix(raw, q) && strings.HasSuffix(raw, q) { + return &Ident{ + Name: strings.Trim(raw, q), + QuotationMark: q, + Raw: raw, + } + } + } + + return &Ident{ + Name: raw, + QuotationMark: "", + Raw: raw, + } +} + +func NewIdent(name, quotationMark, raw string) *Ident { + return &Ident{ + Name: name, + QuotationMark: quotationMark, + Raw: raw, + } +} + +// Parser はSQL文を解析するパーサーです。 +type Parser struct { + l *Lexer + currentToken Token + peekToken Token +} + +// NewParser は新しいParserを生成します。 +func NewParser(l *Lexer) *Parser { + p := &Parser{ + l: l, + } + + return p +} + +// nextToken は次のトークンを読み込みます。 +func (p *Parser) nextToken() { + p.currentToken = p.peekToken + p.peekToken = p.l.NextToken() + + _, file, line, _ := runtime.Caller(1) + logs.TraceLog.Printf("🪲: nextToken: caller=%s:%d currentToken: %#v, peekToken: %#v", filepathz.Short(file), line, p.currentToken, p.peekToken) +} + +// Parse はSQL文を解析します。 +func (p *Parser) Parse() (*DDL, error) { //nolint:ireturn + p.nextToken() // current = "" + p.nextToken() // current = CREATE or ALTER or ... + + d := &DDL{} + +LabelDDL: + for { + switch p.currentToken.Type { //nolint:exhaustive + case TOKEN_CREATE: + stmt, err := p.parseCreateStatement() + if err != nil { + return nil, apperr.Errorf("parseCreateStatement: %w", err) + } + d.Stmts = append(d.Stmts, stmt) + case TOKEN_CLOSE_PAREN: + // do nothing + case TOKEN_SEMICOLON: + // do nothing + case TOKEN_EOF: + break LabelDDL + default: + return nil, apperr.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + } + + p.nextToken() + } + return d, nil +} + +func (p *Parser) parseCreateStatement() (Stmt, error) { //nolint:ireturn + p.nextToken() // current = TABLE or INDEX or ... + + switch p.currentToken.Type { //nolint:exhaustive + case TOKEN_TABLE: + return p.parseCreateTableStmt() + case TOKEN_INDEX, TOKEN_UNIQUE: + return p.parseCreateIndexStmt() + default: + return nil, apperr.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + } +} + +//nolint:cyclop,funlen,gocognit,gocyclo +func (p *Parser) parseCreateTableStmt() (*CreateTableStmt, error) { + createTableStmt := &CreateTableStmt{ + Indent: Indent, + } + + if p.isPeekToken(TOKEN_IF) { + p.nextToken() // current = IF + if err := p.checkPeekToken(TOKEN_NOT); err != nil { + return nil, apperr.Errorf("checkPeekToken: %w", err) + } + p.nextToken() // current = NOT + if err := p.checkPeekToken(TOKEN_EXISTS); err != nil { + return nil, apperr.Errorf("checkPeekToken: %w", err) + } + p.nextToken() // current = EXISTS + createTableStmt.IfNotExists = true + } + + p.nextToken() // current = table_name + if err := p.checkCurrentToken(TOKEN_IDENT); err != nil { + return nil, apperr.Errorf("checkCurrentToken: %w", err) + } + + createTableStmt.Name = NewObjectName(p.currentToken.Literal.Str) + errFmtPrefix := fmt.Sprintf("table_name=%s: ", createTableStmt.Name.StringForDiff()) + + p.nextToken() // current = ( + + if err := p.checkCurrentToken(TOKEN_OPEN_PAREN); err != nil { + return nil, apperr.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) + } + + p.nextToken() // current = column_name + +LabelColumns: + for { + switch { //nolint:exhaustive + case p.isCurrentToken(TOKEN_IDENT): + column, constraints, err := p.parseColumn(createTableStmt.Name.Name) + if err != nil { + return nil, apperr.Errorf(errFmtPrefix+"parseColumn: %w", err) + } + createTableStmt.Columns = append(createTableStmt.Columns, column) + if len(constraints) > 0 { + for _, c := range constraints { + createTableStmt.Constraints = createTableStmt.Constraints.Append(c) + } + } + case isConstraint(p.currentToken.Type): + constraint, err := p.parseTableConstraint(createTableStmt.Name.Name) + if err != nil { + return nil, apperr.Errorf(errFmtPrefix+"parseConstraint: %w", err) + } + createTableStmt.Constraints = createTableStmt.Constraints.Append(constraint) + case p.isCurrentToken(TOKEN_COMMA): + p.nextToken() + continue + case p.isCurrentToken(TOKEN_CLOSE_PAREN): + p.nextToken() + break LabelColumns + default: + return nil, apperr.Errorf(errFmtPrefix+"currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + } + } + +LabelTableOptions: + for { + switch p.currentToken.Type { //nolint:exhaustive + case TOKEN_PRIMARY: + opt := &Option{} + p.nextToken() // current = KEY + if err := p.checkCurrentToken(TOKEN_KEY); err != nil { + return nil, apperr.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) + } + opt.Name = "PRIMARY KEY" + p.nextToken() // current = `(` + idents, err := p.parseExpr() + if err != nil { + return nil, apperr.Errorf(errFmtPrefix+"parseExpr: %w", err) + } + opt.Value = opt.Value.Append(idents...) + createTableStmt.Options = append(createTableStmt.Options, opt) + continue + case TOKEN_INTERLEAVE: + opt := &Option{} + p.nextToken() // current = IN + if err := p.checkCurrentToken(TOKEN_IN); err != nil { + return nil, apperr.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) + } + p.nextToken() // current = PARENT + if err := p.checkCurrentToken(TOKEN_PARENT); err != nil { + return nil, apperr.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) + } + opt.Name = "INTERLEAVE IN PARENT" + p.nextToken() // current = table_name + if err := p.checkCurrentToken(TOKEN_IDENT); err != nil { + return nil, apperr.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) + } + opt.Value = opt.Value.Append(NewRawIdent(p.currentToken.Literal.String())) + if p.isPeekToken(TOKEN_ON) { + p.nextToken() // current = ON + p.nextToken() // current = DELETE + if err := p.checkCurrentToken(TOKEN_DELETE); err != nil { + return nil, apperr.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) + } + onAction := "ON DELETE" + p.nextToken() // current = CASCADE or NO + switch p.currentToken.Type { //nolint:exhaustive + case TOKEN_CASCADE: + onAction += " CASCADE" + case TOKEN_NO: + p.nextToken() // current = ACTION + if err := p.checkCurrentToken(TOKEN_ACTION); err != nil { + return nil, apperr.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) + } + onAction += " NO ACTION" + default: + return nil, apperr.Errorf(errFmtPrefix+"currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + } + opt.Value = opt.Value.Append(NewRawIdent(onAction)) + } + createTableStmt.Options = append(createTableStmt.Options, opt) + case TOKEN_COMMA: + // do nothing + case TOKEN_SEMICOLON, TOKEN_EOF: + break LabelTableOptions + default: + return nil, apperr.Errorf(errFmtPrefix+"peekToken=%#v: %w", p.peekToken, ddl.ErrUnexpectedToken) + } + p.nextToken() + } + + return createTableStmt, nil +} + +//nolint:cyclop,funlen +func (p *Parser) parseCreateIndexStmt() (*CreateIndexStmt, error) { + createIndexStmt := &CreateIndexStmt{} + + if p.isCurrentToken(TOKEN_UNIQUE) { + createIndexStmt.Unique = true + p.nextToken() // current = INDEX + } + + if p.isPeekToken(TOKEN_IF) { + p.nextToken() // current = IF + if err := p.checkPeekToken(TOKEN_NOT); err != nil { + return nil, apperr.Errorf("checkPeekToken: %w", err) + } + p.nextToken() // current = NOT + if err := p.checkPeekToken(TOKEN_EXISTS); err != nil { + return nil, apperr.Errorf("checkPeekToken: %w", err) + } + p.nextToken() // current = EXISTS + createIndexStmt.IfNotExists = true + } + + p.nextToken() // current = index_name + if err := p.checkCurrentToken(TOKEN_IDENT); err != nil { + return nil, apperr.Errorf("checkCurrentToken: %w", err) + } + + createIndexStmt.Name = NewObjectName(p.currentToken.Literal.Str) + errFmtPrefix := fmt.Sprintf("index_name=%s: ", createIndexStmt.Name.StringForDiff()) + + p.nextToken() // current = ON + + if err := p.checkCurrentToken(TOKEN_ON); err != nil { + return nil, apperr.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) + } + + p.nextToken() // current = table_name + + if err := p.checkCurrentToken(TOKEN_IDENT); err != nil { + return nil, apperr.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) + } + + createIndexStmt.TableName = NewObjectName(p.currentToken.Literal.Str) + + p.nextToken() // current = USING or ( + + if p.isCurrentToken(TOKEN_USING) { + p.nextToken() // current = using_def + createIndexStmt.Using = append(createIndexStmt.Using, NewIdent(p.currentToken.Literal.Str, "", p.currentToken.Literal.Str)) + p.nextToken() // current = ( + } + + if err := p.checkCurrentToken(TOKEN_OPEN_PAREN); err != nil { + return nil, apperr.Errorf(errFmtPrefix+"checkCurrentToken: %w", err) + } + + idents, err := p.parseColumnIdents() + if err != nil { + return nil, apperr.Errorf(errFmtPrefix+"parseColumnIdents: %w", err) + } + + createIndexStmt.Columns = idents + + return createIndexStmt, nil +} + +//nolint:funlen,cyclop +func (p *Parser) parseColumn(tableName *Ident) (*Column, []Constraint, error) { + column := &Column{} + constraints := make(Constraints, 0) + + if err := p.checkCurrentToken(TOKEN_IDENT); err != nil { + return nil, nil, apperr.Errorf("checkCurrentToken: %w", err) + } + + column.Name = NewRawIdent(p.currentToken.Literal.Str) + errFmtPrefix := fmt.Sprintf("column_name=%s: ", column.Name.StringForDiff()) + + p.nextToken() // current = DATA_TYPE + + switch { //nolint:exhaustive + case isDataType(p.currentToken.Type): + dataType, err := p.parseDataType() + if err != nil { + return nil, nil, apperr.Errorf(errFmtPrefix+"parseDataType: %w", err) + } + column.DataType = dataType + + p.nextToken() // current = DEFAULT or NOT or NULL or PRIMARY or UNIQUE or COMMA or ... + LabelDefaultNotNull: + for { + switch p.currentToken.Type { //nolint:exhaustive + case TOKEN_NOT: + if err := p.checkPeekToken(TOKEN_NULL); err != nil { + return nil, nil, apperr.Errorf(errFmtPrefix+"checkPeekToken: %w", err) + } + p.nextToken() // current = NULL + column.NotNull = true + case TOKEN_NULL: + column.NotNull = false + case TOKEN_DEFAULT: // current = DEFAULT + p.nextToken() // current = value + def, err := p.parseColumnDefault() + if err != nil { + return nil, nil, apperr.Errorf(errFmtPrefix+"parseColumnDefault: %w", err) + } + column.Default = def + continue + default: + break LabelDefaultNotNull + } + + p.nextToken() + } + + cs, err := p.parseColumnConstraints(tableName, column) + if err != nil { + return nil, nil, apperr.Errorf(errFmtPrefix+"parseColumnConstraints: %w", err) + } + if len(cs) > 0 { + for _, c := range cs { + constraints = constraints.Append(c) + } + } + + if p.isCurrentToken(TOKEN_OPTIONS) { + p.nextToken() // current = ( + idents, err := p.parseExpr() + if err != nil { + return nil, nil, apperr.Errorf(errFmtPrefix+"parseExpr: %w", err) + } + column.Options = column.Options.Append(idents...) + } + default: + return nil, nil, apperr.Errorf(errFmtPrefix+"currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + } + + return column, constraints, nil +} + +//nolint:cyclop +func (p *Parser) parseColumnDefault() (*Default, error) { + def := &Default{} + +LabelDefault: + for { + switch p.currentToken.Type { //nolint:exhaustive + case TOKEN_IDENT: + def.Value = def.Value.Append(NewRawIdent(p.currentToken.Literal.String())) + case TOKEN_OPEN_PAREN: + ids, err := p.parseExpr() + if err != nil { + return nil, apperr.Errorf("parseExpr: %w", err) + } + def.Value = def.Value.Append(ids...) + continue + case TOKEN_NOT, TOKEN_NULL, TOKEN_COMMA, TOKEN_CLOSE_PAREN: + break LabelDefault + default: + if isReservedValue(p.currentToken.Type) { + def.Value = def.Value.Append(NewIdent(string(p.currentToken.Type), "", p.currentToken.Literal.String())) + p.nextToken() + continue + } + // MEMO: backup + // TODO: check if this is necessary + // if isOperator(p.currentToken.Type) { + // def.Value = def.Value.Append(NewRawIdent(p.currentToken.Literal.Str)) + // p.nextToken() + // continue + // } + // if isDataType(p.currentToken.Type) { + // def.Value.Idents = append(def.Value.Idents, NewRawIdent(p.currentToken.Literal.Str)) + // p.nextToken() + // continue + // } + if isConstraint(p.currentToken.Type) { + break LabelDefault + } + return nil, apperr.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + } + + p.nextToken() + } + + return def, nil +} + +//nolint:cyclop +func (p *Parser) parseExpr() ([]*Ident, error) { + idents := make([]*Ident, 0) + + if err := p.checkCurrentToken(TOKEN_OPEN_PAREN); err != nil { + return nil, apperr.Errorf("checkCurrentToken: %w", err) + } + idents = append(idents, NewRawIdent(p.currentToken.Literal.Str)) + p.nextToken() // current = IDENT + +LabelExpr: + for { + switch p.currentToken.Type { //nolint:exhaustive + case TOKEN_OPEN_PAREN: + ids, err := p.parseExpr() + if err != nil { + return nil, apperr.Errorf("parseExpr: %w", err) + } + idents = append(idents, ids...) + continue + case TOKEN_CLOSE_PAREN: + idents = append(idents, NewRawIdent(p.currentToken.Literal.Str)) + p.nextToken() + break LabelExpr + case TOKEN_EQUAL, TOKEN_GREATER, TOKEN_LESS: + value := p.currentToken.Literal.Str + switch p.peekToken.Type { //nolint:exhaustive + case TOKEN_EQUAL, TOKEN_GREATER, TOKEN_LESS: + value += p.peekToken.Literal.Str + p.nextToken() + } + idents = append(idents, NewRawIdent(value)) + case TOKEN_EOF: + return nil, apperr.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + default: + if isReservedValue(p.currentToken.Type) { + idents = append(idents, NewRawIdent(p.currentToken.Type.String())) + } else { + idents = append(idents, NewRawIdent(p.currentToken.Literal.Str)) + } + } + + p.nextToken() + } + + return idents, nil +} + +//nolint:cyclop,funlen,gocognit +func (p *Parser) parseColumnConstraints(tableName *Ident, column *Column) ([]Constraint, error) { + constraints := make(Constraints, 0) + +LabelConstraints: + for { + switch p.currentToken.Type { //nolint:exhaustive + case TOKEN_REFERENCES: + if err := p.checkPeekToken(TOKEN_IDENT); err != nil { + return nil, apperr.Errorf("checkPeekToken: %w", err) + } + p.nextToken() // current = table_name + constraint := &ForeignKeyConstraint{ + Name: NewRawIdent(fmt.Sprintf("%s_%s_fkey", tableName.StringForDiff(), column.Name.StringForDiff())), + Ref: NewRawIdent(p.currentToken.Literal.Str), + Columns: []*ColumnIdent{{Ident: column.Name}}, + } + p.nextToken() // current = ( + idents, err := p.parseColumnIdents() + if err != nil { + return nil, apperr.Errorf("parseColumnIdents: %w", err) + } + constraint.RefColumns = idents + constraints = constraints.Append(constraint) + case TOKEN_CHECK: + if err := p.checkPeekToken(TOKEN_OPEN_PAREN); err != nil { + return nil, apperr.Errorf("checkPeekToken: %w", err) + } + p.nextToken() // current = ( + constraint := &CheckConstraint{ + Name: NewRawIdent(fmt.Sprintf("%s_%s_check", tableName.StringForDiff(), column.Name.StringForDiff())), + } + idents, err := p.parseExpr() + if err != nil { + return nil, apperr.Errorf("parseExpr: %w", err) + } + constraint.Expr = constraint.Expr.Append(idents...) + constraints = constraints.Append(constraint) + case TOKEN_OPTIONS, TOKEN_IDENT, TOKEN_COMMA, TOKEN_CLOSE_PAREN: + break LabelConstraints + default: + return nil, apperr.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + } + + p.nextToken() + } + + return constraints, nil +} + +//nolint:funlen,cyclop,gocognit +func (p *Parser) parseTableConstraint(tableName *Ident) (Constraint, error) { //nolint:ireturn + var constraintName *Ident + if p.isCurrentToken(TOKEN_CONSTRAINT) { + p.nextToken() // current = constraint_name + if p.currentToken.Type != TOKEN_IDENT { + return nil, apperr.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + } + constraintName = NewRawIdent(p.currentToken.Literal.Str) + p.nextToken() // current = PRIMARY or CHECK //diff:ignore-line-postgres-cockroach + } + + switch p.currentToken.Type { //nolint:exhaustive + case TOKEN_FOREIGN: + if err := p.checkPeekToken(TOKEN_KEY); err != nil { + return nil, apperr.Errorf("checkPeekToken: %w", err) + } + p.nextToken() // current = KEY + if err := p.checkPeekToken(TOKEN_OPEN_PAREN); err != nil { + return nil, apperr.Errorf("checkPeekToken: %w", err) + } + p.nextToken() // current = ( + idents, err := p.parseColumnIdents() + if err != nil { + return nil, apperr.Errorf("parseColumnIdents: %w", err) + } + if err := p.checkCurrentToken(TOKEN_REFERENCES); err != nil { + return nil, apperr.Errorf("checkPeekToken: %w", err) + } + p.nextToken() // current = ref_table_name + if err := p.checkCurrentToken(TOKEN_IDENT); err != nil { + return nil, apperr.Errorf("checkCurrentToken: %w", err) + } + refName := NewRawIdent(p.currentToken.Literal.Str) + + p.nextToken() // current = ( + identsRef, err := p.parseColumnIdents() + if err != nil { + return nil, apperr.Errorf("parseColumnIdents: %w", err) + } + if constraintName == nil { + name := tableName.StringForDiff() + for _, ident := range idents { + name += fmt.Sprintf("_%s", ident.StringForDiff()) + } + name += "_fkey" + constraintName = NewRawIdent(name) + } + return &ForeignKeyConstraint{ + Name: constraintName, + Columns: idents, + Ref: refName, + RefColumns: identsRef, + }, nil + case TOKEN_CHECK: + constraint := &CheckConstraint{ + Name: constraintName, + } + if err := p.checkPeekToken(TOKEN_OPEN_PAREN); err != nil { + return nil, apperr.Errorf("checkPeekToken: %w", err) + } + p.nextToken() // current = ( + idents, err := p.parseExpr() + if err != nil { + return nil, apperr.Errorf("parseExpr: %w", err) + } + constraint.Name = constraintName + constraint.Expr = constraint.Expr.Append(idents...) + return constraint, nil + default: + return nil, apperr.Errorf("currentToken=%s: %w", p.currentToken.Type, ddl.ErrUnexpectedToken) + } +} + +//nolint:cyclop,funlen +func (p *Parser) parseDataType() (*DataType, error) { + dataType := &DataType{ + Name: p.currentToken.Literal.String(), + Type: p.currentToken.Type, + } + + // TODO: support ARRAY, STRUCT + + if p.isPeekToken(TOKEN_OPEN_PAREN) { + p.nextToken() // current = ( + idents, err := p.parseIdents() + if err != nil { + return nil, apperr.Errorf("parseIdents: %w", err) + } + dataType.Expr = dataType.Expr.Append(idents...) + } + + return dataType, nil +} + +func (p *Parser) parseColumnIdents() ([]*ColumnIdent, error) { + idents := make([]*ColumnIdent, 0) + +LabelIdents: + for { + switch p.currentToken.Type { //nolint:exhaustive + case TOKEN_OPEN_PAREN: + // do nothing + case TOKEN_IDENT: + ident := &ColumnIdent{Ident: NewRawIdent(p.currentToken.Literal.Str)} + switch p.peekToken.Type { //nolint:exhaustive //diff:ignore-line-postgres-cockroach + case TOKEN_ASC: //diff:ignore-line-postgres-cockroach + ident.Order = &Order{Desc: false} //diff:ignore-line-postgres-cockroach + p.nextToken() // current = ASC //diff:ignore-line-postgres-cockroach + case TOKEN_DESC: //diff:ignore-line-postgres-cockroach + ident.Order = &Order{Desc: true} //diff:ignore-line-postgres-cockroach + p.nextToken() // current = DESC //diff:ignore-line-postgres-cockroach + } //diff:ignore-line-postgres-cockroach + idents = append(idents, ident) + case TOKEN_COMMA: + // do nothing + case TOKEN_CLOSE_PAREN: + p.nextToken() + break LabelIdents + default: + return nil, apperr.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + } + p.nextToken() + } + + return idents, nil +} + +func (p *Parser) parseIdents() ([]*Ident, error) { + idents := make([]*Ident, 0) + +LabelIdents: + for { + switch p.currentToken.Type { //nolint:exhaustive + case TOKEN_OPEN_PAREN: + // do nothing + case TOKEN_CLOSE_PAREN: + break LabelIdents + case TOKEN_EOF, TOKEN_ILLEGAL: + return nil, apperr.Errorf("currentToken=%#v: %w", p.currentToken, ddl.ErrUnexpectedToken) + default: + idents = append(idents, NewRawIdent(p.currentToken.Literal.String())) + } + p.nextToken() + } + + return idents, nil +} + +// MEMO: backup +// TODO: check if this is necessary +// func isOperator(tokenType TokenType) bool { +// switch tokenType { //nolint:exhaustive +// case TOKEN_EQUAL, TOKEN_GREATER, TOKEN_LESS, +// TOKEN_PLUS, TOKEN_MINUS, TOKEN_ASTERISK, TOKEN_SLASH, +// TOKEN_TYPE_ANNOTATION, //diff:ignore-line-postgres-cockroach +// TOKEN_STRING_CONCAT, TOKEN_TYPECAST: +// return true +// default: +// return false +// } +// } + +func isReservedValue(tokenType TokenType) bool { + switch tokenType { //nolint:exhaustive + case TOKEN_NULL, TOKEN_TRUE, TOKEN_FALSE: + return true + default: + return false + } +} + +func isDataType(tokenType TokenType) bool { + switch tokenType { //nolint:exhaustive + case TOKEN_BOOL, + TOKEN_INT64, + TOKEN_NUMERIC, + TOKEN_FLOAT64, + TOKEN_JSON, + TOKEN_STRING, + TOKEN_TIMESTAMP: + return true + default: + return false + } +} + +func isConstraint(tokenType TokenType) bool { + switch tokenType { //nolint:exhaustive + case TOKEN_CONSTRAINT, + TOKEN_INDEX, + TOKEN_PRIMARY, TOKEN_KEY, + TOKEN_FOREIGN, TOKEN_REFERENCES, + TOKEN_UNIQUE, + TOKEN_CHECK: + return true + default: + return false + } +} + +func (p *Parser) isCurrentToken(expectedTypes ...TokenType) bool { + for _, expected := range expectedTypes { + if expected == p.currentToken.Type { + return true + } + } + return false +} + +func (p *Parser) checkCurrentToken(expectedTypes ...TokenType) error { + for _, expected := range expectedTypes { + if expected == p.currentToken.Type { + return nil + } + } + return apperr.Errorf("currentToken: expected=%s, but got=%#v: %w", stringz.JoinStringers(",", expectedTypes...), p.currentToken, ddl.ErrUnexpectedToken) +} + +func (p *Parser) isPeekToken(expectedTypes ...TokenType) bool { + for _, expected := range expectedTypes { + if expected == p.peekToken.Type { + return true + } + } + return false +} + +func (p *Parser) checkPeekToken(expectedTypes ...TokenType) error { + for _, expected := range expectedTypes { + if expected == p.peekToken.Type { + return nil + } + } + return apperr.Errorf("peekToken: expected=%s, but got=%#v: %w", stringz.JoinStringers(",", expectedTypes...), p.peekToken, ddl.ErrUnexpectedToken) +} diff --git a/pkg/ddl/spanner/parser_test.go b/pkg/ddl/spanner/parser_test.go new file mode 100644 index 0000000..4eba6d5 --- /dev/null +++ b/pkg/ddl/spanner/parser_test.go @@ -0,0 +1,504 @@ +//nolint:testpackage +package spanner + +import ( + "log" + "os" + "testing" + + "github.com/kunitsucom/util.go/testing/assert" + "github.com/kunitsucom/util.go/testing/require" + + "github.com/kunitsucom/ddlctl/pkg/ddl" + "github.com/kunitsucom/ddlctl/pkg/ddl/logs" +) + +//nolint:paralleltest,tparallel +func TestParser_Parse(t *testing.T) { + backup := logs.TraceLog + t.Cleanup(func() { + logs.TraceLog = backup + }) + logs.TraceLog = log.New(os.Stderr, "TRACE: ", log.LstdFlags|log.Lshortfile) + + t.Run("success,CREATE_TABLE", func(t *testing.T) { + // t.Parallel() + + l := NewLexer(`CREATE TABLE "groups" ("id" STRING(36) NOT NULL, description STRING) PRIMARY KEY ("id"); CREATE TABLE "users" (id STRING(36) NOT NULL, group_id STRING(36) NOT NULL, "name" STRING(255) NOT NULL, "age" INT64 DEFAULT 0, description STRING, CONSTRAINT users_age_check CHECK ("age" >= 0), CONSTRAINT users_group_id_fkey FOREIGN KEY (group_id) REFERENCES "groups" ("id")) PRIMARY KEY ("id"), INTERLEAVE IN PARENT names ON DELETE NO ACTION;`) + p := NewParser(l) + actual, err := p.Parse() + require.NoError(t, err) + + const expected = `CREATE TABLE "groups" ( + "id" STRING(36) NOT NULL, + description STRING +) PRIMARY KEY ("id"); +CREATE TABLE "users" ( + id STRING(36) NOT NULL, + group_id STRING(36) NOT NULL, + "name" STRING(255) NOT NULL, + "age" INT64 DEFAULT 0, + description STRING, + CONSTRAINT users_age_check CHECK ("age" >= 0), + CONSTRAINT users_group_id_fkey FOREIGN KEY (group_id) REFERENCES "groups" ("id") +) PRIMARY KEY ("id"), +INTERLEAVE IN PARENT names ON DELETE NO ACTION; +` + + if !assert.Equal(t, expected, actual.String()) { + t.Fail() + } + + t.Logf("✅: %s: actual: %%#v: \n%#v", t.Name(), actual) + t.Logf("✅: %s: actual: %%s: \n%s", t.Name(), actual) + }) + + t.Run("success,complex_defaults", func(t *testing.T) { + // t.Parallel() + + l := NewLexer(`-- table: complex_defaults +CREATE TABLE IF NOT EXISTS complex_defaults ( + -- id is the primary key. + id INT64, + created_at TIMESTAMP DEFAULT (CURRENT_TIMESTAMP()), + updated_at TIMESTAMP DEFAULT (CURRENT_TIMESTAMP()), + unique_code STRING DEFAULT (GENERATE_UUID()), + status STRING DEFAULT ('pending'), + random_number INT64 DEFAULT (FLOOR(RANDOM() * 100)), + json_data JSON DEFAULT ('{}'), + calculated_value INT64 DEFAULT (SELECT COUNT(*) FROM another_table) +) PRIMARY KEY (id); +`) + p := NewParser(l) + actual, err := p.Parse() + require.NoError(t, err) + + const expected = `CREATE TABLE IF NOT EXISTS complex_defaults ( + id INT64, + created_at TIMESTAMP DEFAULT (CURRENT_TIMESTAMP()), + updated_at TIMESTAMP DEFAULT (CURRENT_TIMESTAMP()), + unique_code STRING DEFAULT (GENERATE_UUID()), + status STRING DEFAULT ('pending'), + random_number INT64 DEFAULT (FLOOR(RANDOM() * 100)), + json_data JSON DEFAULT ('{}'), + calculated_value INT64 DEFAULT (SELECT COUNT(*) FROM another_table) +) PRIMARY KEY (id); +` + + if !assert.Equal(t, expected, actual.String()) { + t.Fail() + } + + t.Logf("✅: %s: actual: %%#v: \n%#v", t.Name(), actual) + t.Logf("✅: %s: actual: %%s: \n%s", t.Name(), actual) + }) + + failureTests := []struct { + name string + input string + wantErr error + }{ + { + name: "failure,invalid", + input: `)invalid`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_INVALID", + input: `CREATE INVALID;`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_INVALID", + input: `CREATE TABLE;`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_IF_INVALID", + input: `CREATE TABLE IF;`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_IF_NOT_INVALID", + input: `CREATE TABLE IF NOT;`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_INVALID", + input: `CREATE TABLE "users";`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_INVALID", + input: `CREATE TABLE "users" ("id";`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_data_type_INVALID", + input: `CREATE TABLE "users" ("id" STRING(36);`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_data_type_INVALID", + input: `CREATE TABLE "users" ("id" STRING(36), enabled BOOL DEFAULT (FALSE);`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_data_type_INVALID", + input: `CREATE TABLE "users" ("id" STRING(36), enabled BOOL DEFAULT (TRUE AND FALSE);`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_data_type_OPTIONS_INVALID", + input: `CREATE TABLE "users" ("id" STRING(36), created_at TIMESTAMP OPTIONS (allow_commit_timestamp = true;`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_CONSTRAINT_INVALID", + input: `CREATE TABLE "users" ("id" STRING(36), CONSTRAINT "invalid" NOT;`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_INVALID", + input: `CREATE TABLE "users" ("id" STRING(36))(;`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_COMMA_INVALID", + input: `CREATE TABLE "users" ("id" TIMESTAMP CREATE`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_DATA_TYPE_INVALID", + input: `CREATE TABLE "users" ("id" VARYING();`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_INVALID_NOT", + input: `CREATE TABLE "users" ("id" STRING(36) NULL NOT;`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_INVALID_DEFAULT", + input: `CREATE TABLE "users" ("id" STRING(36) DEFAULT ("id")`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_INVALID_DEFAULT_OPEN_PAREN", + input: `CREATE TABLE "users" ("id" STRING(36) DEFAULT ("id",`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_INVALID_PRIMARY_KEY", + input: `CREATE TABLE "users" ("id" STRING(36) PRIMARY NOT`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_INVALID_REFERENCES", + input: `CREATE TABLE "users" ("id" STRING(36) REFERENCES NOT`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_INVALID_REFERENCES_IDENTS", + input: `CREATE TABLE "users" ("id" STRING(36) REFERENCES "groups" (NOT`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_INVALID_CHECK", + input: `CREATE TABLE "users" ("id" STRING(36) CHECK NOT`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_CHECK_INVALID_IDENTS", + input: `CREATE TABLE "users" ("id" STRING(36) CHECK (NOT`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_CONSTRAINT_INVALID_IDENT", + input: `CREATE TABLE "users" ("id" STRING(36), CONSTRAINT NOT`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_CONSTRAINT_CHECK_INVALID", + input: `CREATE TABLE "users" ("id" STRING(36), CONSTRAINT constraint_name CHECK`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_CONSTRAINT_CHECK_OPEN_INVALID", + input: `CREATE TABLE "users" ("id" STRING(36), CONSTRAINT constraint_name CHECK (`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_CONSTRAINT_INVALID_PRIMARY", + input: `CREATE TABLE "users" ("id" STRING(36)) PRIMARY NOT`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_OPTION_PRIMARY_KEY_INVALID", + input: `CREATE TABLE "users" ("id" STRING(36)) PRIMARY KEY NOT`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_OPTION_PRIMARY_KEY_OPEN_PAREN_INVALID", + input: `CREATE TABLE "users" ("id" STRING(36)) PRIMARY KEY (NOT`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_OPTION_PRIMARY_KEY_OPEN_PAREN_column_name_INVALID", + input: `CREATE TABLE "users" ("id" STRING(36)) PRIMARY KEY (id`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_OPTION_PRIMARY_KEY_INTERLEAVE_INVALID", + input: `CREATE TABLE "users" ("id" STRING(36)) PRIMARY KEY (id), INTERLEAVE;`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_OPTION_PRIMARY_KEY_INTERLEAVE_IN_INVALID", + input: `CREATE TABLE "users" ("id" STRING(36)) PRIMARY KEY (id), INTERLEAVE IN;`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_OPTION_PRIMARY_KEY_INTERLEAVE_IN_PARENT_INVALID", + input: `CREATE TABLE "users" ("id" STRING(36)) PRIMARY KEY (id), INTERLEAVE IN PARENT;`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_OPTION_PRIMARY_KEY_INTERLEAVE_IN_PARENT_ON_INVALID", + input: `CREATE TABLE "users" ("id" STRING(36)) PRIMARY KEY (id), INTERLEAVE IN PARENT table_name ON;`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_OPTION_PRIMARY_KEY_INTERLEAVE_IN_PARENT_ON_DELETE_INVALID", + input: `CREATE TABLE "users" ("id" STRING(36)) PRIMARY KEY (id), INTERLEAVE IN PARENT table_name ON DELETE;`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_OPTION_PRIMARY_KEY_INTERLEAVE_IN_PARENT_ON_DELETE_CASCADE_INVALID", + input: `CREATE TABLE "users" ("id" STRING(36)) PRIMARY KEY (id), INTERLEAVE IN PARENT table_name ON DELETE CASCADE NOT;`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_OPTION_PRIMARY_KEY_INTERLEAVE_IN_PARENT_ON_DELETE_NO_INVALID", + input: `CREATE TABLE "users" ("id" STRING(36)) PRIMARY KEY (id), INTERLEAVE IN PARENT table_name ON DELETE NO;`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_CONSTRAINT_INVALID_FOREIGN", + input: `CREATE TABLE "users" ("id" STRING(36), FOREIGN NOT`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_CONSTRAINT_INVALID_FOREIGN_KEY", + input: `CREATE TABLE "users" ("id" STRING(36), FOREIGN KEY NOT`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_CONSTRAINT_INVALID_FOREIGN_KEY_OPEN_PAREN", + input: `CREATE TABLE "users" ("id" STRING(36), FOREIGN KEY (NOT`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_CONSTRAINT_FOREIGN_KEY_IDENTS_INVALID", + input: `CREATE TABLE "users" ("id" STRING(36), FOREIGN KEY ("group_id") NOT`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_CONSTRAINT_FOREIGN_KEY_IDENTS_REFERENCES_INVALID", + input: `CREATE TABLE "users" ("id" STRING(36), FOREIGN KEY ("group_id") REFERENCES `, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_CONSTRAINT_FOREIGN_KEY_IDENTS_REFERENCES_INVALID_IDENTS", + input: `CREATE TABLE "users" ("id" STRING(36), FOREIGN KEY ("group_id") REFERENCES "groups" NOT`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_CONSTRAINT_FOREIGN_KEY_IDENTS_REFERENCES_INVALID_CLOSE_PAREN", + input: `CREATE TABLE "users" ("id" STRING(36), FOREIGN KEY ("group_id") REFERENCES "groups" ("id")`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_CONSTRAINT_UNIQUE_INVALID", + input: `CREATE TABLE "users" ("id" STRING(36), UNIQUE NOT`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_CONSTRAINT_UNIQUE_INVALID", + input: `CREATE TABLE "users" ("id" STRING(36), UNIQUE NOT`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_CONSTRAINT_UNIQUE_INDEX_INVALID", + input: `CREATE TABLE "users" ("id" STRING(36), UNIQUE INDEX users_idx_name NOT`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_CONSTRAINT_UNIQUE_INDEX_COLUMN_INVALID", + input: `CREATE TABLE "users" ("id" STRING(36), UNIQUE INDEX users_idx_name (NOT`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_CONSTRAINT_UNIQUE_INVALID", + input: `CREATE TABLE "users" ("id" STRING(36), UNIQUE INDEX NOT`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_TABLE_table_name_column_name_CONSTRAINT_UNIQUE_IDENTS_INVALID", + input: `CREATE TABLE "users" ("id" STRING(36), name STRING, UNIQUE ("id", name)`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_INDEX_INVALID", + input: `CREATE INDEX NOT`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_INDEX_IF_INVALID", + input: `CREATE INDEX IF;`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_INDEX_IF_NOT_INVALID", + input: `CREATE INDEX IF NOT;`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_INDEX_IF_NOT_EXISTS_INVALID", + input: `CREATE INDEX IF NOT EXISTS;`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_INDEX_index_name_INVALID", + input: `CREATE INDEX users_idx_username NOT`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_INDEX_index_name_ON_INVALID", + input: `CREATE INDEX users_idx_username ON NOT`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_INDEX_index_name_ON_table_name_INVALID", + input: `CREATE INDEX users_idx_username ON users NOT`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_INDEX_index_name_ON_table_name_USING_INVALID", + input: `CREATE INDEX users_idx_username ON users USING NOT`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_INDEX_index_name_ON_table_name_USING_method_INVALID", + input: `CREATE INDEX users_idx_username ON users USING btree NOT`, + wantErr: ddl.ErrUnexpectedToken, + }, + { + name: "failure,CREATE_INDEX_index_name_ON_table_name_USING_method_OPEN_INVALID", + input: `CREATE INDEX users_idx_username ON users USING btree (NOT)`, + wantErr: ddl.ErrUnexpectedToken, + }, + } + + for _, tt := range failureTests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + _, err := NewParser(NewLexer(tt.input)).Parse() + require.ErrorIs(t, err, tt.wantErr) + }) + } + + t.Run("success,TOKEN_SEMICOLON", func(t *testing.T) { + _, err := NewParser(NewLexer(`;`)).Parse() + require.NoError(t, err) + }) +} + +func TestParser_parseColumn(t *testing.T) { + t.Parallel() + + t.Run("success,TOKEN_COMMA", func(t *testing.T) { + t.Parallel() + + p := NewParser(NewLexer("( id STRING(36),")) + p.nextToken() + p.nextToken() + p.nextToken() + _, _, err := p.parseColumn(&Ident{Name: "table_name", QuotationMark: `"`, Raw: `"table_name"`}) + require.NoError(t, err) + }) + + t.Run("failure,invalid", func(t *testing.T) { + t.Parallel() + + _, _, err := NewParser(NewLexer(`NOT`)).parseColumn(&Ident{Name: "table_name", QuotationMark: `"`, Raw: `"table_name"`}) + require.ErrorIs(t, err, ddl.ErrUnexpectedToken) + }) + + t.Run("failure,parseDataType", func(t *testing.T) { + t.Parallel() + + p := NewParser(NewLexer("( id STRING(")) + p.nextToken() + p.nextToken() + p.nextToken() + _, _, err := p.parseColumn(&Ident{Name: "table_name", QuotationMark: `"`, Raw: `"table_name"`}) + require.ErrorIs(t, err, ddl.ErrUnexpectedToken) + }) +} + +func TestParser_parseColumnDefault(t *testing.T) { + t.Parallel() + + t.Run("success,isReservedValue", func(t *testing.T) { + t.Parallel() + + p := NewParser(NewLexer(`DEFAULT TRUE,`)) + p.nextToken() + p.nextToken() + p.nextToken() + _, err := p.parseColumnDefault() + require.NoError(t, err) + }) +} + +func TestParser_parseExpr(t *testing.T) { + t.Parallel() + + t.Run("failure,invalid", func(t *testing.T) { + t.Parallel() + + p := NewParser(NewLexer(`NOT`)) + p.nextToken() + p.nextToken() + _, err := p.parseExpr() + require.ErrorIs(t, err, ddl.ErrUnexpectedToken) + }) + + t.Run("failure,invalid2", func(t *testing.T) { + t.Parallel() + + p := NewParser(NewLexer(`((NOT`)) + p.nextToken() + p.nextToken() + _, err := p.parseExpr() + require.ErrorIs(t, err, ddl.ErrUnexpectedToken) + }) +} + +func TestParser_parseDataType(t *testing.T) { + t.Parallel() + + t.Run("failure,invalid_paren_content", func(t *testing.T) { + t.Parallel() + + p := NewParser(NewLexer(`STRING(`)) + p.nextToken() + p.nextToken() + _, err := p.parseDataType() + require.ErrorIs(t, err, ddl.ErrUnexpectedToken) + }) +} diff --git a/pkg/ddlctl/ddlctl.go b/pkg/ddlctl/ddlctl.go index acfcf48..374e547 100644 --- a/pkg/ddlctl/ddlctl.go +++ b/pkg/ddlctl/ddlctl.go @@ -6,17 +6,13 @@ import ( "fmt" "os" - errorz "github.com/kunitsucom/util.go/errors" cliz "github.com/kunitsucom/util.go/exp/cli" "github.com/kunitsucom/util.go/version" + apperr "github.com/kunitsucom/ddlctl/pkg/apperr" "github.com/kunitsucom/ddlctl/pkg/internal/consts" ) -const ( - _spanner = "spanner" // TODO: remove after spanner ddl diff implemented -) - //nolint:gochecknoglobals var ( optLanguage = &cliz.StringOption{ @@ -138,7 +134,7 @@ func DDLCtl(ctx context.Context) error { RunFunc: func(ctx context.Context, args []string) error { cmd, err := cliz.FromContext(ctx) if err != nil { - return errorz.Errorf("cliz.FromContext: %w", err) + return apperr.Errorf("cliz.FromContext: %w", err) } cmd.ShowUsage() @@ -150,7 +146,8 @@ func DDLCtl(ctx context.Context) error { if errors.Is(err, cliz.ErrHelp) { return nil } - return errorz.Errorf("cmd.Run: %w", err) + + return apperr.Errorf("cmd.Run: %w", err) } return nil diff --git a/pkg/ddlctl/ddlctl_apply.go b/pkg/ddlctl/ddlctl_apply.go index 169667d..10924fc 100644 --- a/pkg/ddlctl/ddlctl_apply.go +++ b/pkg/ddlctl/ddlctl_apply.go @@ -3,44 +3,53 @@ package ddlctl import ( "bufio" "context" + "errors" "fmt" "os" "strings" sqlz "github.com/kunitsucom/util.go/database/sql" - errorz "github.com/kunitsucom/util.go/errors" + stringz "github.com/kunitsucom/util.go/strings" + apperr "github.com/kunitsucom/ddlctl/pkg/apperr" + "github.com/kunitsucom/ddlctl/pkg/ddl" crdbddl "github.com/kunitsucom/ddlctl/pkg/ddl/cockroachdb" - apperr "github.com/kunitsucom/ddlctl/pkg/errors" + myddl "github.com/kunitsucom/ddlctl/pkg/ddl/mysql" + spanddl "github.com/kunitsucom/ddlctl/pkg/ddl/spanner" "github.com/kunitsucom/ddlctl/pkg/internal/config" "github.com/kunitsucom/ddlctl/pkg/internal/consts" ) -//nolint:cyclop,funlen +//nolint:cyclop,funlen,gocognit func Apply(ctx context.Context, args []string) (err error) { if _, err := config.Load(ctx); err != nil { - return errorz.Errorf("config.Load: %w", err) + return apperr.Errorf("config.Load: %w", err) } if len(args) != 2 { - return errorz.Errorf("args=%v: %w", args, apperr.ErrTwoArgumentsRequired) + return apperr.Errorf("args=%v: %w", args, apperr.ErrTwoArgumentsRequired) } dsn, ddlSrc := args[0], args[1] + dialect := config.Dialect() - left, err := resolve(ctx, config.Dialect(), dsn) + left, err := resolve(ctx, dialect, dsn) if err != nil { - return errorz.Errorf("resolve: %w", err) + return apperr.Errorf("resolve: %w", err) } - right, err := resolve(ctx, config.Dialect(), ddlSrc) + right, err := resolve(ctx, dialect, ddlSrc) if err != nil { - return errorz.Errorf("resolve: %w", err) + return apperr.Errorf("resolve: %w", err) } buf := new(strings.Builder) - if err := DiffDDL(buf, config.Dialect(), left, right); err != nil { - return errorz.Errorf("diff: %w", err) + if err := DiffDDL(buf, dialect, left, right); err != nil { + if errors.Is(err, ddl.ErrNoDifference) { + _, _ = fmt.Fprintln(os.Stdout, ddl.ErrNoDifference.Error()) + return nil + } + return apperr.Errorf("diff: %w", err) } q := buf.String() @@ -60,23 +69,23 @@ Do you want to apply these DDL queries? Enter a value: ` if _, err := os.Stdout.WriteString(msg); err != nil { - return errorz.Errorf("os.Stdout.WriteString: %w", err) + return apperr.Errorf("os.Stdout.WriteString: %w", err) } if config.AutoApprove() { if _, err := os.Stdout.WriteString(fmt.Sprintf("yes (via --%s option)\n", consts.OptionAutoApprove)); err != nil { - return errorz.Errorf("os.Stdout.WriteString: %w", err) + return apperr.Errorf("os.Stdout.WriteString: %w", err) } } else { if err := prompt(); err != nil { - return errorz.Errorf("prompt: %w", err) + return apperr.Errorf("prompt: %w", err) } } os.Stdout.WriteString("\nexecuting...\n") driverName := func() string { - switch dialect := config.Dialect(); dialect { + switch dialect { case crdbddl.Dialect: return crdbddl.DriverName default: @@ -86,16 +95,57 @@ Enter a value: ` db, err := sqlz.OpenContext(ctx, driverName, dsn) if err != nil { - return errorz.Errorf("sqlz.OpenContext: %w", err) + return apperr.Errorf("sqlz.OpenContext: %w", err) } defer func() { if cerr := db.Close(); err == nil && cerr != nil { - err = errorz.Errorf("db.Close: %w", cerr) + err = apperr.Errorf("db.Close: %w", cerr) } }() - if _, err := db.ExecContext(ctx, q); err != nil { - return errorz.Errorf("db.ExecContext: %w", err) + switch driverName { + case myddl.DriverName: + for _, q := range strings.Split(q, ";\n") { + if len(q) == 0 { + // skip empty query + continue + } + if _, err := db.ExecContext(ctx, q); err != nil { + return apperr.Errorf("conn.ExecContext: %w", err) + } + } + case spanddl.DriverName: + conn, err := db.Conn(ctx) + if err != nil { + return apperr.Errorf("db.Conn: %w", err) + } + defer func() { + if cerr := conn.Close(); err == nil && cerr != nil { + err = apperr.Errorf("conn.Close: %w", cerr) + } + }() + if _, err := conn.ExecContext(ctx, "START BATCH DDL"); err != nil { + return apperr.Errorf("conn.ExecContext: %w", err) + } + + commentTrimmedDDL := stringz.ReadLine(q, "\n", stringz.ReadLineFuncRemoveCommentLine("--")) + for _, q := range strings.Split(commentTrimmedDDL, ";\n") { + if len(q) == 0 { + // skip empty query + continue + } + if _, err := conn.ExecContext(ctx, q); err != nil { + return apperr.Errorf("conn.ExecContext: %w", err) + } + } + + if _, err := conn.ExecContext(ctx, "RUN BATCH"); err != nil { + return apperr.Errorf("conn.ExecContext: %w", err) + } + default: + if _, err := db.ExecContext(ctx, q); err != nil { + return apperr.Errorf("db.ExecContext: %w", err) + } } os.Stdout.WriteString("done\n") @@ -112,6 +162,6 @@ func prompt() error { case "yes": return nil default: - return errorz.Errorf("input=%s: %w", input, apperr.ErrCanceled) + return apperr.Errorf("input=%s: %w", input, apperr.ErrCanceled) } } diff --git a/pkg/ddlctl/ddlctl_diff.go b/pkg/ddlctl/ddlctl_diff.go index da0bd21..6f30623 100644 --- a/pkg/ddlctl/ddlctl_diff.go +++ b/pkg/ddlctl/ddlctl_diff.go @@ -2,43 +2,48 @@ package ddlctl import ( "context" + "errors" "io" "os" "strings" - errorz "github.com/kunitsucom/util.go/errors" osz "github.com/kunitsucom/util.go/os" + apperr "github.com/kunitsucom/ddlctl/pkg/apperr" + "github.com/kunitsucom/ddlctl/pkg/ddl" crdbddl "github.com/kunitsucom/ddlctl/pkg/ddl/cockroachdb" myddl "github.com/kunitsucom/ddlctl/pkg/ddl/mysql" pgddl "github.com/kunitsucom/ddlctl/pkg/ddl/postgres" - - apperr "github.com/kunitsucom/ddlctl/pkg/errors" + spanddl "github.com/kunitsucom/ddlctl/pkg/ddl/spanner" "github.com/kunitsucom/ddlctl/pkg/internal/config" "github.com/kunitsucom/ddlctl/pkg/internal/logs" ) func Diff(ctx context.Context, args []string) error { if _, err := config.Load(ctx); err != nil { - return errorz.Errorf("config.Load: %w", err) + return apperr.Errorf("config.Load: %w", err) } if len(args) != 2 { - return errorz.Errorf("args=%v: %w", args, apperr.ErrTwoArgumentsRequired) + return apperr.Errorf("args=%v: %w", args, apperr.ErrTwoArgumentsRequired) } left, err := resolve(ctx, config.Dialect(), args[0]) if err != nil { - return errorz.Errorf("resolve: %w", err) + return apperr.Errorf("resolve: %w", err) } right, err := resolve(ctx, config.Dialect(), args[1]) if err != nil { - return errorz.Errorf("resolve: %w", err) + return apperr.Errorf("resolve: %w", err) } if err := DiffDDL(os.Stdout, config.Dialect(), left, right); err != nil { - return errorz.Errorf("diff: %w", err) + if errors.Is(err, ddl.ErrNoDifference) { + logs.Debug.Print(ddl.ErrNoDifference.Error()) + return nil + } + return apperr.Errorf("diff: %w", err) } return nil @@ -50,19 +55,19 @@ func resolve(ctx context.Context, dialect, arg string) (ddl string, err error) { case osz.IsFile(arg): // NOTE: expect SQL file ddlBytes, err := os.ReadFile(arg) if err != nil { - return "", errorz.Errorf("os.ReadFile: %w", err) + return "", apperr.Errorf("os.ReadFile: %w", err) } ddl = string(ddlBytes) case osz.Exists(arg): // NOTE: expect ddlctl generate format genDDL, err := generateDDLForDiff(ctx, arg) if err != nil { - return "", errorz.Errorf("generateDDL: %w", err) // TODO: ddlgen 形式じゃないから無理というエラーに修正する + return "", apperr.Errorf("generateDDL: %w", err) // TODO: ddlgen 形式じゃないから無理というエラーに修正する } ddl = genDDL default: // NOTE: expect DSN genDDL, err := ShowDDL(ctx, dialect, arg) if err != nil { - return "", errorz.Errorf("ShowDDL: %w", err) + return "", apperr.Errorf("ShowDDL: %w", err) } ddl = genDDL } @@ -73,18 +78,18 @@ func resolve(ctx context.Context, dialect, arg string) (ddl string, err error) { func generateDDLForDiff(ctx context.Context, src string) (string, error) { ddl, err := Parse(ctx, config.Language(), src) if err != nil { - return "", errorz.Errorf("parse: %w", err) + return "", apperr.Errorf("parse: %w", err) } b := new(strings.Builder) if err := Fprint(b, config.Dialect(), ddl); err != nil { - return "", errorz.Errorf("fprint: %w", err) + return "", apperr.Errorf("fprint: %w", err) } return b.String(), nil } -//nolint:cyclop,funlen +//nolint:cyclop,funlen,gocognit func DiffDDL(out io.Writer, dialect string, srcDDL string, dstDDL string) error { logs.Trace.Printf("src: %q", srcDDL) logs.Trace.Printf("dst: %q", dstDDL) @@ -93,66 +98,86 @@ func DiffDDL(out io.Writer, dialect string, srcDDL string, dstDDL string) error case myddl.Dialect: leftDDL, err := myddl.NewParser(myddl.NewLexer(srcDDL)).Parse() if err != nil { - return errorz.Errorf("myddl.NewParser: %w", err) + return apperr.Errorf("myddl.NewParser: %w", err) } rightDDL, err := myddl.NewParser(myddl.NewLexer(dstDDL)).Parse() if err != nil { - return errorz.Errorf("myddl.NewParser: %w", err) + return apperr.Errorf("myddl.NewParser: %w", err) } result, err := myddl.Diff(leftDDL, rightDDL) if err != nil { - return errorz.Errorf("myddl.Diff: %w", err) + return apperr.Errorf("myddl.Diff: %w", err) } if _, err := io.WriteString(out, result.String()); err != nil { - return errorz.Errorf("io.WriteString: %w", err) + return apperr.Errorf("io.WriteString: %w", err) } return nil case pgddl.Dialect: leftDDL, err := pgddl.NewParser(pgddl.NewLexer(srcDDL)).Parse() if err != nil { - return errorz.Errorf("pgddl.NewParser: %w", err) + return apperr.Errorf("pgddl.NewParser: %w", err) } rightDDL, err := pgddl.NewParser(pgddl.NewLexer(dstDDL)).Parse() if err != nil { - return errorz.Errorf("pgddl.NewParser: %w", err) + return apperr.Errorf("pgddl.NewParser: %w", err) } result, err := pgddl.Diff(leftDDL, rightDDL) if err != nil { - return errorz.Errorf("pgddl.Diff: %w", err) + return apperr.Errorf("pgddl.Diff: %w", err) } if _, err := io.WriteString(out, result.String()); err != nil { - return errorz.Errorf("io.WriteString: %w", err) + return apperr.Errorf("io.WriteString: %w", err) } return nil case crdbddl.Dialect: leftDDL, err := crdbddl.NewParser(crdbddl.NewLexer(srcDDL)).Parse() if err != nil { - return errorz.Errorf("pgddl.NewParser: %w", err) + return apperr.Errorf("pgddl.NewParser: %w", err) } rightDDL, err := crdbddl.NewParser(crdbddl.NewLexer(dstDDL)).Parse() if err != nil { - return errorz.Errorf("pgddl.NewParser: %w", err) + return apperr.Errorf("pgddl.NewParser: %w", err) } result, err := crdbddl.Diff(leftDDL, rightDDL) if err != nil { - return errorz.Errorf("pgddl.Diff: %w", err) + return apperr.Errorf("pgddl.Diff: %w", err) + } + + if _, err := io.WriteString(out, result.String()); err != nil { + return apperr.Errorf("io.WriteString: %w", err) + } + + return nil + case spanddl.Dialect: + leftDDL, err := spanddl.NewParser(spanddl.NewLexer(srcDDL)).Parse() + if err != nil { + return apperr.Errorf("spanddl.NewParser: %w", err) + } + rightDDL, err := spanddl.NewParser(spanddl.NewLexer(dstDDL)).Parse() + if err != nil { + return apperr.Errorf("spanddl.NewParser: %w", err) + } + + result, err := spanddl.Diff(leftDDL, rightDDL) + if err != nil { + return apperr.Errorf("spanddl.Diff: %w", err) } if _, err := io.WriteString(out, result.String()); err != nil { - return errorz.Errorf("io.WriteString: %w", err) + return apperr.Errorf("io.WriteString: %w", err) } return nil case "": - return errorz.Errorf("dialect=%s: %w", dialect, apperr.ErrDialectIsEmpty) + return apperr.Errorf("dialect=%s: %w", dialect, apperr.ErrDialectIsEmpty) default: - return errorz.Errorf("dialect=%s: %w", dialect, apperr.ErrNotSupported) + return apperr.Errorf("dialect=%s: %w", dialect, apperr.ErrNotSupported) } } diff --git a/pkg/ddlctl/ddlctl_generate.go b/pkg/ddlctl/ddlctl_generate.go index 009d56c..cae7c53 100644 --- a/pkg/ddlctl/ddlctl_generate.go +++ b/pkg/ddlctl/ddlctl_generate.go @@ -6,10 +6,8 @@ import ( "os" "path/filepath" - errorz "github.com/kunitsucom/util.go/errors" - + apperr "github.com/kunitsucom/ddlctl/pkg/apperr" crdbddl "github.com/kunitsucom/ddlctl/pkg/ddl/cockroachdb" - apperr "github.com/kunitsucom/ddlctl/pkg/errors" "github.com/kunitsucom/ddlctl/pkg/internal/config" "github.com/kunitsucom/ddlctl/pkg/internal/generator" "github.com/kunitsucom/ddlctl/pkg/internal/generator/dialect/mysql" @@ -21,7 +19,7 @@ import ( func Generate(ctx context.Context, _ []string) error { if _, err := config.Load(ctx); err != nil { - return errorz.Errorf("config.Load: %w", err) + return apperr.Errorf("config.Load: %w", err) } src := config.Source() @@ -32,7 +30,7 @@ func Generate(ctx context.Context, _ []string) error { ddl, err := Parse(ctx, language, src) if err != nil { - return errorz.Errorf("parse: %w", err) + return apperr.Errorf("parse: %w", err) } if info, err := os.Stat(config.Destination()); err == nil && info.IsDir() { @@ -42,7 +40,7 @@ func Generate(ctx context.Context, _ []string) error { f, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o644) if err != nil { - return errorz.Errorf("os.OpenFile: %w", err) + return apperr.Errorf("os.OpenFile: %w", err) } if err := Fprint( @@ -54,7 +52,7 @@ func Generate(ctx context.Context, _ []string) error { Stmts: []generator.Stmt{stmt}, }, ); err != nil { - return errorz.Errorf("fprint: %w", err) + return apperr.Errorf("fprint: %w", err) } } return nil @@ -65,11 +63,11 @@ func Generate(ctx context.Context, _ []string) error { f, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o644) if err != nil { - return errorz.Errorf("os.OpenFile: %w", err) + return apperr.Errorf("os.OpenFile: %w", err) } if err := Fprint(f, config.Dialect(), ddl); err != nil { - return errorz.Errorf("fprint: %w", err) + return apperr.Errorf("fprint: %w", err) } return nil } @@ -79,11 +77,11 @@ func Parse(ctx context.Context, language string, src string) (*generator.DDL, er case ddlctlgo.Language: ddl, err := ddlctlgo.Parse(ctx, src) if err != nil { - return nil, errorz.Errorf("ddlgengo.Parse: %w", err) + return nil, apperr.Errorf("ddlgengo.Parse: %w", err) } return ddl, nil default: - return nil, errorz.Errorf("language=%s: %w", language, apperr.ErrNotSupported) + return nil, apperr.Errorf("language=%s: %w", language, apperr.ErrNotSupported) } } @@ -91,22 +89,22 @@ func Fprint(w io.Writer, dialect string, ddl *generator.DDL) error { switch dialect { case spanner.Dialect: if err := spanner.Fprint(w, ddl); err != nil { - return errorz.Errorf("spanner.Fprint: %w", err) + return apperr.Errorf("spanner.Fprint: %w", err) } return nil case postgres.Dialect, crdbddl.Dialect: if err := postgres.Fprint(w, ddl); err != nil { - return errorz.Errorf("postgres.Fprint: %w", err) + return apperr.Errorf("postgres.Fprint: %w", err) } return nil case mysql.Dialect: if err := mysql.Fprint(w, ddl); err != nil { - return errorz.Errorf("mysql.Fprint: %w", err) + return apperr.Errorf("mysql.Fprint: %w", err) } return nil case "": - return errorz.Errorf("dialect=%s: %w", dialect, apperr.ErrDialectIsEmpty) + return apperr.Errorf("dialect=%s: %w", dialect, apperr.ErrDialectIsEmpty) default: - return errorz.Errorf("dialect=%s: %w", dialect, apperr.ErrNotSupported) + return apperr.Errorf("dialect=%s: %w", dialect, apperr.ErrNotSupported) } } diff --git a/pkg/ddlctl/ddlctl_show.go b/pkg/ddlctl/ddlctl_show.go index dc61c3c..d63ee6e 100644 --- a/pkg/ddlctl/ddlctl_show.go +++ b/pkg/ddlctl/ddlctl_show.go @@ -6,12 +6,12 @@ import ( "os" sqlz "github.com/kunitsucom/util.go/database/sql" - errorz "github.com/kunitsucom/util.go/errors" + apperr "github.com/kunitsucom/ddlctl/pkg/apperr" crdbddl "github.com/kunitsucom/ddlctl/pkg/ddl/cockroachdb" myddl "github.com/kunitsucom/ddlctl/pkg/ddl/mysql" pgddl "github.com/kunitsucom/ddlctl/pkg/ddl/postgres" - apperr "github.com/kunitsucom/ddlctl/pkg/errors" + spanddl "github.com/kunitsucom/ddlctl/pkg/ddl/spanner" "github.com/kunitsucom/ddlctl/pkg/internal/config" crdbshow "github.com/kunitsucom/ddlctl/pkg/show/cockroachdb" myshow "github.com/kunitsucom/ddlctl/pkg/show/mysql" @@ -21,16 +21,16 @@ import ( func Show(ctx context.Context, args []string) error { if _, err := config.Load(ctx); err != nil { - return errorz.Errorf("config.Load: %w", err) + return apperr.Errorf("config.Load: %w", err) } ddl, err := ShowDDL(ctx, config.Dialect(), args[0]) if err != nil { - return errorz.Errorf("diff: %w", err) + return apperr.Errorf("diff: %w", err) } if _, err := io.WriteString(os.Stdout, ddl); err != nil { - return errorz.Errorf("io.WriteString: %w", err) + return apperr.Errorf("io.WriteString: %w", err) } return nil @@ -49,11 +49,11 @@ func ShowDDL(ctx context.Context, dialect string, dsn string) (ddl string, err e db, err := sqlz.OpenContext(ctx, driverName, dsn) if err != nil { - return "", errorz.Errorf("sqlz.OpenContext: %w", err) + return "", apperr.Errorf("sqlz.OpenContext: %w", err) } defer func() { if cerr := db.Close(); err == nil && cerr != nil { - err = errorz.Errorf("db.Close: %w", cerr) + err = apperr.Errorf("db.Close: %w", cerr) } }() @@ -61,28 +61,28 @@ func ShowDDL(ctx context.Context, dialect string, dsn string) (ddl string, err e case myddl.Dialect: ddl, err := myshow.ShowCreateAllTables(ctx, db) if err != nil { - return "", errorz.Errorf("pgutil.ShowCreateAllTables: %w", err) + return "", apperr.Errorf("pgutil.ShowCreateAllTables: %w", err) } return ddl, nil case pgddl.Dialect: ddl, err := pgshow.ShowCreateAllTables(ctx, db) if err != nil { - return "", errorz.Errorf("pgutil.ShowCreateAllTables: %w", err) + return "", apperr.Errorf("pgutil.ShowCreateAllTables: %w", err) } return ddl, nil case crdbddl.Dialect: ddl, err := crdbshow.ShowCreateAllTables(ctx, db) if err != nil { - return "", errorz.Errorf("crdbutil.ShowCreateAllTables: %w", err) + return "", apperr.Errorf("crdbutil.ShowCreateAllTables: %w", err) } return ddl, nil - case _spanner: + case spanddl.Dialect: ddl, err := spanshow.ShowCreateAllTables(ctx, db) if err != nil { - return "", errorz.Errorf("spanshow.ShowCreateAllTables: %w", err) + return "", apperr.Errorf("spanshow.ShowCreateAllTables: %w", err) } return ddl, nil default: - return "", errorz.Errorf("dialect=%s: %w", dialect, apperr.ErrNotSupported) + return "", apperr.Errorf("dialect=%s: %w", dialect, apperr.ErrNotSupported) } } diff --git a/pkg/internal/config/config.go b/pkg/internal/config/config.go index 940646c..4dd1d05 100644 --- a/pkg/internal/config/config.go +++ b/pkg/internal/config/config.go @@ -8,6 +8,7 @@ import ( errorz "github.com/kunitsucom/util.go/errors" cliz "github.com/kunitsucom/util.go/exp/cli" + apperr "github.com/kunitsucom/ddlctl/pkg/apperr" "github.com/kunitsucom/ddlctl/pkg/internal/logs" ) @@ -38,7 +39,7 @@ var ( func MustLoad(ctx context.Context) (rollback func()) { rollback, err := Load(ctx) if err != nil { - err = errorz.Errorf("Load: %w", err) + err = apperr.Errorf("Load: %w", err) panic(err) } return rollback @@ -51,7 +52,7 @@ func Load(ctx context.Context) (rollback func(), err error) { cfg, err := load(ctx) if err != nil { - return nil, errorz.Errorf("load: %w", err) + return nil, apperr.Errorf("load: %w", err) } globalConfig = cfg @@ -71,7 +72,7 @@ func Load(ctx context.Context) (rollback func(), err error) { func load(ctx context.Context) (cfg *config, err error) { //nolint:unparam cmd, err := cliz.FromContext(ctx) if err != nil { - return nil, errorz.Errorf("cliz.FromContext: %w", err) + return nil, apperr.Errorf("cliz.FromContext: %w", err) } c := &config{ @@ -87,14 +88,16 @@ func load(ctx context.Context) (cfg *config, err error) { //nolint:unparam PKTagGo: loadPKTagGo(ctx, cmd), } - if c.Debug { - logs.Debug = logs.NewDebug() - logs.Debug.Print("debug mode enabled") - } - if c.Trace { + switch { + case c.Trace: + apperr.Errorf = errorz.Errorf //nolint:reassign logs.Trace = logs.NewTrace() logs.Debug = logs.NewDebug() logs.Trace.Print("trace mode enabled") + case c.Debug: + apperr.Errorf = errorz.Errorf //nolint:reassign + logs.Debug = logs.NewDebug() + logs.Debug.Print("debug mode enabled") } if err := json.NewEncoder(logs.Debug).Encode(c); err != nil { diff --git a/pkg/internal/generator/dialect/mysql/mysql.go b/pkg/internal/generator/dialect/mysql/mysql.go index 55caec9..d112f8a 100644 --- a/pkg/internal/generator/dialect/mysql/mysql.go +++ b/pkg/internal/generator/dialect/mysql/mysql.go @@ -3,9 +3,7 @@ package mysql import ( "io" - errorz "github.com/kunitsucom/util.go/errors" - - "github.com/kunitsucom/ddlctl/pkg/errors" + apperr "github.com/kunitsucom/ddlctl/pkg/apperr" ddlast "github.com/kunitsucom/ddlctl/pkg/internal/generator" "github.com/kunitsucom/ddlctl/pkg/internal/logs" ) @@ -31,13 +29,13 @@ func Fprint(w io.Writer, ddl *ddlast.DDL) error { case *ddlast.CreateIndexStmt: fprintCreateIndex(&buf, ddl.Indent, stmt) default: - logs.Warn.Printf("unknown statement type: %T: %v", stmt, errors.ErrNotSupported) + logs.Warn.Printf("unknown statement type: %T: %v", stmt, apperr.ErrNotSupported) continue } } if _, err := io.WriteString(w, buf); err != nil { - return errorz.Errorf("io.WriteString: %w", err) + return apperr.Errorf("io.WriteString: %w", err) } return nil } diff --git a/pkg/internal/generator/dialect/postgres/postgres.go b/pkg/internal/generator/dialect/postgres/postgres.go index 5de69f6..7126397 100644 --- a/pkg/internal/generator/dialect/postgres/postgres.go +++ b/pkg/internal/generator/dialect/postgres/postgres.go @@ -3,9 +3,7 @@ package postgres import ( "io" - errorz "github.com/kunitsucom/util.go/errors" - - "github.com/kunitsucom/ddlctl/pkg/errors" + apperr "github.com/kunitsucom/ddlctl/pkg/apperr" ddlast "github.com/kunitsucom/ddlctl/pkg/internal/generator" "github.com/kunitsucom/ddlctl/pkg/internal/logs" ) @@ -31,13 +29,13 @@ func Fprint(w io.Writer, ddl *ddlast.DDL) error { case *ddlast.CreateIndexStmt: fprintCreateIndex(&buf, ddl.Indent, stmt) default: - logs.Warn.Printf("unknown statement type: %T: %v", stmt, errors.ErrNotSupported) + logs.Warn.Printf("unknown statement type: %T: %v", stmt, apperr.ErrNotSupported) continue } } if _, err := io.WriteString(w, buf); err != nil { - return errorz.Errorf("io.WriteString: %w", err) + return apperr.Errorf("io.WriteString: %w", err) } return nil } diff --git a/pkg/internal/generator/dialect/spanner/spanner.go b/pkg/internal/generator/dialect/spanner/spanner.go index 715b3bb..b92a250 100644 --- a/pkg/internal/generator/dialect/spanner/spanner.go +++ b/pkg/internal/generator/dialect/spanner/spanner.go @@ -3,9 +3,7 @@ package spanner import ( "io" - errorz "github.com/kunitsucom/util.go/errors" - - "github.com/kunitsucom/ddlctl/pkg/errors" + apperr "github.com/kunitsucom/ddlctl/pkg/apperr" ddlast "github.com/kunitsucom/ddlctl/pkg/internal/generator" "github.com/kunitsucom/ddlctl/pkg/internal/logs" ) @@ -31,13 +29,13 @@ func Fprint(w io.Writer, ddl *ddlast.DDL) error { case *ddlast.CreateIndexStmt: fprintCreateIndex(&buf, ddl.Indent, stmt) default: - logs.Warn.Printf("unknown statement type: %T: %v", stmt, errors.ErrNotSupported) + logs.Warn.Printf("unknown statement type: %T: %v", stmt, apperr.ErrNotSupported) continue } } if _, err := io.WriteString(w, buf); err != nil { - return errorz.Errorf("io.WriteString: %w", err) + return apperr.Errorf("io.WriteString: %w", err) } return nil } diff --git a/pkg/internal/lang/go/extract_source.go b/pkg/internal/lang/go/extract_source.go index 53f380f..d35cac1 100644 --- a/pkg/internal/lang/go/extract_source.go +++ b/pkg/internal/lang/go/extract_source.go @@ -8,10 +8,9 @@ import ( "regexp" "sync" - errorz "github.com/kunitsucom/util.go/errors" filepathz "github.com/kunitsucom/util.go/path/filepath" - apperr "github.com/kunitsucom/ddlctl/pkg/errors" + apperr "github.com/kunitsucom/ddlctl/pkg/apperr" "github.com/kunitsucom/ddlctl/pkg/internal/config" "github.com/kunitsucom/ddlctl/pkg/internal/logs" ) @@ -84,7 +83,7 @@ func extractDDLSourceFromDDLTagGo(_ context.Context, fset *token.FileSet, f *goa } if len(ddlSrc) == 0 { - return nil, errorz.Errorf("ddl-tag-go=%s: %w", config.DDLTagGo(), apperr.ErrDDLTagGoAnnotationNotFoundInSource) + return nil, apperr.Errorf("ddl-tag-go=%s: %w", config.DDLTagGo(), apperr.ErrDDLTagGoAnnotationNotFoundInSource) } return ddlSrc, nil diff --git a/pkg/internal/lang/go/parse.go b/pkg/internal/lang/go/parse.go index 493e499..5b329a7 100644 --- a/pkg/internal/lang/go/parse.go +++ b/pkg/internal/lang/go/parse.go @@ -14,11 +14,10 @@ import ( "strings" "unicode" - errorz "github.com/kunitsucom/util.go/errors" filepathz "github.com/kunitsucom/util.go/path/filepath" slicez "github.com/kunitsucom/util.go/slices" - apperr "github.com/kunitsucom/ddlctl/pkg/errors" + apperr "github.com/kunitsucom/ddlctl/pkg/apperr" "github.com/kunitsucom/ddlctl/pkg/internal/config" ddlast "github.com/kunitsucom/ddlctl/pkg/internal/generator" langutil "github.com/kunitsucom/ddlctl/pkg/internal/lang/util" @@ -38,14 +37,14 @@ func Parse(ctx context.Context, src string) (*ddlast.DDL, error) { info, err := os.Stat(sourceAbs) if err != nil { - return nil, errorz.Errorf("os.Stat: %w", err) + return nil, apperr.Errorf("os.Stat: %w", err) } ddl := ddlast.NewDDL(ctx) if info.IsDir() { if err := filepath.WalkDir(sourceAbs, walkDirFn(ctx, ddl)); err != nil { - return nil, errorz.Errorf("filepath.WalkDir: %w", err) + return nil, apperr.Errorf("filepath.WalkDir: %w", err) } return ddl, nil @@ -53,7 +52,7 @@ func Parse(ctx context.Context, src string) (*ddlast.DDL, error) { stmts, err := parseFile(ctx, sourceAbs) if err != nil { - return nil, errorz.Errorf("Parse: %w", err) + return nil, apperr.Errorf("Parse: %w", err) } ddl.Stmts = append(ddl.Stmts, stmts...) @@ -79,7 +78,7 @@ func walkDirFn(ctx context.Context, ddl *ddlast.DDL) func(path string, d os.DirE logs.Debug.Printf("parseFile: %s: %v", path, err) return nil } - return errorz.Errorf("parseFile: %w", err) + return apperr.Errorf("parseFile: %w", err) } ddl.Stmts = append(ddl.Stmts, stmts...) @@ -93,12 +92,12 @@ func parseFile(ctx context.Context, filename string) ([]ddlast.Stmt, error) { fset := token.NewFileSet() f, err := parser.ParseFile(fset, filename, nil, parser.ParseComments) if err != nil { - return nil, errorz.Errorf("parser.ParseFile: %w", err) + return nil, apperr.Errorf("parser.ParseFile: %w", err) } ddlSrc, err := extractDDLSourceFromDDLTagGo(ctx, fset, f) if err != nil { - return nil, errorz.Errorf("extractDDLSourceFromDDLTagGo: %w", err) + return nil, apperr.Errorf("extractDDLSourceFromDDLTagGo: %w", err) } dumpDDLSource(fset, ddlSrc) diff --git a/pkg/internal/lang/go/parse_test.go b/pkg/internal/lang/go/parse_test.go index ed285d0..7c174c9 100644 --- a/pkg/internal/lang/go/parse_test.go +++ b/pkg/internal/lang/go/parse_test.go @@ -12,7 +12,7 @@ import ( "github.com/kunitsucom/util.go/testing/assert" "github.com/kunitsucom/util.go/testing/require" - apperr "github.com/kunitsucom/ddlctl/pkg/errors" + "github.com/kunitsucom/ddlctl/pkg/apperr" "github.com/kunitsucom/ddlctl/pkg/internal/config" "github.com/kunitsucom/ddlctl/pkg/internal/fixture" ddlast "github.com/kunitsucom/ddlctl/pkg/internal/generator" diff --git a/pkg/show/cockroachdb/show_create_all_tables.go b/pkg/show/cockroachdb/show_create_all_tables.go index 80a398f..1a9919c 100644 --- a/pkg/show/cockroachdb/show_create_all_tables.go +++ b/pkg/show/cockroachdb/show_create_all_tables.go @@ -5,7 +5,8 @@ import ( "database/sql" sqlz "github.com/kunitsucom/util.go/database/sql" - errorz "github.com/kunitsucom/util.go/errors" + + apperr "github.com/kunitsucom/ddlctl/pkg/apperr" ) type sqlQueryerContext = interface { @@ -28,7 +29,7 @@ func ShowCreateAllTables(ctx context.Context, db sqlQueryerContext) (query strin createTableStmts := new([]*CreateStatement) if err := dbz.QueryContext(ctx, createTableStmts, queryShowCreateAllTables); err != nil { - return "", errorz.Errorf("dbz.QueryContext: %w", err) + return "", apperr.Errorf("dbz.QueryContext: %w", err) } for _, stmt := range *createTableStmts { query += stmt.CreateStatement + "\n" diff --git a/pkg/show/mysql/show_create_all_tables.go b/pkg/show/mysql/show_create_all_tables.go index 1f797c8..09b8acd 100644 --- a/pkg/show/mysql/show_create_all_tables.go +++ b/pkg/show/mysql/show_create_all_tables.go @@ -6,7 +6,8 @@ import ( "fmt" sqlz "github.com/kunitsucom/util.go/database/sql" - errorz "github.com/kunitsucom/util.go/errors" + + apperr "github.com/kunitsucom/ddlctl/pkg/apperr" ) type sqlQueryerContext = interface { @@ -53,7 +54,7 @@ func ShowCreateAllTables(ctx context.Context, db sqlQueryerContext, opts ...Show tableNames := new([]*TableName) tableNamesQuery := fmt.Sprintf("SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_SCHEMA = %s", databaseQuoted) if err := dbz.QueryContext(ctx, tableNames, tableNamesQuery); err != nil { - return "", errorz.Errorf("dbz.QueryContext: q=%s: %w", tableNamesQuery, err) + return "", apperr.Errorf("dbz.QueryContext: q=%s: %w", tableNamesQuery, err) } // type CreateStatement struct { @@ -67,7 +68,7 @@ func ShowCreateAllTables(ctx context.Context, db sqlQueryerContext, opts ...Show showCreateTable := new(ShowCreateTable) showCreateTableQuery := fmt.Sprintf("SHOW CREATE TABLE `%s`", tn.TableName) if err := dbz.QueryContext(ctx, showCreateTable, showCreateTableQuery); err != nil { - return "", errorz.Errorf("dbz.QueryContext: q=%s: %w", showCreateTableQuery, err) + return "", apperr.Errorf("dbz.QueryContext: q=%s: %w", showCreateTableQuery, err) } query += showCreateTable.CreateStatement + ";\n" @@ -75,7 +76,7 @@ func ShowCreateAllTables(ctx context.Context, db sqlQueryerContext, opts ...Show // showCreateIndex := fmt.Sprintf("SELECT CONCAT('CREATE INDEX ', INDEX_NAME, ' ON ', TABLE_NAME, ' (', GROUP_CONCAT(COLUMN_NAME ORDER BY SEQ_IN_INDEX), ');') AS 'create_statement' FROM information_schema.STATISTICS WHERE TABLE_SCHEMA = %s AND INDEX_NAME IS NOT NULL AND INDEX_NAME != 'PRIMARY' AND TABLE_NAME = '%s' GROUP BY INDEX_NAME, TABLE_NAME;", databaseQuoted, tn.TableName) // createStatements := new([]*CreateStatement) // if err := dbz.QueryContext(ctx, createStatements, showCreateIndex); err != nil { - // return "", errorz.Errorf("dbz.QueryContext: q=%s: %w", showCreateIndex, err) + // return "", apperr.Errorf("dbz.QueryContext: q=%s: %w", showCreateIndex, err) // } // for _, createStatement := range *createStatements { // query += createStatement.CreateStatement + "\n" diff --git a/pkg/show/postgres/show_create_all_tables.go b/pkg/show/postgres/show_create_all_tables.go index 8c6e202..f24510d 100644 --- a/pkg/show/postgres/show_create_all_tables.go +++ b/pkg/show/postgres/show_create_all_tables.go @@ -6,7 +6,8 @@ import ( "fmt" sqlz "github.com/kunitsucom/util.go/database/sql" - errorz "github.com/kunitsucom/util.go/errors" + + apperr "github.com/kunitsucom/ddlctl/pkg/apperr" ) type sqlQueryerContext = interface { @@ -126,7 +127,7 @@ func ShowCreateAllTables(ctx context.Context, db sqlQueryerContext, opts ...Show createTableStmts := new([]*CreateStatement) if err := dbz.QueryContext(ctx, createTableStmts, fmt.Sprintf(formatShowCreateAllTables, cfg.schema)); err != nil { - return "", errorz.Errorf("dbz.QueryContext: %w", err) + return "", apperr.Errorf("dbz.QueryContext: %w", err) } for _, stmt := range *createTableStmts { query += stmt.CreateStatement + "\n" @@ -134,7 +135,7 @@ func ShowCreateAllTables(ctx context.Context, db sqlQueryerContext, opts ...Show createIndexStmts := new([]*CreateStatement) if err := dbz.QueryContext(ctx, createIndexStmts, fmt.Sprintf(formatShowCreateAllIndexes, cfg.schema, cfg.schema)); err != nil { - return "", errorz.Errorf("dbz.QueryContext: %w", err) + return "", apperr.Errorf("dbz.QueryContext: %w", err) } for _, stmt := range *createIndexStmts { query += stmt.CreateStatement + ";\n" diff --git a/pkg/show/spanner/show_create_all_tables.go b/pkg/show/spanner/show_create_all_tables.go index 21bd7ab..204c225 100644 --- a/pkg/show/spanner/show_create_all_tables.go +++ b/pkg/show/spanner/show_create_all_tables.go @@ -6,11 +6,13 @@ import ( "fmt" sqlz "github.com/kunitsucom/util.go/database/sql" - errorz "github.com/kunitsucom/util.go/errors" + apperr "github.com/kunitsucom/ddlctl/pkg/apperr" "github.com/kunitsucom/ddlctl/pkg/internal/logs" ) +// NOTE: https://cloud.google.com/spanner/docs/information-schema?hl=ja + type sqlQueryerContext = interface { QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) } @@ -38,7 +40,7 @@ type informationSchemaColumn struct { func (c *informationSchemaColumn) String() string { d := fmt.Sprintf("%s %s", c.ColumnName, c.SpannerType) if c.ColumnDefault != nil { - d += fmt.Sprintf(" DEFAULT %s", *c.ColumnDefault) + d += fmt.Sprintf(" DEFAULT (%s)", *c.ColumnDefault) } if c.IsNullable == "NO" { d += " NOT NULL" @@ -47,7 +49,40 @@ func (c *informationSchemaColumn) String() string { } const ( - queryShowPrimaryKey = `SELECT i.INDEX_NAME, i.INDEX_TYPE, ic.COLUMN_NAME, ic.COLUMN_ORDERING, ic.ORDINAL_POSITION FROM INFORMATION_SCHEMA.INDEXES AS i INNER JOIN INFORMATION_SCHEMA.INDEX_COLUMNS AS ic ON i.TABLE_NAME = ic.TABLE_NAME WHERE i.TABLE_NAME = ? AND i.INDEX_TYPE = "PRIMARY_KEY" ORDER BY i.TABLE_NAME, ic.ORDINAL_POSITION;` + queryShowTableColumnOptions = `SELECT COLUMN_NAME, OPTION_NAME, OPTION_VALUE FROM INFORMATION_SCHEMA.COLUMN_OPTIONS WHERE TABLE_NAME = ?;` +) + +type informationSchemaColumnOption struct { + ColumnName string `db:"COLUMN_NAME"` + OptionName string `db:"OPTION_NAME"` + OptionValue string `db:"OPTION_VALUE"` +} + +func (c *informationSchemaColumnOption) String() string { + return fmt.Sprintf("%s = %s", c.OptionName, c.OptionValue) +} + +const ( + queryShowPrimaryKey = `-- SHOW TABLES +SELECT + i.INDEX_NAME, + i.INDEX_TYPE, + ic.COLUMN_NAME, + ic.COLUMN_ORDERING, + ic.ORDINAL_POSITION +FROM + INFORMATION_SCHEMA.INDEXES AS i +INNER JOIN + INFORMATION_SCHEMA.INDEX_COLUMNS AS ic +ON + i.TABLE_NAME = ic.TABLE_NAME +WHERE + i.TABLE_NAME = ? + AND i.INDEX_TYPE = "PRIMARY_KEY" +ORDER BY + i.TABLE_NAME, ic.ORDINAL_POSITION +; +` ) type informationSchemaPrimaryKey struct { @@ -70,7 +105,27 @@ type informationSchemaIndexName struct { } const ( - queryShowIndexes = `SELECT i.INDEX_NAME, i.INDEX_TYPE, ic.COLUMN_NAME, ic.COLUMN_ORDERING, ic.ORDINAL_POSITION FROM INFORMATION_SCHEMA.INDEXES AS i INNER JOIN INFORMATION_SCHEMA.INDEX_COLUMNS AS ic ON i.TABLE_NAME = ic.TABLE_NAME WHERE i.TABLE_NAME = ? AND i.INDEX_TYPE != "PRIMARY_KEY" ORDER BY i.TABLE_NAME, i.INDEX_NAME, ic.ORDINAL_POSITION;` + queryShowIndexes = `-- SHOW INDEXES +SELECT + ic.INDEX_NAME, + i.INDEX_TYPE, + ic.COLUMN_NAME, + ic.COLUMN_ORDERING, + ic.ORDINAL_POSITION +FROM + INFORMATION_SCHEMA.INDEXES AS i +INNER JOIN + INFORMATION_SCHEMA.INDEX_COLUMNS AS ic +ON + i.TABLE_NAME = ic.TABLE_NAME +WHERE + i.TABLE_NAME = ? + AND i.INDEX_TYPE != "PRIMARY_KEY" + AND ic.INDEX_NAME != "PRIMARY_KEY" +ORDER BY + i.TABLE_NAME, ic.INDEX_NAME, ic.ORDINAL_POSITION +; +` ) type informationSchemaIndex struct { @@ -114,7 +169,7 @@ func ShowCreateAllTables(ctx context.Context, db sqlQueryerContext, opts ...Show tables := make([]*informationSchemaTable, 0) if err := dbz.QueryContext(ctx, &tables, querySelectTableName); err != nil { - return "", errorz.Errorf("dbz.QueryContext: %w", err) + return "", apperr.Errorf("dbz.QueryContext: %w", err) } tablesLastIndex := len(tables) - 1 @@ -124,13 +179,36 @@ func ShowCreateAllTables(ctx context.Context, db sqlQueryerContext, opts ...Show columns := make([]*informationSchemaColumn, 0) if err := dbz.QueryContext(ctx, &columns, queryShowCreateAllTables, tbl.TableName); err != nil { - return "", errorz.Errorf("dbz.QueryContext: %w", err) + return "", apperr.Errorf("dbz.QueryContext: %w", err) + } + + allColumnOptions := make([]*informationSchemaColumnOption, 0) + if err := dbz.QueryContext(ctx, &allColumnOptions, queryShowTableColumnOptions, tbl.TableName); err != nil { + return "", apperr.Errorf("dbz.QueryContext: %w", err) } columnsLastIndex := len(columns) - 1 - for i, col := range columns { + for colIdx, col := range columns { d += fmt.Sprintf(" %s", col) - if i != columnsLastIndex { + if len(allColumnOptions) > 0 { + columnOptions := make([]*informationSchemaColumnOption, 0) + for _, opt := range allColumnOptions { + if col.ColumnName == opt.ColumnName { + columnOptions = append(columnOptions, opt) + } + } + if len(columnOptions) > 0 { + d += " OPTIONS (" + for columnOptionsIdx, opt := range columnOptions { + d += opt.String() + if columnOptionsLastIndex := len(columnOptions) - 1; columnOptionsIdx != columnOptionsLastIndex { + d += ", " + } + } + d += ")" + } + } + if colIdx != columnsLastIndex { d += "," } d += "\n" @@ -141,7 +219,7 @@ func ShowCreateAllTables(ctx context.Context, db sqlQueryerContext, opts ...Show primaryKeyColumns := make([]*informationSchemaPrimaryKey, 0) if err := dbz.QueryContext(ctx, &primaryKeyColumns, queryShowPrimaryKey, tbl.TableName); err != nil { - return "", errorz.Errorf("dbz.QueryContext: %w", err) + return "", apperr.Errorf("dbz.QueryContext: %w", err) } if len(primaryKeyColumns) > 0 { @@ -162,13 +240,13 @@ func ShowCreateAllTables(ctx context.Context, db sqlQueryerContext, opts ...Show // INDEX indexNames := make([]*informationSchemaIndexName, 0) if err := dbz.QueryContext(ctx, &indexNames, querySelectIndexes, tbl.TableName); err != nil { - return "", errorz.Errorf("dbz.QueryContext: %w", err) + return "", apperr.Errorf("dbz.QueryContext: %w", err) } for _, indexName := range indexNames { indexes := make([]*informationSchemaIndex, 0) if err := dbz.QueryContext(ctx, &indexes, queryShowIndexes, tbl.TableName); err != nil { - return "", errorz.Errorf("dbz.QueryContext: %w", err) + return "", apperr.Errorf("dbz.QueryContext: %w", err) } d := "CREATE "