diff --git a/drainer/translator/mysql.go b/drainer/translator/mysql.go index 4fa7ceff5..387fe48f7 100644 --- a/drainer/translator/mysql.go +++ b/drainer/translator/mysql.go @@ -31,7 +31,7 @@ import ( const implicitColID = -1 -func genDBInsert(schema string, ptable, table *model.TableInfo, row []byte) (names []string, args []interface{}, err error) { +func genDBInsert(schema string, ptable, table *model.TableInfo, row []byte, destDBType loader.DBType) (names []string, args []interface{}, err error) { columns := writableColumns(table) columnValues, err := insertRowToDatums(table, row) @@ -46,7 +46,7 @@ func genDBInsert(schema string, ptable, table *model.TableInfo, row []byte) (nam val = getDefaultOrZeroValue(ptable, col) } - value, err := formatData(val, col.FieldType) + value, err := formatData(val, col.FieldType, destDBType) if err != nil { return nil, nil, errors.Trace(err) } @@ -58,7 +58,7 @@ func genDBInsert(schema string, ptable, table *model.TableInfo, row []byte) (nam return names, args, nil } -func genDBUpdate(schema string, ptable, table *model.TableInfo, row []byte, canAppendDefaultValue bool) (names []string, values []interface{}, oldValues []interface{}, err error) { +func genDBUpdate(schema string, ptable, table *model.TableInfo, row []byte, canAppendDefaultValue bool, destDBType loader.DBType) (names []string, values []interface{}, oldValues []interface{}, err error) { columns := writableColumns(table) updtDecoder := newUpdateDecoder(ptable, table, canAppendDefaultValue) @@ -69,12 +69,12 @@ func genDBUpdate(schema string, ptable, table *model.TableInfo, row []byte, canA return nil, nil, nil, errors.Annotatef(err, "table `%s`.`%s`", schema, table.Name) } - _, oldValues, err = generateColumnAndValue(columns, oldColumnValues) + _, oldValues, err = generateColumnAndValue(columns, oldColumnValues, destDBType) if err != nil { return nil, nil, nil, errors.Trace(err) } - updateColumns, values, err = generateColumnAndValue(columns, newColumnValues) + updateColumns, values, err = generateColumnAndValue(columns, newColumnValues, destDBType) if err != nil { return nil, nil, nil, errors.Trace(err) } @@ -84,7 +84,7 @@ func genDBUpdate(schema string, ptable, table *model.TableInfo, row []byte, canA return } -func genDBDelete(schema string, table *model.TableInfo, row []byte) (names []string, values []interface{}, err error) { +func genDBDelete(schema string, table *model.TableInfo, row []byte, destDBType loader.DBType) (names []string, values []interface{}, err error) { columns := table.Columns colsTypeMap := util.ToColumnTypeMap(columns) @@ -93,7 +93,7 @@ func genDBDelete(schema string, table *model.TableInfo, row []byte) (names []str return nil, nil, errors.Trace(err) } - columns, values, err = generateColumnAndValue(columns, columnValues) + columns, values, err = generateColumnAndValue(columns, columnValues, destDBType) if err != nil { return nil, nil, errors.Trace(err) } @@ -144,33 +144,35 @@ func TiBinlogToTxn(infoGetter TableInfoGetter, schema string, table string, tiBi switch mutType { case tipb.MutationType_Insert: - names, args, err := genDBInsert(schema, pinfo, info, row) + names, args, err := genDBInsert(schema, pinfo, info, row, loader.MysqlDB) if err != nil { return nil, errors.Annotate(err, "gen insert fail") } dml := &loader.DML{ - Tp: loader.InsertDMLType, - Database: schema, - Table: table, - Values: make(map[string]interface{}), + Tp: loader.InsertDMLType, + Database: schema, + Table: table, + Values: make(map[string]interface{}), + DestDBType: loader.MysqlDB, } txn.DMLs = append(txn.DMLs, dml) for i, name := range names { dml.Values[name] = args[i] } case tipb.MutationType_Update: - names, args, oldArgs, err := genDBUpdate(schema, pinfo, info, row, canAppendDefaultValue) + names, args, oldArgs, err := genDBUpdate(schema, pinfo, info, row, canAppendDefaultValue, loader.MysqlDB) if err != nil { return nil, errors.Annotate(err, "gen update fail") } dml := &loader.DML{ - Tp: loader.UpdateDMLType, - Database: schema, - Table: table, - Values: make(map[string]interface{}), - OldValues: make(map[string]interface{}), + Tp: loader.UpdateDMLType, + Database: schema, + Table: table, + Values: make(map[string]interface{}), + OldValues: make(map[string]interface{}), + DestDBType: loader.MysqlDB, } txn.DMLs = append(txn.DMLs, dml) for i, name := range names { @@ -179,16 +181,17 @@ func TiBinlogToTxn(infoGetter TableInfoGetter, schema string, table string, tiBi } case tipb.MutationType_DeleteRow: - names, args, err := genDBDelete(schema, info, row) + names, args, err := genDBDelete(schema, info, row, loader.MysqlDB) if err != nil { return nil, errors.Annotate(err, "gen delete fail") } dml := &loader.DML{ - Tp: loader.DeleteDMLType, - Database: schema, - Table: table, - Values: make(map[string]interface{}), + Tp: loader.DeleteDMLType, + Database: schema, + Table: table, + Values: make(map[string]interface{}), + DestDBType: loader.MysqlDB, } txn.DMLs = append(txn.DMLs, dml) for i, name := range names { @@ -225,7 +228,7 @@ func genColumnNameList(columns []*model.ColumnInfo) (names []string) { return } -func generateColumnAndValue(columns []*model.ColumnInfo, columnValues map[int64]types.Datum) ([]*model.ColumnInfo, []interface{}, error) { +func generateColumnAndValue(columns []*model.ColumnInfo, columnValues map[int64]types.Datum, destDBType loader.DBType) ([]*model.ColumnInfo, []interface{}, error) { var newColumn []*model.ColumnInfo var newColumnsValues []interface{} @@ -233,7 +236,7 @@ func generateColumnAndValue(columns []*model.ColumnInfo, columnValues map[int64] val, ok := columnValues[col.ID] if ok { newColumn = append(newColumn, col) - value, err := formatData(val, col.FieldType) + value, err := formatData(val, col.FieldType, destDBType) if err != nil { return nil, nil, errors.Trace(err) } @@ -245,13 +248,19 @@ func generateColumnAndValue(columns []*model.ColumnInfo, columnValues map[int64] return newColumn, newColumnsValues, nil } -func formatData(data types.Datum, ft types.FieldType) (types.Datum, error) { +func formatData(data types.Datum, ft types.FieldType, destDBType loader.DBType) (types.Datum, error) { if data.GetValue() == nil { return data, nil } switch ft.Tp { - case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeNewDate, mysql.TypeTimestamp, mysql.TypeDuration, mysql.TypeNewDecimal, mysql.TypeJSON: + case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeNewDate, mysql.TypeTimestamp, mysql.TypeNewDecimal, mysql.TypeJSON: + data = types.NewDatum(fmt.Sprintf("%v", data.GetValue())) + case mysql.TypeDuration: + //only for oracle db + if destDBType == loader.OracleDB { + return types.Datum{}, errors.New("unsupported column type[time]") + } data = types.NewDatum(fmt.Sprintf("%v", data.GetValue())) case mysql.TypeEnum: data = types.NewDatum(data.GetMysqlEnum().Value) @@ -264,7 +273,21 @@ func formatData(data types.Datum, ft types.FieldType) (types.Datum, error) { return types.Datum{}, err } data = types.NewUintDatum(val) + case mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob, mysql.TypeBlob: + //only for oracle db + if destDBType == loader.OracleDB && isBlob(ft) { + data = types.NewBytesDatum(data.GetBytes()) + } } return data, nil } + +func isBlob(ft types.FieldType) bool { + stype := types.TypeToStr(ft.Tp, ft.Charset) + switch stype { + case "blob", "tinyblob", "mediumblob", "longblob": + return true + } + return false +} diff --git a/drainer/translator/oracle.go b/drainer/translator/oracle.go index 7d6625557..cd5a239d0 100644 --- a/drainer/translator/oracle.go +++ b/drainer/translator/oracle.go @@ -66,7 +66,7 @@ func TiBinlogToOracleTxn(infoGetter TableInfoGetter, schema string, table string switch mutType { case tipb.MutationType_Insert: - names, args, err := genDBInsert(schema, pinfo, info, row) + names, args, err := genDBInsert(schema, pinfo, info, row, loader.OracleDB) if err != nil { return nil, errors.Annotate(err, "gen insert fail") } @@ -77,13 +77,14 @@ func TiBinlogToOracleTxn(infoGetter TableInfoGetter, schema string, table string Table: downStreamTable, Values: make(map[string]interface{}), UpColumnsInfoMap: tableIDColumnsMap[mut.GetTableId()], + DestDBType: loader.OracleDB, } txn.DMLs = append(txn.DMLs, dml) for i, name := range names { dml.Values[strings.ToUpper(name)] = args[i] } case tipb.MutationType_Update: - names, args, oldArgs, err := genDBUpdate(schema, pinfo, info, row, canAppendDefaultValue) + names, args, oldArgs, err := genDBUpdate(schema, pinfo, info, row, canAppendDefaultValue, loader.OracleDB) if err != nil { return nil, errors.Annotate(err, "gen update fail") } @@ -95,6 +96,7 @@ func TiBinlogToOracleTxn(infoGetter TableInfoGetter, schema string, table string Values: make(map[string]interface{}), OldValues: make(map[string]interface{}), UpColumnsInfoMap: tableIDColumnsMap[mut.GetTableId()], + DestDBType: loader.OracleDB, } txn.DMLs = append(txn.DMLs, dml) for i, name := range names { @@ -103,7 +105,7 @@ func TiBinlogToOracleTxn(infoGetter TableInfoGetter, schema string, table string } case tipb.MutationType_DeleteRow: - names, args, err := genDBDelete(schema, info, row) + names, args, err := genDBDelete(schema, info, row, loader.OracleDB) if err != nil { return nil, errors.Annotate(err, "gen delete fail") } @@ -114,6 +116,7 @@ func TiBinlogToOracleTxn(infoGetter TableInfoGetter, schema string, table string Table: downStreamTable, Values: make(map[string]interface{}), UpColumnsInfoMap: tableIDColumnsMap[mut.GetTableId()], + DestDBType: loader.OracleDB, } txn.DMLs = append(txn.DMLs, dml) for i, name := range names { diff --git a/drainer/translator/pb.go b/drainer/translator/pb.go index de0d19f15..9e78040c3 100644 --- a/drainer/translator/pb.go +++ b/drainer/translator/pb.go @@ -29,6 +29,7 @@ import ( "github.com/pingcap/tidb/util/codec" tipb "github.com/pingcap/tipb/go-binlog" + "github.com/pingcap/tidb-binlog/pkg/loader" "github.com/pingcap/tidb-binlog/pkg/util" pb "github.com/pingcap/tidb-binlog/proto/binlog" ) @@ -137,7 +138,7 @@ func genInsert(schema string, ptable, table *model.TableInfo, row []byte) (event val = getDefaultOrZeroValue(ptable, col) } - value, err := formatData(val, col.FieldType) + value, err := formatData(val, col.FieldType, loader.DBTypeUnknown) if err != nil { return nil, errors.Trace(err) } @@ -173,11 +174,11 @@ func genUpdate(schema string, ptable, table *model.TableInfo, row []byte, canApp for _, col := range columns { val, ok := newColumnValues[col.ID] if ok { - oldValue, err := formatData(oldColumnValues[col.ID], col.FieldType) + oldValue, err := formatData(oldColumnValues[col.ID], col.FieldType, loader.DBTypeUnknown) if err != nil { return nil, errors.Trace(err) } - newValue, err := formatData(val, col.FieldType) + newValue, err := formatData(val, col.FieldType, loader.DBTypeUnknown) if err != nil { return nil, errors.Trace(err) } @@ -217,7 +218,7 @@ func genDelete(schema string, table *model.TableInfo, row []byte) (event *pb.Eve for _, col := range columns { val, ok := columnValues[col.ID] if ok { - value, err := formatData(val, col.FieldType) + value, err := formatData(val, col.FieldType, loader.DBTypeUnknown) if err != nil { return nil, errors.Trace(err) } diff --git a/go.mod b/go.mod index e360927c0..23a4a558d 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/Shopify/sarama v1.30.0 github.com/dustin/go-humanize v1.0.0 github.com/go-sql-driver/mysql v1.6.0 - github.com/godror/godror v0.29.0 + github.com/godror/godror v0.33.0 github.com/gogo/protobuf v1.3.2 github.com/golang/mock v1.6.0 github.com/golang/protobuf v1.5.2 @@ -71,7 +71,8 @@ require ( github.com/eapache/queue v1.1.0 // indirect github.com/fatih/color v1.13.0 // indirect github.com/form3tech-oss/jwt-go v3.2.5+incompatible // indirect - github.com/go-logfmt/logfmt v0.5.0 // indirect + github.com/go-logfmt/logfmt v0.5.1 // indirect + github.com/go-logr/logr v1.2.3 // indirect github.com/go-ole/go-ole v1.2.6 // indirect github.com/godror/knownpb v0.1.0 // indirect github.com/golang/glog v1.0.0 // indirect diff --git a/go.sum b/go.sum index 540b47ad5..86a38bcd1 100644 --- a/go.sum +++ b/go.sum @@ -76,7 +76,6 @@ github.com/Shopify/sarama v1.30.0 h1:TOZL6r37xJBDEMLx4yjB77jxbZYXPaDow08TSK6vIL0 github.com/Shopify/sarama v1.30.0/go.mod h1:zujlQQx1kzHsh4jfV1USnptCQrHAEZ2Hk8fTKCulPVs= github.com/Shopify/toxiproxy/v2 v2.1.6-0.20210914104332-15ea381dcdae h1:ePgznFqEG1v3AjMklnK8H7BSc++FDSo7xfK9K7Af+0Y= github.com/Shopify/toxiproxy/v2 v2.1.6-0.20210914104332-15ea381dcdae/go.mod h1:/cvHQkZ1fst0EmZnA5dFtiQdWCNCFYzb+uE2vqVgvx0= -github.com/UNO-SOFT/knownpb v0.0.2/go.mod h1:p80FhK7Efqtw1I44+KdbwHKT2Fg2KluTHKtkGN8YXfE= github.com/VividCortex/ewma v1.1.1 h1:MnEK4VOv6n0RSY4vtRe3h11qjxL3+t0B8yOL8iMXdcM= github.com/VividCortex/ewma v1.1.1/go.mod h1:2Tkkvm3sRDVXaiyucHiACn4cqf7DpdyLvmxzcbUokwA= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= @@ -218,8 +217,11 @@ github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2 github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= -github.com/go-logfmt/logfmt v0.5.0 h1:TrB8swr/68K7m9CcGut2g3UOihhbcbiMAYiuTXdEih4= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= +github.com/go-logfmt/logfmt v0.5.1 h1:otpy5pqBCBZ1ng9RQ0dPu4PN7ba75Y/aA+UpowDyNVA= +github.com/go-logfmt/logfmt v0.5.1/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= +github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0= +github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= @@ -228,8 +230,8 @@ github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LB github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/godror/godror v0.29.0 h1:J5PiWMy7glh4cZnExYk5ryAYx0c972YQUavh/ml+wlM= -github.com/godror/godror v0.29.0/go.mod h1:dwNYusI/Ug2JlbJuVvQQMhzlxVEJeq+MwaXwTYlDyC8= +github.com/godror/godror v0.33.0 h1:ZK1W7GohHVDPoLp/37U9QCSHARnYB4vVxNJya+CyWQ4= +github.com/godror/godror v0.33.0/go.mod h1:qHYnDISFm/h0vM+HDwg0LpyoLvxRKFRSwvhYF7ufjZ8= github.com/godror/knownpb v0.1.0 h1:dJPK8s/I3PQzGGaGcUStL2zIaaICNzKKAK8BzP1uLio= github.com/godror/knownpb v0.1.0/go.mod h1:4nRFbQo1dDuwKnblRXDxrfCFYeT4hjg3GjMqef58eRE= github.com/gogo/protobuf v0.0.0-20171007142547-342cbe0a0415/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= diff --git a/pkg/loader/executor.go b/pkg/loader/executor.go index 226e12d92..1656a7977 100644 --- a/pkg/loader/executor.go +++ b/pkg/loader/executor.go @@ -43,7 +43,7 @@ var ( type executor struct { db *gosql.DB - destDBType string + destDBType DBType batchSize int workerCount int info *loopbacksync.LoopBackSync @@ -70,7 +70,7 @@ func (e *executor) withRefreshTableInfo(fn func(schema string, table string) (in return e } -func (e *executor) withDestDBType(destDBType string) *executor { +func (e *executor) withDestDBType(destDBType DBType) *executor { e.destDBType = destDBType return e } @@ -109,15 +109,7 @@ type tx struct { // wrap of sql.Tx.Exec() func (tx *tx) exec(query string, args ...interface{}) (gosql.Result, error) { start := time.Now() - var ( - res gosql.Result - err error - ) - if len(args) == 0 { - res, err = tx.Tx.Exec(query) - } else { - res, err = tx.Tx.Exec(query, args...) - } + res, err := tx.Tx.Exec(query, args...) if tx.queryHistogramVec != nil { tx.queryHistogramVec.WithLabelValues("exec").Observe(time.Since(start).Seconds()) } @@ -126,11 +118,7 @@ func (tx *tx) exec(query string, args ...interface{}) (gosql.Result, error) { } func (tx *tx) autoRollbackExec(query string, args ...interface{}) (res gosql.Result, err error) { - if len(args) == 0 { - res, err = tx.exec(query) - } else { - res, err = tx.exec(query, args...) - } + res, err = tx.exec(query, args...) if err != nil { log.Error("Exec fail, will rollback", zap.String("query", query), zap.Reflect("args", args), zap.Error(err)) if rbErr := tx.Rollback(); rbErr != nil { @@ -225,10 +213,10 @@ func (e *executor) bulkReplace(inserts []*DML) error { var builder strings.Builder - cols := "(" + buildColumnList(info.columns) + ")" + cols := "(" + buildColumnList(info.columns, e.destDBType) + ")" builder.WriteString("REPLACE INTO " + inserts[0].TableName() + cols + " VALUES ") - holder := fmt.Sprintf("(%s)", holderString(len(info.columns))) + holder := fmt.Sprintf("(%s)", holderString(len(info.columns), e.destDBType)) for i := 0; i < len(inserts); i++ { if i > 0 { builder.WriteByte(',') @@ -265,8 +253,8 @@ func (e *executor) oracleBulkOperation(dmls []*DML) error { return errors.Trace(err) } for _, dml := range dmls { - sql := dml.oracleSQL() - _, err = tx.autoRollbackExec(sql) + sql, args := dml.sql() + _, err = tx.autoRollbackExec(sql, args...) if err != nil { return errors.Trace(err) } @@ -296,7 +284,7 @@ func (e *executor) execTableBatch(ctx context.Context, dmls []*DML) error { if allDeletes, ok := types[DeleteDMLType]; ok { bulkDelete := e.bulkDelete - if e.destDBType == "oracle" { + if e.destDBType == OracleDB { bulkDelete = e.oracleBulkOperation } if err := e.splitExecDML(ctx, allDeletes, bulkDelete); err != nil { @@ -306,7 +294,7 @@ func (e *executor) execTableBatch(ctx context.Context, dmls []*DML) error { if allInserts, ok := types[InsertDMLType]; ok { bulkInsert := e.bulkReplace - if e.destDBType == "oracle" { + if e.destDBType == OracleDB { bulkInsert = e.oracleBulkOperation } if err := e.splitExecDML(ctx, allInserts, bulkInsert); err != nil { @@ -316,7 +304,7 @@ func (e *executor) execTableBatch(ctx context.Context, dmls []*DML) error { if allUpdates, ok := types[UpdateDMLType]; ok { bulkUpdate := e.bulkReplace - if e.destDBType == "oracle" { + if e.destDBType == OracleDB { bulkUpdate = e.oracleBulkOperation } if err := e.splitExecDML(ctx, allUpdates, bulkUpdate); err != nil { @@ -463,43 +451,43 @@ func (e *executor) singleOracleExec(dmls []*DML, safeMode bool) error { for _, dml := range dmls { if safeMode && dml.Tp == UpdateDMLType { //delete old row - sql := dml.oracleDeleteSQL() + sql, args := dml.deleteSQL() log.Debug("safeMode and UpdateDMLType", zap.String("delete old", sql)) - _, err := tx.autoRollbackExec(sql) + _, err := tx.autoRollbackExec(sql, args...) if err != nil { return errors.Trace(err) } //delete new row - sql = dml.oracleDeleteNewValueSQL() + sql, args = dml.oracleDeleteNewValueSQL() log.Debug("safeMode and UpdateDMLType", zap.String("delete new old", sql)) - _, err = tx.autoRollbackExec(sql) + _, err = tx.autoRollbackExec(sql, args...) if err != nil { return errors.Trace(err) } //insert new row - sql = dml.oracleInsertSQL() + sql, args = dml.insertSQL() log.Debug("safeMode and UpdateDMLType", zap.String("insert new old", sql)) - _, err = tx.autoRollbackExec(sql) + _, err = tx.autoRollbackExec(sql, args...) if err != nil { return errors.Trace(err) } } else if safeMode && dml.Tp == InsertDMLType { - sql := dml.oracleDeleteSQL() + sql, args := dml.deleteSQL() log.Debug("safeMode and InsertDMLType", zap.String("delete sql", sql)) - _, err := tx.autoRollbackExec(sql) + _, err := tx.autoRollbackExec(sql, args...) if err != nil { return errors.Trace(err) } - sql = dml.oracleInsertSQL() + sql, args = dml.insertSQL() log.Debug("safeMode and InsertDMLType", zap.String("insert sql", sql)) - _, err = tx.autoRollbackExec(sql) + _, err = tx.autoRollbackExec(sql, args...) if err != nil { return errors.Trace(err) } } else { - sql := dml.oracleSQL() + sql, args := dml.sql() log.Debug("normal sql with no safeMode", zap.String("sql", sql)) - _, err := tx.autoRollbackExec(sql) + _, err := tx.autoRollbackExec(sql, args...) if err != nil { return errors.Trace(err) } diff --git a/pkg/loader/executor_test.go b/pkg/loader/executor_test.go index 677884d73..5d7d1c99d 100644 --- a/pkg/loader/executor_test.go +++ b/pkg/loader/executor_test.go @@ -17,11 +17,12 @@ import ( "context" "database/sql" "fmt" + "regexp" + "sync/atomic" + "github.com/pingcap/tidb/parser/model" tmysql "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/parser/types" - "regexp" - "sync/atomic" sqlmock "github.com/DATA-DOG/go-sqlmock" "github.com/go-sql-driver/mysql" @@ -247,6 +248,7 @@ func (s *singleExecSuite) TestOracleSafeUpdate(c *C) { "age": { FieldType: types.FieldType{Tp: tmysql.TypeInt24}}, }, + DestDBType: OracleDB, } delSQL := "DELETE FROM unicorn.users.*" insertSQL := "INSERT INTO unicorn.users.*" @@ -255,7 +257,7 @@ func (s *singleExecSuite) TestOracleSafeUpdate(c *C) { s.dbMock.ExpectExec(delSQL).WillReturnError(errors.New("del")) e := newExecutor(s.db) - e.destDBType = "oracle" + e.destDBType = OracleDB err := e.singleOracleExec([]*DML{&dml}, true) c.Assert(err, ErrorMatches, "del") c.Assert(s.dbMock.ExpectationsWereMet(), IsNil) @@ -266,7 +268,7 @@ func (s *singleExecSuite) TestOracleSafeUpdate(c *C) { s.dbMock.ExpectExec(delSQL).WillReturnResult(sqlmock.NewResult(1, 1)) s.dbMock.ExpectExec(delSQL).WillReturnError(errors.New("del")) e = newExecutor(s.db) - e.destDBType = "oracle" + e.destDBType = OracleDB err = e.singleOracleExec([]*DML{&dml}, true) c.Assert(err, ErrorMatches, "del") c.Assert(s.dbMock.ExpectationsWereMet(), IsNil) @@ -279,7 +281,7 @@ func (s *singleExecSuite) TestOracleSafeUpdate(c *C) { s.dbMock.ExpectExec(insertSQL).WillReturnError(errors.New("insert")) e = newExecutor(s.db) err = e.singleOracleExec([]*DML{&dml}, true) - e.destDBType = "oracle" + e.destDBType = OracleDB c.Assert(err, ErrorMatches, "insert") c.Assert(s.dbMock.ExpectationsWereMet(), IsNil) s.resetMock(c) @@ -291,7 +293,7 @@ func (s *singleExecSuite) TestOracleSafeUpdate(c *C) { s.dbMock.ExpectExec(insertSQL).WillReturnResult(sqlmock.NewResult(1, 1)) s.dbMock.ExpectCommit() e = newExecutor(s.db) - e.destDBType = "oracle" + e.destDBType = OracleDB err = e.singleOracleExec([]*DML{&dml}, true) c.Assert(err, IsNil) c.Assert(s.dbMock.ExpectationsWereMet(), IsNil) @@ -319,6 +321,7 @@ func (s *singleExecSuite) TestOracleSafeInsert(c *C) { "age": { FieldType: types.FieldType{Tp: tmysql.TypeInt24}}, }, + DestDBType: OracleDB, } delSQL := "DELETE FROM unicorn.users.*" insertSQL := "INSERT INTO unicorn.users.*" @@ -328,7 +331,7 @@ func (s *singleExecSuite) TestOracleSafeInsert(c *C) { s.dbMock.ExpectExec(delSQL).WillReturnError(errors.New("del")) e := newExecutor(s.db) - e.destDBType = "oracle" + e.destDBType = OracleDB err := e.singleOracleExec([]*DML{&dml}, true) c.Assert(err, ErrorMatches, "del") c.Assert(s.dbMock.ExpectationsWereMet(), IsNil) @@ -339,7 +342,7 @@ func (s *singleExecSuite) TestOracleSafeInsert(c *C) { s.dbMock.ExpectExec(delSQL).WillReturnResult(sqlmock.NewResult(1, 1)) s.dbMock.ExpectExec(insertSQL).WillReturnError(errors.New("insert")) e = newExecutor(s.db) - e.destDBType = "oracle" + e.destDBType = OracleDB err = e.singleOracleExec([]*DML{&dml}, true) c.Assert(err, ErrorMatches, "insert") c.Assert(s.dbMock.ExpectationsWereMet(), IsNil) @@ -351,7 +354,7 @@ func (s *singleExecSuite) TestOracleSafeInsert(c *C) { s.dbMock.ExpectExec(insertSQL).WillReturnResult(sqlmock.NewResult(1, 1)) s.dbMock.ExpectCommit() e = newExecutor(s.db) - e.destDBType = "oracle" + e.destDBType = OracleDB err = e.singleOracleExec([]*DML{&dml}, true) c.Assert(err, IsNil) c.Assert(s.dbMock.ExpectationsWereMet(), IsNil) @@ -378,6 +381,7 @@ func (s *singleExecSuite) TestOracleSafeDelete(c *C) { "age": { FieldType: types.FieldType{Tp: tmysql.TypeInt24}}, }, + DestDBType: OracleDB, } delSQL := "DELETE FROM unicorn.users.*" @@ -386,7 +390,7 @@ func (s *singleExecSuite) TestOracleSafeDelete(c *C) { s.dbMock.ExpectExec(delSQL).WillReturnError(errors.New("del")) e := newExecutor(s.db) - e.destDBType = "oracle" + e.destDBType = OracleDB err := e.singleOracleExec([]*DML{&dml}, true) c.Assert(err, ErrorMatches, "del") c.Assert(s.dbMock.ExpectationsWereMet(), IsNil) @@ -397,7 +401,7 @@ func (s *singleExecSuite) TestOracleSafeDelete(c *C) { s.dbMock.ExpectExec(delSQL).WillReturnResult(sqlmock.NewResult(1, 1)) s.dbMock.ExpectCommit() e = newExecutor(s.db) - e.destDBType = "oracle" + e.destDBType = OracleDB err = e.singleOracleExec([]*DML{&dml}, true) c.Assert(err, IsNil) c.Assert(s.dbMock.ExpectationsWereMet(), IsNil) diff --git a/pkg/loader/load.go b/pkg/loader/load.go index 3f10a693a..6ea416ce0 100644 --- a/pkg/loader/load.go +++ b/pkg/loader/load.go @@ -68,7 +68,7 @@ type loaderImpl struct { // like column name, pk & uk db *gosql.DB //downStream db type, mysql,tidb,oracle - destDBType string + destDBType DBType // only set for test getTableInfoFromDB func(db *gosql.DB, schema string, table string) (info *tableInfo, err error) opts options @@ -130,7 +130,7 @@ type options struct { enableDispatch bool enableCausality bool merge bool - destDBType string + destDBType DBType } var defaultLoaderOptions = options{ @@ -143,7 +143,7 @@ var defaultLoaderOptions = options{ enableDispatch: true, enableCausality: true, merge: false, - destDBType: "tidb", + destDBType: MysqlDB, } // A Option sets options such batch size, worker count etc. @@ -196,8 +196,16 @@ func Merge(v bool) Option { //DestinationDBType set destDBType option. func DestinationDBType(t string) Option { + destDBType := DBTypeUnknown + if t == "oracle" { + destDBType = OracleDB + } else if t == "tidb" { + destDBType = TiDB + } else if t == "mysql" { + destDBType = MysqlDB + } return func(o *options) { - o.destDBType = t + o.destDBType = destDBType } } @@ -259,7 +267,7 @@ func NewLoader(db *gosql.DB, opt ...Option) (Loader, error) { ctx: ctx, cancel: cancel, } - if opts.destDBType == "oracle" { + if opts.destDBType == OracleDB { s.getTableInfoFromDB = getOracleTableInfo fGetAppliedTS = getOracleAppliedTS } @@ -395,7 +403,7 @@ func (s *loaderImpl) execDDL(ddl *DDL) error { if ddl.ShouldSkip { return nil } - if s.destDBType == "oracle" { + if s.destDBType == OracleDB { return s.processOracleDDL(ddl) } return s.processMysqlDDL(ddl) @@ -750,7 +758,7 @@ func filterGeneratedCols(dml *DML) { func (s *loaderImpl) getExecutor() *executor { e := newExecutor(s.db).withBatchSize(s.batchSize).withDestDBType(s.destDBType) - if s.destDBType == "oracle" { + if s.destDBType == OracleDB { e.fTryRefreshTableErr = tryRefreshTableOracleErr e.fSingleExec = e.singleOracleExec } diff --git a/pkg/loader/load_test.go b/pkg/loader/load_test.go index e05f51588..10e75eca2 100644 --- a/pkg/loader/load_test.go +++ b/pkg/loader/load_test.go @@ -350,7 +350,7 @@ func (s *execDDLSuite) TestShouldExecInTransaction(c *check.C) { mock.ExpectExec("CREATE TABLE `t` \\(`id` INT\\)").WillReturnResult(sqlmock.NewResult(0, 0)) mock.ExpectCommit() - loader := &loaderImpl{db: db, ctx: context.Background(), destDBType: "mysql"} + loader := &loaderImpl{db: db, ctx: context.Background(), destDBType: MysqlDB} ddl := DDL{SQL: "CREATE TABLE `t` (`id` INT)"} err = loader.execDDL(&ddl) @@ -365,7 +365,7 @@ func (s *execDDLSuite) TestOracleTruncateDDL(c *check.C) { mock.ExpectExec("BEGIN test.do_truncate\\('test.t1',''\\);END;").WillReturnResult(sqlmock.NewResult(0, 0)) mock.ExpectCommit() - loader := &loaderImpl{db: db, ctx: context.Background(), destDBType: "oracle"} + loader := &loaderImpl{db: db, ctx: context.Background(), destDBType: OracleDB} ddl := DDL{SQL: "truncate table t1", Database: "test", Table: "t1"} err = loader.execDDL(&ddl) @@ -389,7 +389,7 @@ func (s *execDDLSuite) TestShouldUseDatabase(c *check.C) { mock.ExpectExec("CREATE TABLE `t` \\(`id` INT\\)").WillReturnResult(sqlmock.NewResult(0, 0)) mock.ExpectCommit() - loader := &loaderImpl{db: db, ctx: context.Background(), destDBType: "mysql"} + loader := &loaderImpl{db: db, ctx: context.Background(), destDBType: MysqlDB} ddl := DDL{SQL: "CREATE TABLE `t` (`id` INT)", Database: "test_db"} err = loader.execDDL(&ddl) diff --git a/pkg/loader/model.go b/pkg/loader/model.go index 06aa86546..474124366 100644 --- a/pkg/loader/model.go +++ b/pkg/loader/model.go @@ -19,10 +19,8 @@ import ( "strconv" "strings" - "github.com/pingcap/tidb/parser/model" - "github.com/pingcap/tidb/parser/mysql" - "github.com/pingcap/log" + "github.com/pingcap/tidb/parser/model" "go.uber.org/zap" ) @@ -37,6 +35,17 @@ const ( DeleteDMLType DMLType = 3 ) +// DBType can be Mysql/Tidb or Oracle +type DBType int + +// DBType types +const ( + DBTypeUnknown DBType = iota + MysqlDB + TiDB + OracleDB +) + // DML holds the dml info type DML struct { Database string @@ -50,6 +59,8 @@ type DML struct { info *tableInfo UpColumnsInfoMap map[string]*model.ColumnInfo + + DestDBType DBType } // DDL holds the ddl info @@ -167,19 +178,23 @@ func (dml *DML) oldPrimaryKeyValues() []interface{} { // TableName returns the fully qualified name of the DML's table func (dml *DML) TableName() string { + if dml.DestDBType == OracleDB { + return fmt.Sprintf("%s.%s", dml.Database, dml.Table) + } return quoteSchema(dml.Database, dml.Table) } -// OracleTableName returns the fully qualified name of the DML's table in oracle db -func (dml *DML) OracleTableName() string { - return fmt.Sprintf("%s.%s", dml.Database, dml.Table) +func (dml *DML) updateSQL() (sql string, args []interface{}) { + if dml.DestDBType == OracleDB { + return dml.updateOracleSQL() + } + return dml.updateTiDBSQL() } -func (dml *DML) updateSQL() (sql string, args []interface{}) { +func (dml *DML) updateTiDBSQL() (sql string, args []interface{}) { builder := new(strings.Builder) fmt.Fprintf(builder, "UPDATE %s SET ", dml.TableName()) - for _, name := range dml.columnNames() { if len(args) > 0 { builder.WriteByte(',') @@ -191,41 +206,45 @@ func (dml *DML) updateSQL() (sql string, args []interface{}) { builder.WriteString(" WHERE ") - whereArgs := dml.buildWhere(builder) + whereArgs := dml.buildTiDBWhere(builder) args = append(args, whereArgs...) - builder.WriteString(" LIMIT 1") sql = builder.String() return } -func (dml *DML) oracleUpdateSQL() (sql string) { +func (dml *DML) updateOracleSQL() (sql string, args []interface{}) { builder := new(strings.Builder) - fmt.Fprintf(builder, "UPDATE %s SET ", dml.OracleTableName()) - - for i, name := range dml.columnNames() { - if i > 0 { + fmt.Fprintf(builder, "UPDATE %s SET ", dml.TableName()) + oracleHolderPos := 1 + for _, name := range dml.columnNames() { + if len(args) > 0 { builder.WriteByte(',') } - value := dml.Values[name] - if value == nil { - fmt.Fprintf(builder, "%s = NULL", escapeName(name)) - } else { - fmt.Fprintf(builder, "%s = %s", escapeName(name), genOracleValue(dml.UpColumnsInfoMap[name], value)) - } + arg := dml.Values[name] + fmt.Fprintf(builder, "%s = :%d", escapeName(name), oracleHolderPos) + oracleHolderPos++ + args = append(args, arg) } builder.WriteString(" WHERE ") - dml.buildOracleWhere(builder) + whereArgs := dml.buildOracleWhere(builder, oracleHolderPos) + args = append(args, whereArgs...) builder.WriteString(" AND rownum <=1") - sql = builder.String() return } -func (dml *DML) buildWhere(builder *strings.Builder) (args []interface{}) { +func (dml *DML) buildWhere(builder *strings.Builder, oracleHolderPos int) (args []interface{}) { + if dml.DestDBType == OracleDB { + dml.buildOracleWhere(builder, oracleHolderPos) + } + return dml.buildTiDBWhere(builder) +} + +func (dml *DML) buildTiDBWhere(builder *strings.Builder) (args []interface{}) { wnames, wargs := dml.whereSlice() for i := 0; i < len(wnames); i++ { if i > 0 { @@ -241,18 +260,22 @@ func (dml *DML) buildWhere(builder *strings.Builder) (args []interface{}) { return } -func (dml *DML) buildOracleWhere(builder *strings.Builder) { - colNames, colValues := dml.whereSlice() - for i := 0; i < len(colNames); i++ { +func (dml *DML) buildOracleWhere(builder *strings.Builder, oracleHolderPos int) (args []interface{}) { + wnames, wargs := dml.whereSlice() + pOracleHolderPos := oracleHolderPos + for i := 0; i < len(wnames); i++ { if i > 0 { builder.WriteString(" AND ") } - if colValues[i] == nil { - builder.WriteString(escapeName(colNames[i]) + " IS NULL") + if wargs[i] == nil { + builder.WriteString(escapeName(wnames[i]) + " IS NULL") } else { - builder.WriteString(fmt.Sprintf("%s = %s", escapeName(colNames[i]), genOracleValue(dml.UpColumnsInfoMap[colNames[i]], colValues[i]))) + builder.WriteString(fmt.Sprintf("%s = :%d", escapeName(wnames[i]), pOracleHolderPos)) + pOracleHolderPos++ + args = append(args, wargs[i]) } } + return } func (dml *DML) whereValues(names []string) (values []interface{}) { @@ -290,29 +313,39 @@ func (dml *DML) whereSlice() (colNames []string, args []interface{}) { } func (dml *DML) deleteSQL() (sql string, args []interface{}) { + if dml.DestDBType == OracleDB { + return dml.deleteOracleSQL() + } + return dml.deleteTiDBSQL() +} + +func (dml *DML) deleteTiDBSQL() (sql string, args []interface{}) { builder := new(strings.Builder) fmt.Fprintf(builder, "DELETE FROM %s WHERE ", dml.TableName()) - args = dml.buildWhere(builder) + args = dml.buildTiDBWhere(builder) + builder.WriteString(" LIMIT 1") sql = builder.String() return } -func (dml *DML) oracleDeleteSQL() (sql string) { +func (dml *DML) deleteOracleSQL() (sql string, args []interface{}) { builder := new(strings.Builder) - fmt.Fprintf(builder, "DELETE FROM %s WHERE ", dml.OracleTableName()) - dml.buildOracleWhere(builder) + fmt.Fprintf(builder, "DELETE FROM %s WHERE ", dml.TableName()) + args = dml.buildOracleWhere(builder, 1) + builder.WriteString(" AND rownum <=1") + sql = builder.String() return } -func (dml *DML) oracleDeleteNewValueSQL() (sql string) { +func (dml *DML) oracleDeleteNewValueSQL() (sql string, args []interface{}) { builder := new(strings.Builder) - fmt.Fprintf(builder, "DELETE FROM %s WHERE ", dml.OracleTableName()) + fmt.Fprintf(builder, "DELETE FROM %s WHERE ", dml.TableName()) valueMap := dml.Values colNames := make([]string, 0) @@ -342,7 +375,7 @@ func (dml *DML) oracleDeleteNewValueSQL() (sql string) { colValues = append(colValues, valueMap[col]) } } - + oracleHolderPos := 1 for i := 0; i < len(colNames); i++ { if i > 0 { builder.WriteString(" AND ") @@ -350,7 +383,9 @@ func (dml *DML) oracleDeleteNewValueSQL() (sql string) { if colValues[i] == nil { builder.WriteString(escapeName(colNames[i]) + " IS NULL") } else { - builder.WriteString(fmt.Sprintf("%s = %s", colNames[i], genOracleValue(dml.UpColumnsInfoMap[colNames[i]], colValues[i]))) + builder.WriteString(fmt.Sprintf("%s = :%d", colNames[i], oracleHolderPos)) + oracleHolderPos++ + args = append(args, colValues[i]) } } builder.WriteString(" AND rownum <=1") @@ -371,7 +406,7 @@ func (dml *DML) columnNames() []string { func (dml *DML) replaceSQL() (sql string, args []interface{}) { names := dml.columnNames() - sql = fmt.Sprintf("REPLACE INTO %s(%s) VALUES(%s)", dml.TableName(), buildColumnList(names), holderString(len(names))) + sql = fmt.Sprintf("REPLACE INTO %s(%s) VALUES(%s)", dml.TableName(), buildColumnList(names, dml.DestDBType), holderString(len(names), dml.DestDBType)) for _, name := range names { v := dml.Values[name] args = append(args, v) @@ -385,23 +420,6 @@ func (dml *DML) insertSQL() (sql string, args []interface{}) { return } -func (dml *DML) oracleInsertSQL() (sql string) { - builder := new(strings.Builder) - columns, values := dml.buildOracleInsertColAndValue() - fmt.Fprintf(builder, "INSERT INTO %s (%s) VALUES (%s)", dml.OracleTableName(), columns, values) - sql = builder.String() - return -} - -func (dml *DML) buildOracleInsertColAndValue() (string, string) { - names := dml.columnNames() - values := make([]string, 0, len(dml.Values)) - for _, name := range names { - values = append(values, genOracleValue(dml.UpColumnsInfoMap[name], dml.Values[name])) - } - return strings.Join(names, ", "), strings.Join(values, ", ") -} - func (dml *DML) sql() (sql string, args []interface{}) { switch dml.Tp { case InsertDMLType: @@ -417,21 +435,6 @@ func (dml *DML) sql() (sql string, args []interface{}) { return } -func (dml *DML) oracleSQL() (sql string) { - switch dml.Tp { - case InsertDMLType: - return dml.oracleInsertSQL() - case UpdateDMLType: - return dml.oracleUpdateSQL() - case DeleteDMLType: - return dml.oracleDeleteSQL() - } - - log.Debug("get sql for dml", zap.Reflect("dml", dml), zap.String("sql", sql)) - - return -} - func formatKey(values []interface{}) string { builder := new(strings.Builder) for i, v := range values { @@ -498,31 +501,3 @@ func getKeys(dml *DML) (keys []string) { return } - -func genOracleValue(column *model.ColumnInfo, value interface{}) string { - if value == nil { - return "NULL" - } - switch column.Tp { - case mysql.TypeDate: - return fmt.Sprintf("TO_DATE('%v', 'yyyy-mm-dd')", value) - case mysql.TypeDatetime: - if column.Decimal == 0 { - return fmt.Sprintf("TO_DATE('%v', 'yyyy-mm-dd hh24:mi:ss')", value) - } - return fmt.Sprintf("TO_TIMESTAMP('%v', 'yyyy-mm-dd hh24:mi:ss.ff%d')", value, column.Decimal) - case mysql.TypeTimestamp: - return fmt.Sprintf("TO_TIMESTAMP('%s', 'yyyy-mm-dd hh24:mi:ss.ff%d')", value, column.Decimal) - case mysql.TypeDuration: - return fmt.Sprintf("TO_DATE('%s', 'hh24:mi:ss')", value) - case mysql.TypeTiny, mysql.TypeShort, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeInt24, - mysql.TypeYear, mysql.TypeFloat, mysql.TypeDouble, mysql.TypeNewDecimal: - return fmt.Sprintf("%v", value) - default: - return fmt.Sprintf("'%s'", processOracleQuoteStringValue(fmt.Sprintf("%v", value))) - } -} - -func processOracleQuoteStringValue(data string) string { - return strings.ReplaceAll(data, "'", "''") -} diff --git a/pkg/loader/model_test.go b/pkg/loader/model_test.go index 3747d3b57..100b3d613 100644 --- a/pkg/loader/model_test.go +++ b/pkg/loader/model_test.go @@ -47,6 +47,7 @@ func getDML(key bool, tp DMLType) *DML { dml.Database = "test" dml.Table = "test" dml.Tp = tp + dml.DestDBType = TiDB return dml } @@ -76,7 +77,7 @@ func (d *dmlSuite) testWhere(c *check.C, tp DMLType) { c.Assert(args, check.DeepEquals, []interface{}{1}) builder := new(strings.Builder) - args = dml.buildWhere(builder) + args = dml.buildWhere(builder, 0) c.Assert(args, check.DeepEquals, []interface{}{1}) c.Assert(strings.Count(builder.String(), "?"), check.Equals, len(args)) @@ -94,14 +95,14 @@ func (d *dmlSuite) testWhere(c *check.C, tp DMLType) { c.Assert(args, check.DeepEquals, []interface{}{1, 1}) builder.Reset() - args = dml.buildWhere(builder) + args = dml.buildWhere(builder, 0) c.Assert(args, check.DeepEquals, []interface{}{1, 1}) c.Assert(strings.Count(builder.String(), "?"), check.Equals, len(args)) // set a1 to NULL value values["a1"] = nil builder.Reset() - args = dml.buildWhere(builder) + args = dml.buildWhere(builder, 0) c.Assert(args, check.DeepEquals, []interface{}{1}) c.Assert(strings.Count(builder.String(), "?"), check.Equals, len(args)) } @@ -188,6 +189,7 @@ func (s *SQLSuite) TestInsertSQL(c *check.C) { info: &tableInfo{ columns: []string{"name", "age"}, }, + DestDBType: TiDB, } sql, args := dml.sql() c.Assert(sql, check.Equals, "INSERT INTO `test`.`hello`(`age`,`name`) VALUES(?,?)") @@ -208,6 +210,7 @@ func (s *SQLSuite) TestDeleteSQL(c *check.C) { info: &tableInfo{ columns: []string{"name", "age"}, }, + DestDBType: TiDB, } sql, args := dml.sql() c.Assert( @@ -232,6 +235,7 @@ func (s *SQLSuite) TestUpdateSQL(c *check.C) { info: &tableInfo{ columns: []string{"name"}, }, + DestDBType: TiDB, } sql, args := dml.sql() c.Assert( @@ -289,11 +293,17 @@ func (s *SQLSuite) TestOracleUpdateSQL(c *check.C) { "NAME": { FieldType: types.FieldType{Tp: mysql.TypeVarString}}, }, + DestDBType: OracleDB, } - sql := dml.oracleSQL() + sql, args := dml.sql() c.Assert( sql, check.Equals, - "UPDATE db.tbl SET ID = 123,NAME = 'pc' WHERE ID = 123 AND NAME = 'pingcap' AND rownum <=1") + "UPDATE db.tbl SET ID = :1,NAME = :2 WHERE ID = :3 AND NAME = :4 AND rownum <=1") + c.Assert(args, check.HasLen, 4) + c.Assert(args[0], check.Equals, 123) + c.Assert(args[1], check.Equals, "pc") + c.Assert(args[2], check.Equals, 123) + c.Assert(args[3], check.Equals, "pingcap") } func (s *SQLSuite) TestOracleUpdateSQLPrimaryKey(c *check.C) { @@ -328,11 +338,16 @@ func (s *SQLSuite) TestOracleUpdateSQLPrimaryKey(c *check.C) { "NAME": { FieldType: types.FieldType{Tp: mysql.TypeVarString}}, }, + DestDBType: OracleDB, } - sql := dml.oracleSQL() + sql, args := dml.sql() c.Assert( sql, check.Equals, - "UPDATE db.tbl SET ID = 123,NAME = 'pc' WHERE ID = 123 AND rownum <=1") + "UPDATE db.tbl SET ID = :1,NAME = :2 WHERE ID = :3 AND rownum <=1") + c.Assert(args, check.HasLen, 3) + c.Assert(args[0], check.Equals, 123) + c.Assert(args[1], check.Equals, "pc") + c.Assert(args[2], check.Equals, 123) } func (s *SQLSuite) TestOracleDeleteSQL(c *check.C) { @@ -353,11 +368,15 @@ func (s *SQLSuite) TestOracleDeleteSQL(c *check.C) { "NAME": { FieldType: types.FieldType{Tp: mysql.TypeVarString}}, }, + DestDBType: OracleDB, } - sql := dml.oracleSQL() + sql, args := dml.sql() c.Assert( sql, check.Equals, - "DELETE FROM db.tbl WHERE ID = 123 AND NAME = 'pc' AND rownum <=1") + "DELETE FROM db.tbl WHERE ID = :1 AND NAME = :2 AND rownum <=1") + c.Assert(args, check.HasLen, 2) + c.Assert(args[0], check.Equals, 123) + c.Assert(args[1], check.Equals, "pc") } func (s *SQLSuite) TestOracleInsertSQL(c *check.C) { @@ -381,78 +400,16 @@ func (s *SQLSuite) TestOracleInsertSQL(c *check.C) { "C2": { FieldType: types.FieldType{Tp: mysql.TypeVarString}}, }, + DestDBType: OracleDB, } - sql := dml.oracleSQL() + sql, args := dml.sql() c.Assert( sql, check.Equals, - "INSERT INTO db.tbl (C2, ID, NAME) VALUES (NULL, 123, 'pc')") -} - -func (s *SQLSuite) TestGenOracleValue(c *check.C) { - columnInfo := model.ColumnInfo{ - FieldType: types.FieldType{Tp: mysql.TypeDate}, - } - colVaue := "2021-09-13" - val := genOracleValue(&columnInfo, colVaue) - c.Assert( - val, check.Equals, - "TO_DATE('2021-09-13', 'yyyy-mm-dd')") - - columnInfo = model.ColumnInfo{ - FieldType: types.FieldType{Tp: mysql.TypeDatetime, Decimal: 0}, - } - colVaue = "2021-09-13 10:10:23" - val = genOracleValue(&columnInfo, colVaue) - c.Assert( - val, check.Equals, - "TO_DATE('2021-09-13 10:10:23', 'yyyy-mm-dd hh24:mi:ss')") - - columnInfo = model.ColumnInfo{ - FieldType: types.FieldType{Tp: mysql.TypeDatetime, Decimal: 6}, - } - colVaue = "2021-09-13 10:10:23.123456" - val = genOracleValue(&columnInfo, colVaue) - c.Assert( - val, check.Equals, - "TO_TIMESTAMP('2021-09-13 10:10:23.123456', 'yyyy-mm-dd hh24:mi:ss.ff6')") - - columnInfo = model.ColumnInfo{ - FieldType: types.FieldType{Tp: mysql.TypeTimestamp, Decimal: 5}, - } - colVaue = "2021-09-13 10:10:23.12345" - val = genOracleValue(&columnInfo, colVaue) - c.Assert( - val, check.Equals, - "TO_TIMESTAMP('2021-09-13 10:10:23.12345', 'yyyy-mm-dd hh24:mi:ss.ff5')") - - columnInfo = model.ColumnInfo{ - FieldType: types.FieldType{Tp: mysql.TypeYear}, - } - colVaue = "2021" - val = genOracleValue(&columnInfo, colVaue) - c.Assert( - val, check.Equals, "2021") - - columnInfo = model.ColumnInfo{ - FieldType: types.FieldType{Tp: mysql.TypeVarchar}, - } - colVaue = "2021" - val = genOracleValue(&columnInfo, colVaue) - c.Assert( - val, check.Equals, "'2021'") - - columnInfo = model.ColumnInfo{ - FieldType: types.FieldType{Tp: mysql.TypeDuration}, - } - colVaue = "23:11:59" - val = genOracleValue(&columnInfo, colVaue) - c.Assert( - val, check.Equals, "TO_DATE('23:11:59', 'hh24:mi:ss')") - - var colVaue2 interface{} - val = genOracleValue(&columnInfo, colVaue2) - c.Assert( - val, check.Equals, "NULL") + "INSERT INTO db.tbl(C2,ID,NAME) VALUES(:1,:2,:3)") + c.Assert(args, check.HasLen, 3) + c.Assert(args[0], check.Equals, nil) + c.Assert(args[1], check.Equals, 123) + c.Assert(args[2], check.Equals, "pc") } func (s *SQLSuite) TestOracleDeleteNewValueSQLWithOneUK(c *check.C) { @@ -482,12 +439,15 @@ func (s *SQLSuite) TestOracleDeleteNewValueSQLWithOneUK(c *check.C) { "C2": { FieldType: types.FieldType{Tp: mysql.TypeVarString}}, }, + DestDBType: OracleDB, } - sql := dml.oracleDeleteNewValueSQL() + sql, args := dml.oracleDeleteNewValueSQL() c.Assert( sql, check.Equals, - "DELETE FROM db.tbl WHERE ID = 123 AND rownum <=1") + "DELETE FROM db.tbl WHERE ID = :1 AND rownum <=1") + c.Assert(args, check.HasLen, 1) + c.Assert(args[0], check.Equals, 123) // column in UK have nil value, so fall back to all columns dml = DML{ @@ -516,11 +476,15 @@ func (s *SQLSuite) TestOracleDeleteNewValueSQLWithOneUK(c *check.C) { "C2": { FieldType: types.FieldType{Tp: mysql.TypeVarString}}, }, + DestDBType: OracleDB, } - sql = dml.oracleDeleteNewValueSQL() + sql, args = dml.oracleDeleteNewValueSQL() c.Assert( sql, check.Equals, - "DELETE FROM db.tbl WHERE C2 IS NULL AND ID = 123 AND NAME = 'pc' AND rownum <=1") + "DELETE FROM db.tbl WHERE C2 IS NULL AND ID = :1 AND NAME = :2 AND rownum <=1") + c.Assert(args, check.HasLen, 2) + c.Assert(args[0], check.Equals, 123) + c.Assert(args[1], check.Equals, "pc") } func (s *SQLSuite) TestOracleDeleteNewValueSQLWithMultiUK(c *check.C) { @@ -557,12 +521,15 @@ func (s *SQLSuite) TestOracleDeleteNewValueSQLWithMultiUK(c *check.C) { "C2": { FieldType: types.FieldType{Tp: mysql.TypeVarString}}, }, + DestDBType: OracleDB, } - sql := dml.oracleDeleteNewValueSQL() + sql, args := dml.oracleDeleteNewValueSQL() c.Assert( sql, check.Equals, - "DELETE FROM db.tbl WHERE ID2 = '456' AND rownum <=1") + "DELETE FROM db.tbl WHERE ID2 = :1 AND rownum <=1") + c.Assert(args, check.HasLen, 1) + c.Assert(args[0], check.Equals, "456") } func (s *SQLSuite) TestOracleDeleteNewValueSQLWithNoUK(c *check.C) { @@ -589,10 +556,15 @@ func (s *SQLSuite) TestOracleDeleteNewValueSQLWithNoUK(c *check.C) { "C2": { FieldType: types.FieldType{Tp: mysql.TypeVarString}}, }, + DestDBType: OracleDB, } - sql := dml.oracleDeleteNewValueSQL() + sql, args := dml.oracleDeleteNewValueSQL() c.Assert( sql, check.Equals, - "DELETE FROM db.tbl WHERE C2 IS NULL AND ID = 123 AND ID2 = '456' AND NAME = 'pc' AND rownum <=1") + "DELETE FROM db.tbl WHERE C2 IS NULL AND ID = :1 AND ID2 = :2 AND NAME = :3 AND rownum <=1") + c.Assert(args, check.HasLen, 3) + c.Assert(args[0], check.Equals, 123) + c.Assert(args[1], check.Equals, "456") + c.Assert(args[2], check.Equals, "pc") } diff --git a/pkg/loader/util.go b/pkg/loader/util.go index fc54c5224..b2b1c5a4f 100644 --- a/pkg/loader/util.go +++ b/pkg/loader/util.go @@ -230,6 +230,7 @@ func CreateOracleDB(user string, password string, host string, port int, service Timezone: loc, }, } + oraDSN.OnInitStmts = []string{"ALTER SESSION SET NLS_DATE_FORMAT='YYYY-MM-DD HH24.MI.SS' NLS_TIMESTAMP_FORMAT='YYYY-MM-DD HH24.MI.SS.FF' NLS_TIMESTAMP_TZ_FORMAT='YYYY-MM-DD HH24.MI.SS.FF TZR' NLS_TIME_FORMAT='HH24.MI.SS.FF' NLS_TIME_TZ_FORMAT='HH24.MI.SS.FF TZR'"} sqlDB := gosql.OpenDB(godror.NewConnector(oraDSN)) err = sqlDB.Ping() if err != nil { @@ -250,7 +251,14 @@ func escapeName(name string) string { return strings.Replace(name, "`", "``", -1) } -func holderString(n int) string { +func holderString(n int, destDBType DBType) string { + if destDBType == OracleDB { + return holderStringOracle(n) + } + return holderStringTiDB(n) +} + +func holderStringTiDB(n int) string { builder := new(strings.Builder) for i := 0; i < n; i++ { if i > 0 { @@ -261,6 +269,17 @@ func holderString(n int) string { return builder.String() } +func holderStringOracle(n int) string { + builder := new(strings.Builder) + for i := 0; i < n; i++ { + if i > 0 { + builder.WriteString(",") + } + builder.WriteString(":" + strconv.Itoa(i+1)) + } + return builder.String() +} + func genHashKey(key string) uint32 { return crc32.ChecksumIEEE([]byte(key)) } @@ -277,13 +296,17 @@ func splitDMLs(dmls []*DML, size int) (res [][]*DML) { return } -func buildColumnList(names []string) string { +func buildColumnList(names []string, destDBType DBType) string { var b strings.Builder for i, name := range names { if i > 0 { b.WriteString(",") } - b.WriteString(quoteName(name)) + if destDBType == OracleDB { + b.WriteString(escapeName(name)) + } else { + b.WriteString(quoteName(name)) + } }