diff --git a/br/pkg/lightning/backend/kv/session.go b/br/pkg/lightning/backend/kv/session.go index eb6148a097b8e..e1aca81bd581e 100644 --- a/br/pkg/lightning/backend/kv/session.go +++ b/br/pkg/lightning/backend/kv/session.go @@ -286,11 +286,13 @@ func NewSession(options *encode.SessionOptions, logger log.Logger) *Session { vars.StmtCtx.InInsertStmt = true vars.StmtCtx.BatchCheck = true vars.StmtCtx.BadNullAsWarning = !sqlMode.HasStrictMode() - vars.StmtCtx.TruncateAsWarning = !sqlMode.HasStrictMode() vars.StmtCtx.OverflowAsWarning = !sqlMode.HasStrictMode() vars.StmtCtx.AllowInvalidDate = sqlMode.HasAllowInvalidDatesMode() vars.StmtCtx.IgnoreZeroInDate = !sqlMode.HasStrictMode() || sqlMode.HasAllowInvalidDatesMode() vars.SQLMode = sqlMode + + typeFlags := vars.StmtCtx.TypeFlags().WithTruncateAsWarning(!sqlMode.HasStrictMode()) + vars.StmtCtx.SetTypeFlags(typeFlags) if options.SysVars != nil { for k, v := range options.SysVars { // since 6.3(current master) tidb checks whether we can set a system variable diff --git a/br/pkg/lightning/backend/tidb/tidb.go b/br/pkg/lightning/backend/tidb/tidb.go index 7636e960e498e..4109c27bc834f 100644 --- a/br/pkg/lightning/backend/tidb/tidb.go +++ b/br/pkg/lightning/backend/tidb/tidb.go @@ -455,7 +455,7 @@ func (enc *tidbEncoder) appendSQL(sb *strings.Builder, datum *types.Datum, _ *ta case types.KindMysqlBit: var buffer [20]byte - intValue, err := datum.GetBinaryLiteral().ToInt(nil) + intValue, err := datum.GetBinaryLiteral().ToInt(types.DefaultNoWarningContext) if err != nil { return err } diff --git a/pkg/ddl/backfilling_scheduler.go b/pkg/ddl/backfilling_scheduler.go index 02c48a532ee12..5a6a5d9008b78 100644 --- a/pkg/ddl/backfilling_scheduler.go +++ b/pkg/ddl/backfilling_scheduler.go @@ -163,12 +163,15 @@ func initSessCtx( return errors.Trace(err) } sessCtx.GetSessionVars().StmtCtx.BadNullAsWarning = !sqlMode.HasStrictMode() - sessCtx.GetSessionVars().StmtCtx.TruncateAsWarning = !sqlMode.HasStrictMode() sessCtx.GetSessionVars().StmtCtx.OverflowAsWarning = !sqlMode.HasStrictMode() sessCtx.GetSessionVars().StmtCtx.AllowInvalidDate = sqlMode.HasAllowInvalidDatesMode() sessCtx.GetSessionVars().StmtCtx.DividedByZeroAsWarning = !sqlMode.HasStrictMode() sessCtx.GetSessionVars().StmtCtx.IgnoreZeroInDate = !sqlMode.HasStrictMode() || sqlMode.HasAllowInvalidDatesMode() sessCtx.GetSessionVars().StmtCtx.NoZeroDate = sqlMode.HasStrictMode() + + typeFlags := sessCtx.GetSessionVars().StmtCtx.TypeFlags().WithTruncateAsWarning(!sqlMode.HasStrictMode()) + sessCtx.GetSessionVars().StmtCtx.SetTypeFlags(typeFlags) + // Prevent initializing the mock context in the workers concurrently. // For details, see https://github.com/pingcap/tidb/issues/40879. _ = sessCtx.GetDomainInfoSchema() diff --git a/pkg/ddl/ddl_api.go b/pkg/ddl/ddl_api.go index c87bc650a8048..9aae98f0a323d 100644 --- a/pkg/ddl/ddl_api.go +++ b/pkg/ddl/ddl_api.go @@ -1352,7 +1352,7 @@ func getDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.Colu return str, false, err } // For other kind of fields (e.g. INT), we supply its integer as string value. - value, err := v.GetBinaryLiteral().ToInt(ctx.GetSessionVars().StmtCtx) + value, err := v.GetBinaryLiteral().ToInt(ctx.GetSessionVars().StmtCtx.TypeCtx) if err != nil { return nil, false, err } @@ -5617,12 +5617,11 @@ func GetModifiableColumnJob( } pAst := at.Specs[0].Partition sv := sctx.GetSessionVars().StmtCtx - oldTruncAsWarn, oldIgnoreTrunc := sv.TruncateAsWarning, sv.IgnoreTruncate.Load() - sv.TruncateAsWarning = false - sv.IgnoreTruncate.Store(false) + oldTypeFlags := sv.TypeFlags() + newTypeFlags := oldTypeFlags.WithTruncateAsWarning(false).WithIgnoreTruncateErr(false) + sv.SetTypeFlags(newTypeFlags) _, err = buildPartitionDefinitionsInfo(sctx, pAst.Definitions, &newTblInfo, uint64(len(newTblInfo.Partition.Definitions))) - sv.TruncateAsWarning = oldTruncAsWarn - sv.IgnoreTruncate.Store(oldIgnoreTrunc) + sv.SetTypeFlags(oldTypeFlags) if err != nil { return nil, dbterror.ErrUnsupportedModifyColumn.GenWithStack("New column does not match partition definitions: %s", err.Error()) } diff --git a/pkg/executor/aggfuncs/func_group_concat.go b/pkg/executor/aggfuncs/func_group_concat.go index 6c9e9739e50b6..afe1d204f7fb9 100644 --- a/pkg/executor/aggfuncs/func_group_concat.go +++ b/pkg/executor/aggfuncs/func_group_concat.go @@ -72,7 +72,7 @@ func (e *baseGroupConcat4String) AppendFinalResult2Chunk(_ sessionctx.Context, p func (e *baseGroupConcat4String) handleTruncateError(sctx sessionctx.Context) (err error) { if atomic.CompareAndSwapInt32(e.truncated, 0, 1) { - if !sctx.GetSessionVars().StmtCtx.TruncateAsWarning { + if !sctx.GetSessionVars().StmtCtx.TypeFlags().TruncateAsWarning() { return expression.ErrCutValueGroupConcat.GenWithStackByArgs(e.args[0].String()) } sctx.GetSessionVars().StmtCtx.AppendWarning(expression.ErrCutValueGroupConcat.GenWithStackByArgs(e.args[0].String())) diff --git a/pkg/executor/coprocessor.go b/pkg/executor/coprocessor.go index c835dde6d67f9..acc6cfc93fba5 100644 --- a/pkg/executor/coprocessor.go +++ b/pkg/executor/coprocessor.go @@ -182,14 +182,14 @@ func (h *CoprocessorDAGHandler) buildDAGExecutor(req *coprocessor.Request) (exec } stmtCtx := h.sctx.GetSessionVars().StmtCtx - stmtCtx.SetFlagsFromPBFlag(dagReq.Flags) + tz, err := timeutil.ConstructTimeZone(dagReq.TimeZoneName, int(dagReq.TimeZoneOffset)) if err != nil { return nil, errors.Trace(err) } - - stmtCtx.SetTimeZone(tz) h.sctx.GetSessionVars().TimeZone = tz + stmtCtx.InitFromPBFlagAndTz(dagReq.Flags, tz) + h.dagReq = dagReq is := h.sctx.GetInfoSchema().(infoschema.InfoSchema) // Build physical plan. diff --git a/pkg/executor/executor.go b/pkg/executor/executor.go index fca75d706f40d..9eca11f41cac2 100644 --- a/pkg/executor/executor.go +++ b/pkg/executor/executor.go @@ -907,6 +907,13 @@ func (e *CheckTableExec) Next(ctx context.Context, _ *chunk.Chunk) error { } defer func() { e.done = true }() + // See the comment of `ColumnInfos2ColumnsAndNames`. It's fixing #42341 + originalTypeFlags := e.Ctx().GetSessionVars().StmtCtx.TypeFlags() + defer func() { + e.Ctx().GetSessionVars().StmtCtx.SetTypeFlags(originalTypeFlags) + }() + e.Ctx().GetSessionVars().StmtCtx.SetTypeFlags(originalTypeFlags.WithIgnoreTruncateErr(true)) + idxNames := make([]string, 0, len(e.indexInfos)) for _, idx := range e.indexInfos { if idx.MVIndex { @@ -2062,6 +2069,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.InRestrictedSQL = vars.InRestrictedSQL switch stmt := s.(type) { + // `ResetUpdateStmtCtx` and `ResetDeleteStmtCtx` may modify the flags, so we'll need to store them. case *ast.UpdateStmt: ResetUpdateStmtCtx(sc, stmt, vars) case *ast.DeleteStmt: @@ -2075,17 +2083,17 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.BadNullAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr sc.IgnoreNoPartition = stmt.IgnoreErr sc.ErrAutoincReadFailedAsWarning = stmt.IgnoreErr - sc.TruncateAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr sc.DividedByZeroAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode() sc.IgnoreZeroInDate = !vars.SQLMode.HasNoZeroInDateMode() || !vars.SQLMode.HasNoZeroDateMode() || !vars.StrictSQLMode || stmt.IgnoreErr || sc.AllowInvalidDate sc.Priority = stmt.Priority + sc.SetTypeFlags(sc.TypeFlags().WithTruncateAsWarning(!vars.StrictSQLMode || stmt.IgnoreErr)) case *ast.CreateTableStmt, *ast.AlterTableStmt: sc.InCreateOrAlterStmt = true sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode() sc.IgnoreZeroInDate = !vars.SQLMode.HasNoZeroInDateMode() || !vars.StrictSQLMode || sc.AllowInvalidDate sc.NoZeroDate = vars.SQLMode.HasNoZeroDateMode() - sc.TruncateAsWarning = !vars.StrictSQLMode + sc.SetTypeFlags(sc.TypeFlags().WithTruncateAsWarning(!vars.StrictSQLMode)) case *ast.LoadDataStmt: sc.InLoadDataStmt = true // return warning instead of error when load data meet no partition for value @@ -2100,7 +2108,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.OverflowAsWarning = true // Return warning for truncate error in selection. - sc.TruncateAsWarning = true + sc.SetTypeFlags(sc.TypeFlags().WithTruncateAsWarning(true)) sc.IgnoreZeroInDate = true sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode() if opts := stmt.SelectStmtOpts; opts != nil { @@ -2111,11 +2119,11 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { case *ast.SetOprStmt: sc.InSelectStmt = true sc.OverflowAsWarning = true - sc.TruncateAsWarning = true + sc.SetTypeFlags(sc.TypeFlags().WithTruncateAsWarning(true)) sc.IgnoreZeroInDate = true sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode() case *ast.ShowStmt: - sc.IgnoreTruncate.Store(true) + sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(true)) sc.IgnoreZeroInDate = true sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode() if stmt.Tp == ast.ShowWarnings || stmt.Tp == ast.ShowErrors || stmt.Tp == ast.ShowSessionStates { @@ -2123,26 +2131,24 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.SetWarnings(vars.StmtCtx.GetWarnings()) } case *ast.SplitRegionStmt: - sc.IgnoreTruncate.Store(false) + sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(false)) sc.IgnoreZeroInDate = true sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode() case *ast.SetSessionStatesStmt: sc.InSetSessionStatesStmt = true - sc.IgnoreTruncate.Store(true) + sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(true)) sc.IgnoreZeroInDate = true sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode() default: - sc.IgnoreTruncate.Store(true) + sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(true)) sc.IgnoreZeroInDate = true sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode() } - sc.UpdateTypeFlags(func(flags types.Flags) types.Flags { - return flags. - WithSkipUTF8Check(vars.SkipUTF8Check). - WithSkipSACIICheck(vars.SkipASCIICheck). - WithSkipUTF8MB4Check(!globalConfig.Instance.CheckMb4ValueInUTF8.Load()) - }) + sc.SetTypeFlags(sc.TypeFlags(). + WithSkipUTF8Check(vars.SkipUTF8Check). + WithSkipSACIICheck(vars.SkipASCIICheck). + WithSkipUTF8MB4Check(!globalConfig.Instance.CheckMb4ValueInUTF8.Load())) vars.PlanCacheParams.Reset() if priority := mysql.PriorityEnum(atomic.LoadInt32(&variable.ForcePriority)); priority != mysql.NoPriority { @@ -2192,12 +2198,12 @@ func ResetUpdateStmtCtx(sc *stmtctx.StatementContext, stmt *ast.UpdateStmt, vars sc.InUpdateStmt = true sc.DupKeyAsWarning = stmt.IgnoreErr sc.BadNullAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr - sc.TruncateAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr sc.DividedByZeroAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode() sc.IgnoreZeroInDate = !vars.SQLMode.HasNoZeroInDateMode() || !vars.SQLMode.HasNoZeroDateMode() || !vars.StrictSQLMode || stmt.IgnoreErr || sc.AllowInvalidDate sc.Priority = stmt.Priority sc.IgnoreNoPartition = stmt.IgnoreErr + sc.SetTypeFlags(sc.TypeFlags().WithTruncateAsWarning(!vars.StrictSQLMode || stmt.IgnoreErr)) } // ResetDeleteStmtCtx resets statement context for DeleteStmt. @@ -2205,11 +2211,11 @@ func ResetDeleteStmtCtx(sc *stmtctx.StatementContext, stmt *ast.DeleteStmt, vars sc.InDeleteStmt = true sc.DupKeyAsWarning = stmt.IgnoreErr sc.BadNullAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr - sc.TruncateAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr sc.DividedByZeroAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode() sc.IgnoreZeroInDate = !vars.SQLMode.HasNoZeroInDateMode() || !vars.SQLMode.HasNoZeroDateMode() || !vars.StrictSQLMode || stmt.IgnoreErr || sc.AllowInvalidDate sc.Priority = stmt.Priority + sc.SetTypeFlags(sc.TypeFlags().WithTruncateAsWarning(!vars.StrictSQLMode || stmt.IgnoreErr)) } func setOptionForTopSQL(sc *stmtctx.StatementContext, snapshot kv.Snapshot) { diff --git a/pkg/executor/insert_common.go b/pkg/executor/insert_common.go index e8d48aaf45709..09ddfd5500206 100644 --- a/pkg/executor/insert_common.go +++ b/pkg/executor/insert_common.go @@ -706,7 +706,7 @@ func (e *InsertValues) fillRow(ctx context.Context, row []types.Datum, hasValue if err != nil && gCol.FieldType.IsArray() { return nil, completeError(tbl, gCol.Offset, rowIdx, err) } - if e.Ctx().GetSessionVars().StmtCtx.HandleTruncate(err) != nil { + if e.Ctx().GetSessionVars().StmtCtx.TypeCtx.HandleTruncate(err) != nil { return nil, err } row[colIdx], err = table.CastValue(e.Ctx(), val, gCol.ToInfo(), false, false) @@ -791,7 +791,7 @@ func setDatumAutoIDAndCast(ctx sessionctx.Context, d *types.Datum, id int64, col // Auto ID is out of range. sc := ctx.GetSessionVars().StmtCtx insertPlan, ok := sc.GetPlan().(*core.Insert) - if ok && sc.TruncateAsWarning && len(insertPlan.OnDuplicate) > 0 { + if ok && sc.TypeFlags().TruncateAsWarning() && len(insertPlan.OnDuplicate) > 0 { // Fix issue #38950: AUTO_INCREMENT is incompatible with mysql // An auto id out of range error occurs in `insert ignore into ... on duplicate ...`. // We should allow the SQL to be executed successfully. diff --git a/pkg/executor/load_data.go b/pkg/executor/load_data.go index e882fae1cbe9e..fd0a6c288b07f 100644 --- a/pkg/executor/load_data.go +++ b/pkg/executor/load_data.go @@ -99,8 +99,9 @@ func setNonRestrictiveFlags(stmtCtx *stmtctx.StatementContext) { // TODO: DupKeyAsWarning represents too many "ignore error" paths, the // meaning of this flag is not clear. I can only reuse it here. stmtCtx.DupKeyAsWarning = true - stmtCtx.TruncateAsWarning = true stmtCtx.BadNullAsWarning = true + + stmtCtx.SetTypeFlags(stmtCtx.TypeFlags().WithTruncateAsWarning(true)) } // NewLoadDataWorker creates a new LoadDataWorker that is ready to work. diff --git a/pkg/executor/test/loaddatatest/load_data_test.go b/pkg/executor/test/loaddatatest/load_data_test.go index 19cbb874e3d58..1284475340815 100644 --- a/pkg/executor/test/loaddatatest/load_data_test.go +++ b/pkg/executor/test/loaddatatest/load_data_test.go @@ -141,11 +141,11 @@ func TestLoadData(t *testing.T) { selectSQL := "select * from load_data_test;" sc := ctx.GetSessionVars().StmtCtx - originIgnoreTruncate := sc.IgnoreTruncate.Load() + oldFlags := sc.TypeFlags() defer func() { - sc.IgnoreTruncate.Store(originIgnoreTruncate) + sc.SetTypeFlags(oldFlags) }() - sc.IgnoreTruncate.Store(false) + sc.SetTypeFlags(oldFlags.WithIgnoreTruncateErr(false)) // fields and lines are default, ReadOneBatchRows returns data is nil tests := []testCase{ // In MySQL we have 4 warnings: 1*"Incorrect integer value: '' for column 'id' at row", 3*"Row 1 doesn't contain data for all columns" diff --git a/pkg/executor/test/writetest/write_test.go b/pkg/executor/test/writetest/write_test.go index fcb756c4d91fc..5d7423a17e8be 100644 --- a/pkg/executor/test/writetest/write_test.go +++ b/pkg/executor/test/writetest/write_test.go @@ -1317,11 +1317,11 @@ func TestIssue18681(t *testing.T) { ctx.GetSessionVars().StmtCtx.BadNullAsWarning = true sc := ctx.GetSessionVars().StmtCtx - originIgnoreTruncate := sc.IgnoreTruncate.Load() + oldTypeFlags := sc.TypeFlags() defer func() { - sc.IgnoreTruncate.Store(originIgnoreTruncate) + sc.SetTypeFlags(oldTypeFlags) }() - sc.IgnoreTruncate.Store(false) + sc.SetTypeFlags(oldTypeFlags.WithIgnoreTruncateErr(true)) tests := []testCase{ {[]byte("true\tfalse\t0\t1\n"), []string{"1|0|0|1"}, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 0"}, } diff --git a/pkg/expression/aggregation/aggregation.go b/pkg/expression/aggregation/aggregation.go index d26698abdb9a8..e417b1e71cc78 100644 --- a/pkg/expression/aggregation/aggregation.go +++ b/pkg/expression/aggregation/aggregation.go @@ -140,7 +140,7 @@ func (af *aggFunction) ResetContext(sc *stmtctx.StatementContext, evalCtx *AggEv evalCtx.Value.SetNull() } -func (af *aggFunction) updateSum(sc *stmtctx.StatementContext, evalCtx *AggEvaluateContext, row chunk.Row) error { +func (af *aggFunction) updateSum(ctx types.Context, evalCtx *AggEvaluateContext, row chunk.Row) error { a := af.Args[0] value, err := a.Eval(row) if err != nil { @@ -158,7 +158,7 @@ func (af *aggFunction) updateSum(sc *stmtctx.StatementContext, evalCtx *AggEvalu return nil } } - evalCtx.Value, err = calculateSum(sc, evalCtx.Value, value) + evalCtx.Value, err = calculateSum(ctx, evalCtx.Value, value) if err != nil { return err } diff --git a/pkg/expression/aggregation/avg.go b/pkg/expression/aggregation/avg.go index e15f0ce0f7be6..3fa1911f9e3a4 100644 --- a/pkg/expression/aggregation/avg.go +++ b/pkg/expression/aggregation/avg.go @@ -27,7 +27,7 @@ type avgFunction struct { aggFunction } -func (af *avgFunction) updateAvg(sc *stmtctx.StatementContext, evalCtx *AggEvaluateContext, row chunk.Row) error { +func (af *avgFunction) updateAvg(ctx types.Context, evalCtx *AggEvaluateContext, row chunk.Row) error { a := af.Args[1] value, err := a.Eval(row) if err != nil { @@ -36,7 +36,7 @@ func (af *avgFunction) updateAvg(sc *stmtctx.StatementContext, evalCtx *AggEvalu if value.IsNull() { return nil } - evalCtx.Value, err = calculateSum(sc, evalCtx.Value, value) + evalCtx.Value, err = calculateSum(ctx, evalCtx.Value, value) if err != nil { return err } @@ -60,9 +60,9 @@ func (af *avgFunction) ResetContext(sc *stmtctx.StatementContext, evalCtx *AggEv func (af *avgFunction) Update(evalCtx *AggEvaluateContext, sc *stmtctx.StatementContext, row chunk.Row) (err error) { switch af.Mode { case Partial1Mode, CompleteMode: - err = af.updateSum(sc, evalCtx, row) + err = af.updateSum(sc.TypeCtx, evalCtx, row) case Partial2Mode, FinalMode: - err = af.updateAvg(sc, evalCtx, row) + err = af.updateAvg(sc.TypeCtx, evalCtx, row) case DedupMode: panic("DedupMode is not supported now.") } diff --git a/pkg/expression/aggregation/sum.go b/pkg/expression/aggregation/sum.go index 0c1dd13ae192a..5169682cc3bbf 100644 --- a/pkg/expression/aggregation/sum.go +++ b/pkg/expression/aggregation/sum.go @@ -26,7 +26,7 @@ type sumFunction struct { // Update implements Aggregation interface. func (sf *sumFunction) Update(evalCtx *AggEvaluateContext, sc *stmtctx.StatementContext, row chunk.Row) error { - return sf.updateSum(sc, evalCtx, row) + return sf.updateSum(sc.TypeCtx, evalCtx, row) } // GetResult implements Aggregation interface. diff --git a/pkg/expression/aggregation/util.go b/pkg/expression/aggregation/util.go index 1842d385e346d..c253a675e51c3 100644 --- a/pkg/expression/aggregation/util.go +++ b/pkg/expression/aggregation/util.go @@ -55,7 +55,7 @@ func (d *distinctChecker) Check(values []types.Datum) (bool, error) { } // calculateSum adds v to sum. -func calculateSum(sc *stmtctx.StatementContext, sum, v types.Datum) (data types.Datum, err error) { +func calculateSum(ctx types.Context, sum, v types.Datum) (data types.Datum, err error) { // for avg and sum calculation // avg and sum use decimal for integer and decimal type, use float for others // see https://dev.mysql.com/doc/refman/5.7/en/group-by-functions.html @@ -64,7 +64,7 @@ func calculateSum(sc *stmtctx.StatementContext, sum, v types.Datum) (data types. case types.KindNull: case types.KindInt64, types.KindUint64: var d *types.MyDecimal - d, err = v.ToDecimal(sc) + d, err = v.ToDecimal(ctx) if err == nil { data = types.NewDecimalDatum(d) } @@ -72,7 +72,7 @@ func calculateSum(sc *stmtctx.StatementContext, sum, v types.Datum) (data types. v.Copy(&data) default: var f float64 - f, err = v.ToFloat64(sc) + f, err = v.ToFloat64(ctx) if err == nil { data = types.NewFloat64Datum(f) } diff --git a/pkg/expression/builtin_arithmetic.go b/pkg/expression/builtin_arithmetic.go index 1bff34ccd91a1..6f6a84e5ea9ef 100644 --- a/pkg/expression/builtin_arithmetic.go +++ b/pkg/expression/builtin_arithmetic.go @@ -727,7 +727,7 @@ func (s *builtinArithmeticDivideDecimalSig) evalDecimal(row chunk.Row) (*types.M return c, true, handleDivisionByZeroError(s.ctx) } else if err == types.ErrTruncated { sc := s.ctx.GetSessionVars().StmtCtx - err = sc.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c)) + err = sc.TypeCtx.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c)) } else if err == nil { _, frac := c.PrecisionAndFrac() if frac < s.baseBuiltinFunc.tp.GetDecimal() { @@ -846,7 +846,7 @@ func (s *builtinArithmeticIntDivideDecimalSig) evalInt(row chunk.Row) (ret int64 return 0, true, handleDivisionByZeroError(s.ctx) } if err == types.ErrTruncated { - err = sc.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c)) + err = sc.TypeCtx.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c)) } if err == types.ErrOverflow { newErr := errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c) diff --git a/pkg/expression/builtin_arithmetic_vec.go b/pkg/expression/builtin_arithmetic_vec.go index 49c62896ecf6d..c4951c581bc38 100644 --- a/pkg/expression/builtin_arithmetic_vec.go +++ b/pkg/expression/builtin_arithmetic_vec.go @@ -95,7 +95,7 @@ func (b *builtinArithmeticDivideDecimalSig) vecEvalDecimal(input *chunk.Chunk, r result.SetNull(i, true) continue } else if err == types.ErrTruncated { - if err = sc.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", to)); err != nil { + if err = sc.TypeCtx.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", to)); err != nil { return err } } else if err == nil { @@ -617,7 +617,7 @@ func (b *builtinArithmeticIntDivideDecimalSig) vecEvalInt(input *chunk.Chunk, re continue } if err == types.ErrTruncated { - err = sc.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c)) + err = sc.TypeCtx.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c)) } else if err == types.ErrOverflow { newErr := errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c) err = sc.HandleOverflow(newErr, newErr) diff --git a/pkg/expression/builtin_cast.go b/pkg/expression/builtin_cast.go index 4d49e0a4be94c..01bd3438ef417 100644 --- a/pkg/expression/builtin_cast.go +++ b/pkg/expression/builtin_cast.go @@ -534,7 +534,7 @@ func convertJSON2Tp(evalType types.EvalType) func(*stmtctx.StatementContext, typ if item.TypeCode != types.JSONTypeCodeString { return nil, ErrInvalidJSONForFuncIndex } - return types.ProduceStrWithSpecifiedTp(string(item.GetString()), tp, sc, false) + return types.ProduceStrWithSpecifiedTp(string(item.GetString()), tp, sc.TypeCtx, false) } case types.ETInt: return func(sc *stmtctx.StatementContext, item types.BinaryJSON, tp *types.FieldType) (any, error) { @@ -552,7 +552,7 @@ func convertJSON2Tp(evalType types.EvalType) func(*stmtctx.StatementContext, typ if item.TypeCode != types.JSONTypeCodeFloat64 && item.TypeCode != types.JSONTypeCodeInt64 && item.TypeCode != types.JSONTypeCodeUint64 { return nil, ErrInvalidJSONForFuncIndex } - return types.ConvertJSONToFloat(sc, item) + return types.ConvertJSONToFloat(sc.TypeCtx, item) } case types.ETDatetime: return func(sc *stmtctx.StatementContext, item types.BinaryJSON, tp *types.FieldType) (any, error) { @@ -730,7 +730,7 @@ func (b *builtinCastIntAsStringSig) evalString(row chunk.Row) (res string, isNul if tp.GetType() == mysql.TypeYear && res == "0" { res = "0000" } - res, err = types.ProduceStrWithSpecifiedTp(res, b.tp, b.ctx.GetSessionVars().StmtCtx, false) + res, err = types.ProduceStrWithSpecifiedTp(res, b.tp, b.ctx.GetSessionVars().StmtCtx.TypeCtx, false) if err != nil { return res, false, err } @@ -790,7 +790,7 @@ func (b *builtinCastIntAsDurationSig) evalDuration(row chunk.Row) (res types.Dur err = b.ctx.GetSessionVars().StmtCtx.HandleOverflow(err, err) } if types.ErrTruncatedWrongVal.Equal(err) { - err = b.ctx.GetSessionVars().StmtCtx.HandleTruncate(err) + err = b.ctx.GetSessionVars().StmtCtx.TypeCtx.HandleTruncate(err) } return res, true, err } @@ -1045,7 +1045,7 @@ func (b *builtinCastRealAsStringSig) evalString(row chunk.Row) (res string, isNu // If we strconv.FormatFloat the value with 64bits, the result is incorrect! bits = 32 } - res, err = types.ProduceStrWithSpecifiedTp(strconv.FormatFloat(val, 'f', -1, bits), b.tp, b.ctx.GetSessionVars().StmtCtx, false) + res, err = types.ProduceStrWithSpecifiedTp(strconv.FormatFloat(val, 'f', -1, bits), b.tp, b.ctx.GetSessionVars().StmtCtx.TypeCtx, false) if err != nil { return res, false, err } @@ -1102,7 +1102,7 @@ func (b *builtinCastRealAsDurationSig) evalDuration(row chunk.Row) (res types.Du res, _, err = types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, strconv.FormatFloat(val, 'f', -1, 64), b.tp.GetDecimal()) if err != nil { if types.ErrTruncatedWrongVal.Equal(err) { - err = b.ctx.GetSessionVars().StmtCtx.HandleTruncate(err) + err = b.ctx.GetSessionVars().StmtCtx.TypeCtx.HandleTruncate(err) // ErrTruncatedWrongVal needs to be considered NULL. return res, true, err } @@ -1191,7 +1191,7 @@ func (b *builtinCastDecimalAsStringSig) evalString(row chunk.Row) (res string, i return res, isNull, err } sc := b.ctx.GetSessionVars().StmtCtx - res, err = types.ProduceStrWithSpecifiedTp(string(val.ToString()), b.tp, sc, false) + res, err = types.ProduceStrWithSpecifiedTp(string(val.ToString()), b.tp, sc.TypeCtx, false) if err != nil { return res, false, err } @@ -1279,7 +1279,7 @@ func (b *builtinCastDecimalAsDurationSig) evalDuration(row chunk.Row) (res types } res, _, err = types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, string(val.ToString()), b.tp.GetDecimal()) if types.ErrTruncatedWrongVal.Equal(err) { - err = b.ctx.GetSessionVars().StmtCtx.HandleTruncate(err) + err = b.ctx.GetSessionVars().StmtCtx.TypeCtx.HandleTruncate(err) // ErrTruncatedWrongVal needs to be considered NULL. return res, true, err } @@ -1301,8 +1301,7 @@ func (b *builtinCastStringAsStringSig) evalString(row chunk.Row) (res string, is if isNull || err != nil { return res, isNull, err } - sc := b.ctx.GetSessionVars().StmtCtx - res, err = types.ProduceStrWithSpecifiedTp(res, b.tp, sc, false) + res, err = types.ProduceStrWithSpecifiedTp(res, b.tp, b.ctx.GetSessionVars().StmtCtx.TypeCtx, false) if err != nil { return res, false, err } @@ -1367,7 +1366,7 @@ func (b *builtinCastStringAsIntSig) evalInt(row chunk.Row) (res int64, isNull bo var ures uint64 sc := b.ctx.GetSessionVars().StmtCtx if !isNegative { - ures, err = types.StrToUint(sc, val, true) + ures, err = types.StrToUint(sc.TypeCtx, val, true) res = int64(ures) if err == nil && !mysql.HasUnsignedFlag(b.tp.GetFlag()) && ures > uint64(math.MaxInt64) { @@ -1376,7 +1375,7 @@ func (b *builtinCastStringAsIntSig) evalInt(row chunk.Row) (res int64, isNull bo } else if b.inUnion && mysql.HasUnsignedFlag(b.tp.GetFlag()) { res = 0 } else { - res, err = types.StrToInt(sc, val, true) + res, err = types.StrToInt(sc.TypeCtx, val, true) if err == nil && mysql.HasUnsignedFlag(b.tp.GetFlag()) { // If overflow, don't append this warnings sc.AppendWarning(types.ErrCastNegIntAsUnsigned) @@ -1412,7 +1411,7 @@ func (b *builtinCastStringAsRealSig) evalReal(row chunk.Row) (res float64, isNul return res, isNull, err } sc := b.ctx.GetSessionVars().StmtCtx - res, err = types.StrToFloat(sc, val, true) + res, err = types.StrToFloat(sc.TypeCtx, val, true) if err != nil { return 0, false, err } @@ -1450,7 +1449,7 @@ func (b *builtinCastStringAsDecimalSig) evalDecimal(row chunk.Row) (res *types.M if err == types.ErrTruncated { err = types.ErrTruncatedWrongVal.GenWithStackByArgs("DECIMAL", []byte(val)) } - err = sc.HandleTruncate(err) + err = sc.TypeCtx.HandleTruncate(err) if err != nil { return res, false, err } @@ -1507,7 +1506,7 @@ func (b *builtinCastStringAsDurationSig) evalDuration(row chunk.Row) (res types. res, isNull, err = types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, val, b.tp.GetDecimal()) if types.ErrTruncatedWrongVal.Equal(err) { sc := b.ctx.GetSessionVars().StmtCtx - err = sc.HandleTruncate(err) + err = sc.TypeCtx.HandleTruncate(err) } return res, isNull, err } @@ -1620,7 +1619,7 @@ func (b *builtinCastTimeAsStringSig) evalString(row chunk.Row) (res string, isNu return res, isNull, err } sc := b.ctx.GetSessionVars().StmtCtx - res, err = types.ProduceStrWithSpecifiedTp(val.String(), b.tp, sc, false) + res, err = types.ProduceStrWithSpecifiedTp(val.String(), b.tp, sc.TypeCtx, false) if err != nil { return res, false, err } @@ -1753,7 +1752,7 @@ func (b *builtinCastDurationAsStringSig) evalString(row chunk.Row) (res string, return res, isNull, err } sc := b.ctx.GetSessionVars().StmtCtx - res, err = types.ProduceStrWithSpecifiedTp(val.String(), b.tp, sc, false) + res, err = types.ProduceStrWithSpecifiedTp(val.String(), b.tp, sc.TypeCtx, false) if err != nil { return res, false, err } @@ -1855,7 +1854,7 @@ func (b *builtinCastJSONAsRealSig) evalReal(row chunk.Row) (res float64, isNull return res, isNull, err } sc := b.ctx.GetSessionVars().StmtCtx - res, err = types.ConvertJSONToFloat(sc, val) + res, err = types.ConvertJSONToFloat(sc.TypeCtx, val) return } @@ -1875,7 +1874,7 @@ func (b *builtinCastJSONAsDecimalSig) evalDecimal(row chunk.Row) (res *types.MyD return res, isNull, err } sc := b.ctx.GetSessionVars().StmtCtx - res, err = types.ConvertJSONToDecimal(sc, val) + res, err = types.ConvertJSONToDecimal(sc.TypeCtx, val) if err != nil { return res, false, err } @@ -1898,7 +1897,7 @@ func (b *builtinCastJSONAsStringSig) evalString(row chunk.Row) (res string, isNu if isNull || err != nil { return res, isNull, err } - s, err := types.ProduceStrWithSpecifiedTp(val.String(), b.tp, b.ctx.GetSessionVars().StmtCtx, false) + s, err := types.ProduceStrWithSpecifiedTp(val.String(), b.tp, b.ctx.GetSessionVars().StmtCtx.TypeCtx, false) if err != nil { return res, false, err } @@ -1961,7 +1960,7 @@ func (b *builtinCastJSONAsTimeSig) evalTime(row chunk.Row) (res types.Time, isNu return res, isNull, err default: err = types.ErrTruncatedWrongVal.GenWithStackByArgs(types.TypeStr(b.tp.GetType()), val.String()) - return res, true, b.ctx.GetSessionVars().StmtCtx.HandleTruncate(err) + return res, true, b.ctx.GetSessionVars().StmtCtx.TypeCtx.HandleTruncate(err) } } @@ -2003,12 +2002,12 @@ func (b *builtinCastJSONAsDurationSig) evalDuration(row chunk.Row) (res types.Du res, _, err = types.ParseDuration(stmtCtx, s, b.tp.GetDecimal()) if types.ErrTruncatedWrongVal.Equal(err) { sc := b.ctx.GetSessionVars().StmtCtx - err = sc.HandleTruncate(err) + err = sc.TypeCtx.HandleTruncate(err) } return res, isNull, err default: err = types.ErrTruncatedWrongVal.GenWithStackByArgs("TIME", val.String()) - return res, true, stmtCtx.HandleTruncate(err) + return res, true, stmtCtx.TypeCtx.HandleTruncate(err) } } diff --git a/pkg/expression/builtin_cast_test.go b/pkg/expression/builtin_cast_test.go index f6142331f2270..45dbf10d5a94d 100644 --- a/pkg/expression/builtin_cast_test.go +++ b/pkg/expression/builtin_cast_test.go @@ -36,14 +36,11 @@ func TestCastFunctions(t *testing.T) { sc := ctx.GetSessionVars().StmtCtx // Test `cast as char[(N)]` and `cast as binary[(N)]`. - originIgnoreTruncate := sc.IgnoreTruncate.Load() - originTruncateAsWarning := sc.TruncateAsWarning - sc.IgnoreTruncate.Store(false) - sc.TruncateAsWarning = true + oldTypeFlags := sc.TypeFlags() defer func() { - sc.IgnoreTruncate.Store(originIgnoreTruncate) - sc.TruncateAsWarning = originTruncateAsWarning + sc.SetTypeFlags(oldTypeFlags) }() + sc.SetTypeFlags(oldTypeFlags.WithTruncateAsWarning(true)) tp := types.NewFieldType(mysql.TypeString) tp.SetFlen(5) @@ -296,14 +293,20 @@ func TestCastFuncSig(t *testing.T) { ctx := createContext(t) sc := ctx.GetSessionVars().StmtCtx - originIgnoreTruncate := sc.IgnoreTruncate.Load() + originIgnoreTruncate := sc.TypeFlags().IgnoreTruncateErr() originTZ := sc.TimeZone() - sc.IgnoreTruncate.Store(true) + sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(true)) sc.SetTimeZone(time.UTC) defer func() { - sc.IgnoreTruncate.Store(originIgnoreTruncate) + sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(originIgnoreTruncate)) sc.SetTimeZone(originTZ) }() + + oldTypeFlags := sc.TypeFlags() + defer func() { + sc.SetTypeFlags(oldTypeFlags) + }() + sc.SetTypeFlags(oldTypeFlags.WithIgnoreTruncateErr(true)) var sig builtinFunc durationColumn := &Column{RetType: types.NewFieldType(mysql.TypeDuration), Index: 0} @@ -1105,11 +1108,11 @@ func TestCastFuncSig(t *testing.T) { func TestCastJSONAsDecimalSig(t *testing.T) { ctx := createContext(t) sc := ctx.GetSessionVars().StmtCtx - originIgnoreTruncate := sc.IgnoreTruncate.Load() - sc.IgnoreTruncate.Store(true) + oldTypeFlags := sc.TypeFlags() defer func() { - sc.IgnoreTruncate.Store(originIgnoreTruncate) + sc.SetTypeFlags(oldTypeFlags) }() + sc.SetTypeFlags(oldTypeFlags.WithIgnoreTruncateErr(true)) col := &Column{RetType: types.NewFieldType(mysql.TypeJSON), Index: 0} b, err := newBaseBuiltinFunc(ctx, "", []Expression{col}, types.NewFieldType(mysql.TypeNewDecimal)) @@ -1587,11 +1590,11 @@ func TestCastConstAsDecimalFieldType(t *testing.T) { func TestCastBinaryStringAsJSONSig(t *testing.T) { ctx := createContext(t) sc := ctx.GetSessionVars().StmtCtx - originIgnoreTruncate := sc.IgnoreTruncate.Load() - sc.IgnoreTruncate.Store(true) + oldTypeFlags := sc.TypeFlags() defer func() { - sc.IgnoreTruncate.Store(originIgnoreTruncate) + sc.SetTypeFlags(oldTypeFlags) }() + sc.SetTypeFlags(oldTypeFlags.WithIgnoreTruncateErr(true)) // BINARY STRING will be converted to a JSON opaque // and yield "base64:typeXX:" finally diff --git a/pkg/expression/builtin_cast_vec.go b/pkg/expression/builtin_cast_vec.go index 947aceecbf8da..4a9a1ba079d1d 100644 --- a/pkg/expression/builtin_cast_vec.go +++ b/pkg/expression/builtin_cast_vec.go @@ -51,7 +51,7 @@ func (b *builtinCastIntAsDurationSig) vecEvalDuration(input *chunk.Chunk, result err = b.ctx.GetSessionVars().StmtCtx.HandleOverflow(err, err) } if types.ErrTruncatedWrongVal.Equal(err) { - err = b.ctx.GetSessionVars().StmtCtx.HandleTruncate(err) + err = b.ctx.GetSessionVars().StmtCtx.TypeCtx.HandleTruncate(err) } if err != nil { return err @@ -223,7 +223,7 @@ func (b *builtinCastRealAsStringSig) vecEvalString(input *chunk.Chunk, result *c result.AppendNull() continue } - res, err = types.ProduceStrWithSpecifiedTp(strconv.FormatFloat(v, 'f', -1, bits), b.tp, sc, false) + res, err = types.ProduceStrWithSpecifiedTp(strconv.FormatFloat(v, 'f', -1, bits), b.tp, sc.TypeCtx, false) if err != nil { return err } @@ -263,7 +263,7 @@ func (b *builtinCastDecimalAsStringSig) vecEvalString(input *chunk.Chunk, result result.AppendNull() continue } - res, e := types.ProduceStrWithSpecifiedTp(string(v.ToString()), b.tp, sc, false) + res, e := types.ProduceStrWithSpecifiedTp(string(v.ToString()), b.tp, sc.TypeCtx, false) if e != nil { return e } @@ -456,7 +456,7 @@ func (b *builtinCastJSONAsRealSig) vecEvalReal(input *chunk.Chunk, result *chunk if result.IsNull(i) { continue } - f64s[i], err = types.ConvertJSONToFloat(sc, buf.GetJSON(i)) + f64s[i], err = types.ConvertJSONToFloat(sc.TypeCtx, buf.GetJSON(i)) if err != nil { return err } @@ -716,7 +716,7 @@ func (b *builtinCastIntAsStringSig) vecEvalString(input *chunk.Chunk, result *ch if isYearType && str == "0" { str = "0000" } - str, err = types.ProduceStrWithSpecifiedTp(str, b.tp, b.ctx.GetSessionVars().StmtCtx, false) + str, err = types.ProduceStrWithSpecifiedTp(str, b.tp, b.ctx.GetSessionVars().StmtCtx.TypeCtx, false) if err != nil { return err } @@ -951,7 +951,7 @@ func (b *builtinCastStringAsIntSig) vecEvalInt(input *chunk.Chunk, result *chunk return err } result.MergeNulls(buf) - sc := b.ctx.GetSessionVars().StmtCtx + typeCtx := b.ctx.GetSessionVars().StmtCtx.TypeCtx i64s := result.Int64s() isUnsigned := mysql.HasUnsignedFlag(b.tp.GetFlag()) unionUnsigned := isUnsigned && b.inUnion @@ -966,18 +966,18 @@ func (b *builtinCastStringAsIntSig) vecEvalInt(input *chunk.Chunk, result *chunk val := strings.TrimSpace(buf.GetString(i)) isNegative := len(val) > 1 && val[0] == '-' if !isNegative { - ures, err = types.StrToUint(sc, val, true) + ures, err = types.StrToUint(typeCtx, val, true) if !isUnsigned && err == nil && ures > uint64(math.MaxInt64) { - sc.AppendWarning(types.ErrCastAsSignedOverflow) + typeCtx.AppendWarning(types.ErrCastAsSignedOverflow) } res = int64(ures) } else if unionUnsigned { res = 0 } else { - res, err = types.StrToInt(sc, val, true) + res, err = types.StrToInt(typeCtx, val, true) if err == nil && isUnsigned { // If overflow, don't append this warnings - sc.AppendWarning(types.ErrCastNegIntAsUnsigned) + typeCtx.AppendWarning(types.ErrCastNegIntAsUnsigned) } } res, err = b.handleOverflow(res, val, err, isNegative) @@ -1013,7 +1013,7 @@ func (b *builtinCastStringAsDurationSig) vecEvalDuration(input *chunk.Chunk, res dur, isNull, err := types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, buf.GetString(i), b.tp.GetDecimal()) if err != nil { if types.ErrTruncatedWrongVal.Equal(err) { - err = b.ctx.GetSessionVars().StmtCtx.HandleTruncate(err) + err = b.ctx.GetSessionVars().StmtCtx.TypeCtx.HandleTruncate(err) } if err != nil { return err @@ -1187,7 +1187,7 @@ func (b *builtinCastJSONAsStringSig) vecEvalString(input *chunk.Chunk, result *c result.AppendNull() continue } - s, err := types.ProduceStrWithSpecifiedTp(buf.GetJSON(i).String(), b.tp, b.ctx.GetSessionVars().StmtCtx, false) + s, err := types.ProduceStrWithSpecifiedTp(buf.GetJSON(i).String(), b.tp, b.ctx.GetSessionVars().StmtCtx.TypeCtx, false) if err != nil { return err } @@ -1291,7 +1291,7 @@ func (b *builtinCastRealAsDurationSig) vecEvalDuration(input *chunk.Chunk, resul dur, _, err := types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, strconv.FormatFloat(f64s[i], 'f', -1, 64), b.tp.GetDecimal()) if err != nil { if types.ErrTruncatedWrongVal.Equal(err) { - err = b.ctx.GetSessionVars().StmtCtx.HandleTruncate(err) + err = b.ctx.GetSessionVars().StmtCtx.TypeCtx.HandleTruncate(err) if err != nil { return err } @@ -1395,7 +1395,7 @@ func (b *builtinCastDurationAsStringSig) vecEvalString(input *chunk.Chunk, resul result.AppendNull() continue } - res, err = types.ProduceStrWithSpecifiedTp(buf.GetDuration(i, fsp).String(), b.tp, sc, false) + res, err = types.ProduceStrWithSpecifiedTp(buf.GetDuration(i, fsp).String(), b.tp, sc.TypeCtx, false) if err != nil { return err } @@ -1600,7 +1600,7 @@ func (b *builtinCastTimeAsStringSig) vecEvalString(input *chunk.Chunk, result *c result.AppendNull() continue } - res, err = types.ProduceStrWithSpecifiedTp(v.String(), b.tp, sc, false) + res, err = types.ProduceStrWithSpecifiedTp(v.String(), b.tp, sc.TypeCtx, false) if err != nil { return err } @@ -1639,7 +1639,7 @@ func (b *builtinCastJSONAsDecimalSig) vecEvalDecimal(input *chunk.Chunk, result if result.IsNull(i) { continue } - tempres, err := types.ConvertJSONToDecimal(sc, buf.GetJSON(i)) + tempres, err := types.ConvertJSONToDecimal(sc.TypeCtx, buf.GetJSON(i)) if err != nil { return err } @@ -1685,7 +1685,7 @@ func (b *builtinCastStringAsRealSig) vecEvalReal(input *chunk.Chunk, result *chu if result.IsNull(i) { continue } - res, err := types.StrToFloat(sc, buf.GetString(i), true) + res, err := types.StrToFloat(sc.TypeCtx, buf.GetString(i), true) if err != nil { return err } @@ -1730,7 +1730,7 @@ func (b *builtinCastStringAsDecimalSig) vecEvalDecimal(input *chunk.Chunk, resul isNegative := len(val) > 0 && val[0] == '-' dec := new(types.MyDecimal) if !(b.inUnion && mysql.HasUnsignedFlag(b.tp.GetFlag()) && isNegative) { - if err := stmtCtx.HandleTruncate(dec.FromString([]byte(val))); err != nil { + if err := stmtCtx.TypeCtx.HandleTruncate(dec.FromString([]byte(val))); err != nil { return err } dec, err := types.ProduceDecWithSpecifiedTp(dec, b.tp, stmtCtx) @@ -1874,7 +1874,7 @@ func (b *builtinCastDecimalAsDurationSig) vecEvalDuration(input *chunk.Chunk, re dur, _, err := types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, string(args[i].ToString()), b.tp.GetDecimal()) if err != nil { if types.ErrTruncatedWrongVal.Equal(err) { - err = b.ctx.GetSessionVars().StmtCtx.HandleTruncate(err) + err = b.ctx.GetSessionVars().StmtCtx.TypeCtx.HandleTruncate(err) if err != nil { return err } @@ -1913,7 +1913,7 @@ func (b *builtinCastStringAsStringSig) vecEvalString(input *chunk.Chunk, result result.AppendNull() continue } - res, err = types.ProduceStrWithSpecifiedTp(buf.GetString(i), b.tp, sc, false) + res, err = types.ProduceStrWithSpecifiedTp(buf.GetString(i), b.tp, sc.TypeCtx, false) if err != nil { return err } @@ -1979,7 +1979,7 @@ func (b *builtinCastJSONAsDurationSig) vecEvalDuration(input *chunk.Chunk, resul } dur, _, err = types.ParseDuration(stmtCtx, s, b.tp.GetDecimal()) if types.ErrTruncatedWrongVal.Equal(err) { - err = stmtCtx.HandleTruncate(err) + err = stmtCtx.TypeCtx.HandleTruncate(err) } if err != nil { return err @@ -1987,7 +1987,7 @@ func (b *builtinCastJSONAsDurationSig) vecEvalDuration(input *chunk.Chunk, resul ds[i] = dur.Duration default: err = types.ErrTruncatedWrongVal.GenWithStackByArgs(types.TypeStr(b.tp.GetType()), val.String()) - err = stmtCtx.HandleTruncate(err) + err = stmtCtx.TypeCtx.HandleTruncate(err) if err != nil { return err } diff --git a/pkg/expression/builtin_compare_test.go b/pkg/expression/builtin_compare_test.go index 8df462792a2a0..698eea969c211 100644 --- a/pkg/expression/builtin_compare_test.go +++ b/pkg/expression/builtin_compare_test.go @@ -233,11 +233,11 @@ func TestIntervalFunc(t *testing.T) { ctx := createContext(t) sc := ctx.GetSessionVars().StmtCtx - origin := sc.IgnoreTruncate.Load() - sc.IgnoreTruncate.Store(true) + oldTypeFlags := sc.TypeFlags() defer func() { - sc.IgnoreTruncate.Store(origin) + sc.SetTypeFlags(oldTypeFlags) }() + sc.SetTypeFlags(oldTypeFlags.WithIgnoreTruncateErr(true)) for _, test := range []struct { args []types.Datum @@ -289,13 +289,14 @@ func TestIntervalFunc(t *testing.T) { func TestGreatestLeastFunc(t *testing.T) { ctx := createContext(t) sc := ctx.GetSessionVars().StmtCtx - originIgnoreTruncate := sc.IgnoreTruncate.Load() - sc.IgnoreTruncate.Store(true) - decG := &types.MyDecimal{} - decL := &types.MyDecimal{} + oldTypeFlags := sc.TypeFlags() defer func() { - sc.IgnoreTruncate.Store(originIgnoreTruncate) + sc.SetTypeFlags(oldTypeFlags) }() + sc.SetTypeFlags(oldTypeFlags.WithIgnoreTruncateErr(true)) + + decG := &types.MyDecimal{} + decL := &types.MyDecimal{} for _, test := range []struct { args []interface{} diff --git a/pkg/expression/builtin_control_test.go b/pkg/expression/builtin_control_test.go index 329df3c1cbce8..53a87a5bcf83b 100644 --- a/pkg/expression/builtin_control_test.go +++ b/pkg/expression/builtin_control_test.go @@ -61,11 +61,11 @@ func TestCaseWhen(t *testing.T) { func TestIf(t *testing.T) { ctx := createContext(t) stmtCtx := ctx.GetSessionVars().StmtCtx - origin := stmtCtx.IgnoreTruncate.Load() - stmtCtx.IgnoreTruncate.Store(true) + oldTypeFlags := stmtCtx.TypeFlags() defer func() { - stmtCtx.IgnoreTruncate.Store(origin) + stmtCtx.SetTypeFlags(oldTypeFlags) }() + stmtCtx.SetTypeFlags(oldTypeFlags.WithIgnoreTruncateErr(true)) tbl := []struct { Arg1 interface{} Arg2 interface{} diff --git a/pkg/expression/builtin_math_test.go b/pkg/expression/builtin_math_test.go index 9e44f6e5167e5..3f56b92180025 100644 --- a/pkg/expression/builtin_math_test.go +++ b/pkg/expression/builtin_math_test.go @@ -61,11 +61,11 @@ func TestAbs(t *testing.T) { func TestCeil(t *testing.T) { ctx := createContext(t) sc := ctx.GetSessionVars().StmtCtx - tmpIT := sc.IgnoreTruncate.Load() - sc.IgnoreTruncate.Store(true) + oldTypeFlags := sc.TypeFlags() defer func() { - sc.IgnoreTruncate.Store(tmpIT) + sc.SetTypeFlags(oldTypeFlags) }() + sc.SetTypeFlags(oldTypeFlags.WithIgnoreTruncateErr(true)) type testCase struct { arg interface{} @@ -177,11 +177,11 @@ func TestExp(t *testing.T) { func TestFloor(t *testing.T) { ctx := createContext(t) sc := ctx.GetSessionVars().StmtCtx - tmpIT := sc.IgnoreTruncate.Load() - sc.IgnoreTruncate.Store(true) + oldTypeFlags := sc.TypeFlags() defer func() { - sc.IgnoreTruncate.Store(tmpIT) + sc.SetTypeFlags(oldTypeFlags) }() + sc.SetTypeFlags(oldTypeFlags.WithIgnoreTruncateErr(true)) genDuration := func(h, m, s int64) types.Duration { duration := time.Duration(h)*time.Hour + @@ -631,11 +631,11 @@ func TestConv(t *testing.T) { func TestSign(t *testing.T) { ctx := createContext(t) sc := ctx.GetSessionVars().StmtCtx - tmpIT := sc.IgnoreTruncate.Load() - sc.IgnoreTruncate.Store(true) + oldTypeFlags := sc.TypeFlags() defer func() { - sc.IgnoreTruncate.Store(tmpIT) + sc.SetTypeFlags(oldTypeFlags) }() + sc.SetTypeFlags(oldTypeFlags.WithIgnoreTruncateErr(true)) for _, tt := range []struct { num []interface{} @@ -666,7 +666,12 @@ func TestSign(t *testing.T) { func TestDegrees(t *testing.T) { ctx := createContext(t) sc := ctx.GetSessionVars().StmtCtx - sc.IgnoreTruncate.Store(false) + oldTypeFlags := sc.TypeFlags() + defer func() { + sc.SetTypeFlags(oldTypeFlags) + }() + sc.SetTypeFlags(oldTypeFlags.WithIgnoreTruncateErr(false)) + cases := []struct { args interface{} expected float64 diff --git a/pkg/expression/builtin_op_test.go b/pkg/expression/builtin_op_test.go index e3cb53d605ca8..c5580fb913b84 100644 --- a/pkg/expression/builtin_op_test.go +++ b/pkg/expression/builtin_op_test.go @@ -70,11 +70,11 @@ func TestUnary(t *testing.T) { func TestLogicAnd(t *testing.T) { ctx := createContext(t) sc := ctx.GetSessionVars().StmtCtx - origin := sc.IgnoreTruncate.Load() + oldTypeFlags := sc.TypeFlags() defer func() { - sc.IgnoreTruncate.Store(origin) + sc.SetTypeFlags(oldTypeFlags) }() - sc.IgnoreTruncate.Store(true) + sc.SetTypeFlags(oldTypeFlags.WithIgnoreTruncateErr(true)) cases := []struct { args []interface{} @@ -243,11 +243,11 @@ func TestBitXor(t *testing.T) { func TestBitOr(t *testing.T) { ctx := createContext(t) sc := ctx.GetSessionVars().StmtCtx - origin := sc.IgnoreTruncate.Load() + oldTypeFlags := sc.TypeFlags() defer func() { - sc.IgnoreTruncate.Store(origin) + sc.SetTypeFlags(oldTypeFlags) }() - sc.IgnoreTruncate.Store(true) + sc.SetTypeFlags(oldTypeFlags.WithIgnoreTruncateErr(true)) cases := []struct { args []interface{} @@ -289,11 +289,11 @@ func TestBitOr(t *testing.T) { func TestLogicOr(t *testing.T) { ctx := createContext(t) sc := ctx.GetSessionVars().StmtCtx - origin := sc.IgnoreTruncate.Load() + oldTypeFlags := sc.TypeFlags() defer func() { - sc.IgnoreTruncate.Store(origin) + sc.SetTypeFlags(oldTypeFlags) }() - sc.IgnoreTruncate.Store(true) + sc.SetTypeFlags(oldTypeFlags.WithIgnoreTruncateErr(true)) cases := []struct { args []interface{} @@ -395,11 +395,11 @@ func TestBitAnd(t *testing.T) { func TestBitNeg(t *testing.T) { ctx := createContext(t) sc := ctx.GetSessionVars().StmtCtx - origin := sc.IgnoreTruncate.Load() + oldTypeFlags := sc.TypeFlags() defer func() { - sc.IgnoreTruncate.Store(origin) + sc.SetTypeFlags(oldTypeFlags) }() - sc.IgnoreTruncate.Store(true) + sc.SetTypeFlags(oldTypeFlags.WithIgnoreTruncateErr(true)) cases := []struct { args []interface{} @@ -441,11 +441,11 @@ func TestBitNeg(t *testing.T) { func TestUnaryNot(t *testing.T) { ctx := createContext(t) sc := ctx.GetSessionVars().StmtCtx - origin := sc.IgnoreTruncate.Load() + oldTypeFlags := sc.TypeFlags() defer func() { - sc.IgnoreTruncate.Store(origin) + sc.SetTypeFlags(oldTypeFlags) }() - sc.IgnoreTruncate.Store(true) + sc.SetTypeFlags(oldTypeFlags.WithIgnoreTruncateErr(true)) cases := []struct { args []interface{} @@ -495,11 +495,11 @@ func TestUnaryNot(t *testing.T) { func TestIsTrueOrFalse(t *testing.T) { ctx := createContext(t) sc := ctx.GetSessionVars().StmtCtx - origin := sc.IgnoreTruncate.Load() + oldTypeFlags := sc.TypeFlags() defer func() { - sc.IgnoreTruncate.Store(origin) + sc.SetTypeFlags(oldTypeFlags) }() - sc.IgnoreTruncate.Store(true) + sc.SetTypeFlags(oldTypeFlags.WithIgnoreTruncateErr(true)) testCases := []struct { args []interface{} @@ -602,11 +602,11 @@ func TestIsTrueOrFalse(t *testing.T) { func TestLogicXor(t *testing.T) { ctx := createContext(t) sc := ctx.GetSessionVars().StmtCtx - origin := sc.IgnoreTruncate.Load() + oldTypeFlags := sc.TypeFlags() defer func() { - sc.IgnoreTruncate.Store(origin) + sc.SetTypeFlags(oldTypeFlags) }() - sc.IgnoreTruncate.Store(true) + sc.SetTypeFlags(oldTypeFlags.WithIgnoreTruncateErr(true)) cases := []struct { args []interface{} diff --git a/pkg/expression/builtin_other.go b/pkg/expression/builtin_other.go index 16065e20429ed..aa8b40bc34a6f 100644 --- a/pkg/expression/builtin_other.go +++ b/pkg/expression/builtin_other.go @@ -1210,7 +1210,7 @@ func (b *builtinValuesIntSig) evalInt(_ chunk.Row) (int64, bool, error) { } if len(val) < 8 { var binary types.BinaryLiteral = val - v, err := binary.ToInt(b.ctx.GetSessionVars().StmtCtx) + v, err := binary.ToInt(b.ctx.GetSessionVars().StmtCtx.TypeCtx) if err != nil { return 0, true, errors.Trace(err) } diff --git a/pkg/expression/builtin_other_test.go b/pkg/expression/builtin_other_test.go index e6e9a9e13a184..c0d364edc9cfc 100644 --- a/pkg/expression/builtin_other_test.go +++ b/pkg/expression/builtin_other_test.go @@ -32,11 +32,11 @@ import ( func TestBitCount(t *testing.T) { ctx := createContext(t) stmtCtx := ctx.GetSessionVars().StmtCtx - origin := stmtCtx.IgnoreTruncate.Load() - stmtCtx.IgnoreTruncate.Store(true) + oldTypeFlags := stmtCtx.TypeFlags() defer func() { - stmtCtx.IgnoreTruncate.Store(origin) + stmtCtx.SetTypeFlags(oldTypeFlags) }() + stmtCtx.SetTypeFlags(oldTypeFlags.WithIgnoreTruncateErr(true)) fc := funcs[ast.BitCount] var bitCountCases = []struct { origin interface{} @@ -67,8 +67,8 @@ func TestBitCount(t *testing.T) { require.Nil(t, test.count) continue } - sc := stmtctx.NewStmtCtx() - sc.IgnoreTruncate.Store(true) + sc := stmtctx.NewStmtCtxWithTimeZone(stmtCtx.TimeZone()) + sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(true)) res, err := count.ToInt64(sc) require.NoError(t, err) require.Equal(t, test.count, res) diff --git a/pkg/expression/builtin_string_test.go b/pkg/expression/builtin_string_test.go index 286f8c3385cc0..a8f632a07001d 100644 --- a/pkg/expression/builtin_string_test.go +++ b/pkg/expression/builtin_string_test.go @@ -393,11 +393,11 @@ func TestConcatWSSig(t *testing.T) { func TestLeft(t *testing.T) { ctx := createContext(t) stmtCtx := ctx.GetSessionVars().StmtCtx - origin := stmtCtx.IgnoreTruncate.Load() - stmtCtx.IgnoreTruncate.Store(true) + oldTypeFlags := stmtCtx.TypeFlags() defer func() { - stmtCtx.IgnoreTruncate.Store(origin) + stmtCtx.SetTypeFlags(oldTypeFlags) }() + stmtCtx.SetTypeFlags(oldTypeFlags.WithIgnoreTruncateErr(true)) cases := []struct { args []interface{} @@ -443,11 +443,11 @@ func TestLeft(t *testing.T) { func TestRight(t *testing.T) { ctx := createContext(t) stmtCtx := ctx.GetSessionVars().StmtCtx - origin := stmtCtx.IgnoreTruncate.Load() - stmtCtx.IgnoreTruncate.Store(true) + oldTypeFlags := stmtCtx.TypeFlags() defer func() { - stmtCtx.IgnoreTruncate.Store(origin) + stmtCtx.SetTypeFlags(oldTypeFlags) }() + stmtCtx.SetTypeFlags(oldTypeFlags.WithIgnoreTruncateErr(true)) cases := []struct { args []interface{} @@ -956,11 +956,11 @@ func TestSubstringIndex(t *testing.T) { func TestSpace(t *testing.T) { ctx := createContext(t) stmtCtx := ctx.GetSessionVars().StmtCtx - origin := stmtCtx.IgnoreTruncate.Load() - stmtCtx.IgnoreTruncate.Store(true) + oldTypeFlags := stmtCtx.TypeFlags() defer func() { - stmtCtx.IgnoreTruncate.Store(origin) + stmtCtx.SetTypeFlags(oldTypeFlags) }() + stmtCtx.SetTypeFlags(oldTypeFlags.WithIgnoreTruncateErr(true)) cases := []struct { arg interface{} @@ -1388,7 +1388,8 @@ func TestBitLength(t *testing.T) { func TestChar(t *testing.T) { ctx := createContext(t) - ctx.GetSessionVars().StmtCtx.IgnoreTruncate.Store(true) + typeFlags := ctx.GetSessionVars().StmtCtx.TypeFlags() + ctx.GetSessionVars().StmtCtx.SetTypeFlags(typeFlags.WithIgnoreTruncateErr(true)) tbl := []struct { str string iNum int64 @@ -1509,11 +1510,11 @@ func TestFindInSet(t *testing.T) { func TestField(t *testing.T) { ctx := createContext(t) stmtCtx := ctx.GetSessionVars().StmtCtx - origin := stmtCtx.IgnoreTruncate.Load() - stmtCtx.IgnoreTruncate.Store(true) + oldTypeFlags := stmtCtx.TypeFlags() defer func() { - stmtCtx.IgnoreTruncate.Store(origin) + stmtCtx.SetTypeFlags(oldTypeFlags) }() + stmtCtx.SetTypeFlags(oldTypeFlags.WithIgnoreTruncateErr(true)) tbl := []struct { argLst []interface{} @@ -1991,8 +1992,8 @@ func TestFormat(t *testing.T) { testutil.DatumEqual(t, types.NewDatum(tt.ret), r) } - origConfig := ctx.GetSessionVars().StmtCtx.TruncateAsWarning - ctx.GetSessionVars().StmtCtx.TruncateAsWarning = true + origTypeFlags := ctx.GetSessionVars().StmtCtx.TypeFlags() + ctx.GetSessionVars().StmtCtx.SetTypeFlags(origTypeFlags.WithTruncateAsWarning(true)) for _, tt := range formatTests1 { f, err := fc.getFunction(ctx, datumsToConstants(types.MakeDatums(tt.number, tt.precision))) require.NoError(t, err) @@ -2009,7 +2010,7 @@ func TestFormat(t *testing.T) { ctx.GetSessionVars().StmtCtx.SetWarnings([]stmtctx.SQLWarn{}) } } - ctx.GetSessionVars().StmtCtx.TruncateAsWarning = origConfig + ctx.GetSessionVars().StmtCtx.SetTypeFlags(origTypeFlags) f2, err := fc.getFunction(ctx, datumsToConstants(types.MakeDatums(formatTests2.number, formatTests2.precision, formatTests2.locale))) require.NoError(t, err) @@ -2306,7 +2307,8 @@ func TestBin(t *testing.T) { fc := funcs[ast.Bin] dtbl := tblToDtbl(tbl) ctx := mock.NewContext() - ctx.GetSessionVars().StmtCtx.IgnoreTruncate.Store(true) + typeFlags := ctx.GetSessionVars().StmtCtx.TypeFlags() + ctx.GetSessionVars().StmtCtx.SetTypeFlags(typeFlags.WithIgnoreTruncateErr(true)) for _, c := range dtbl { f, err := fc.getFunction(ctx, datumsToConstants(c["Input"])) require.NoError(t, err) diff --git a/pkg/expression/builtin_time.go b/pkg/expression/builtin_time.go index 5af1a08e1be25..f9f9940fbec16 100644 --- a/pkg/expression/builtin_time.go +++ b/pkg/expression/builtin_time.go @@ -600,7 +600,7 @@ func calculateTimeDiff(sc *stmtctx.StatementContext, lhs, rhs types.Time) (d typ d = lhs.Sub(sc, &rhs) d.Duration, err = types.TruncateOverflowMySQLTime(d.Duration) if types.ErrTruncatedWrongVal.Equal(err) { - err = sc.HandleTruncate(err) + err = sc.TypeCtx.HandleTruncate(err) } return d, err != nil, err } @@ -615,7 +615,7 @@ func calculateDurationTimeDiff(ctx sessionctx.Context, lhs, rhs types.Duration) d.Duration, err = types.TruncateOverflowMySQLTime(d.Duration) if types.ErrTruncatedWrongVal.Equal(err) { sc := ctx.GetSessionVars().StmtCtx - err = sc.HandleTruncate(err) + err = sc.TypeCtx.HandleTruncate(err) } return d, err != nil, err } @@ -2275,7 +2275,7 @@ func (b *builtinTimeSig) evalDuration(row chunk.Row) (res types.Duration, isNull sc := b.ctx.GetSessionVars().StmtCtx res, _, err = types.ParseDuration(sc, expr, fsp) if types.ErrTruncatedWrongVal.Equal(err) { - err = sc.HandleTruncate(err) + err = sc.TypeCtx.HandleTruncate(err) } return res, isNull, err } @@ -5570,7 +5570,7 @@ func (b *builtinSecToTimeSig) evalDuration(row chunk.Row) (types.Duration, bool, minute = 59 second = 59 demical = 0 - err = b.ctx.GetSessionVars().StmtCtx.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("time", strconv.FormatFloat(secondsFloat, 'f', -1, 64))) + err = b.ctx.GetSessionVars().StmtCtx.TypeCtx.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("time", strconv.FormatFloat(secondsFloat, 'f', -1, 64))) if err != nil { return types.Duration{}, err != nil, err } diff --git a/pkg/expression/builtin_time_test.go b/pkg/expression/builtin_time_test.go index 456477b2b8961..85c175617414a 100644 --- a/pkg/expression/builtin_time_test.go +++ b/pkg/expression/builtin_time_test.go @@ -1492,11 +1492,11 @@ func TestStrToDate(t *testing.T) { func TestFromDays(t *testing.T) { ctx := createContext(t) stmtCtx := ctx.GetSessionVars().StmtCtx - origin := stmtCtx.IgnoreTruncate.Load() - stmtCtx.IgnoreTruncate.Store(true) + oldTypeFlags := stmtCtx.TypeFlags() defer func() { - stmtCtx.IgnoreTruncate.Store(origin) + stmtCtx.SetTypeFlags(oldTypeFlags) }() + stmtCtx.SetTypeFlags(oldTypeFlags.WithIgnoreTruncateErr(true)) tests := []struct { day int64 expect string @@ -1781,7 +1781,7 @@ func TestTimestampDiff(t *testing.T) { } sc := ctx.GetSessionVars().StmtCtx - sc.IgnoreTruncate.Store(true) + sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(true)) sc.IgnoreZeroInDate = true resetStmtContext(ctx) f, err := fc.getFunction(ctx, datumsToConstants([]types.Datum{types.NewStringDatum("DAY"), @@ -2721,11 +2721,11 @@ func TestTimeToSec(t *testing.T) { func TestSecToTime(t *testing.T) { ctx := createContext(t) stmtCtx := ctx.GetSessionVars().StmtCtx - origin := stmtCtx.IgnoreTruncate.Load() - stmtCtx.IgnoreTruncate.Store(true) + oldTypeFlags := stmtCtx.TypeFlags() defer func() { - stmtCtx.IgnoreTruncate.Store(origin) + stmtCtx.SetTypeFlags(oldTypeFlags) }() + stmtCtx.SetTypeFlags(oldTypeFlags.WithIgnoreTruncateErr(true)) fc := funcs[ast.SecToTime] diff --git a/pkg/expression/builtin_time_vec.go b/pkg/expression/builtin_time_vec.go index 0ff70bf96e7c6..851f2b2b2de2e 100644 --- a/pkg/expression/builtin_time_vec.go +++ b/pkg/expression/builtin_time_vec.go @@ -1928,7 +1928,7 @@ func (b *builtinSecToTimeSig) vecEvalDuration(input *chunk.Chunk, result *chunk. minute = 59 second = 59 demical = 0 - err = b.ctx.GetSessionVars().StmtCtx.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("time", strconv.FormatFloat(secondsFloat, 'f', -1, 64))) + err = b.ctx.GetSessionVars().StmtCtx.TypeCtx.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("time", strconv.FormatFloat(secondsFloat, 'f', -1, 64))) if err != nil { return err } @@ -2411,7 +2411,7 @@ func (b *builtinTimeSig) vecEvalDuration(input *chunk.Chunk, result *chunk.Colum res, _, err := types.ParseDuration(sc, expr, fsp) if types.ErrTruncatedWrongVal.Equal(err) { - err = sc.HandleTruncate(err) + err = sc.TypeCtx.HandleTruncate(err) } if err != nil { return err diff --git a/pkg/expression/builtin_time_vec_test.go b/pkg/expression/builtin_time_vec_test.go index b7a1d2f3ca40e..6bdbf7d783862 100644 --- a/pkg/expression/builtin_time_vec_test.go +++ b/pkg/expression/builtin_time_vec_test.go @@ -581,7 +581,8 @@ func BenchmarkVectorizedBuiltinTimeFunc(b *testing.B) { func TestVecMonth(t *testing.T) { ctx := mock.NewContext() ctx.GetSessionVars().SQLMode |= mysql.ModeNoZeroDate - ctx.GetSessionVars().StmtCtx.TruncateAsWarning = true + typeFlags := ctx.GetSessionVars().StmtCtx.TypeFlags() + ctx.GetSessionVars().StmtCtx.SetTypeFlags(typeFlags.WithTruncateAsWarning(true)) input := chunk.New([]*types.FieldType{types.NewFieldType(mysql.TypeDatetime)}, 3, 3) input.Reset() input.AppendTime(0, types.ZeroDate) @@ -594,6 +595,6 @@ func TestVecMonth(t *testing.T) { require.Equal(t, 0, len(ctx.GetSessionVars().StmtCtx.GetWarnings())) ctx.GetSessionVars().StmtCtx.InInsertStmt = true - ctx.GetSessionVars().StmtCtx.TruncateAsWarning = false + ctx.GetSessionVars().StmtCtx.SetTypeFlags(typeFlags.WithTruncateAsWarning(false)) require.NoError(t, f.vecEvalInt(input, result)) } diff --git a/pkg/expression/column.go b/pkg/expression/column.go index 95e3bd513a228..437d081e10152 100644 --- a/pkg/expression/column.go +++ b/pkg/expression/column.go @@ -422,7 +422,7 @@ func (col *Column) EvalInt(ctx sessionctx.Context, row chunk.Row) (int64, bool, return 0, true, nil } if val.Kind() == types.KindMysqlBit { - val, err := val.GetBinaryLiteral().ToInt(ctx.GetSessionVars().StmtCtx) + val, err := val.GetBinaryLiteral().ToInt(ctx.GetSessionVars().StmtCtx.TypeCtx) return int64(val), err != nil, err } res, err := val.ToInt64(ctx.GetSessionVars().StmtCtx) diff --git a/pkg/expression/constant.go b/pkg/expression/constant.go index bb6ea63068ebe..c9c7b8197d5de 100644 --- a/pkg/expression/constant.go +++ b/pkg/expression/constant.go @@ -278,13 +278,13 @@ func (c *Constant) EvalInt(ctx sessionctx.Context, row chunk.Row) (int64, bool, if c.GetType().GetType() == mysql.TypeNull || dt.IsNull() { return 0, true, nil } else if dt.Kind() == types.KindBinaryLiteral { - val, err := dt.GetBinaryLiteral().ToInt(ctx.GetSessionVars().StmtCtx) + val, err := dt.GetBinaryLiteral().ToInt(ctx.GetSessionVars().StmtCtx.TypeCtx) return int64(val), err != nil, err } else if c.GetType().Hybrid() || dt.Kind() == types.KindString { res, err := dt.ToInt64(ctx.GetSessionVars().StmtCtx) return res, false, err } else if dt.Kind() == types.KindMysqlBit { - uintVal, err := dt.GetBinaryLiteral().ToInt(ctx.GetSessionVars().StmtCtx) + uintVal, err := dt.GetBinaryLiteral().ToInt(ctx.GetSessionVars().StmtCtx.TypeCtx) return int64(uintVal), false, err } return dt.GetInt64(), false, nil @@ -303,7 +303,7 @@ func (c *Constant) EvalReal(ctx sessionctx.Context, row chunk.Row) (float64, boo return 0, true, nil } if c.GetType().Hybrid() || dt.Kind() == types.KindBinaryLiteral || dt.Kind() == types.KindString { - res, err := dt.ToFloat64(ctx.GetSessionVars().StmtCtx) + res, err := dt.ToFloat64(ctx.GetSessionVars().StmtCtx.TypeCtx) return res, false, err } return dt.GetFloat64(), false, nil @@ -337,7 +337,7 @@ func (c *Constant) EvalDecimal(ctx sessionctx.Context, row chunk.Row) (*types.My if c.GetType().GetType() == mysql.TypeNull || dt.IsNull() { return nil, true, nil } - res, err := dt.ToDecimal(ctx.GetSessionVars().StmtCtx) + res, err := dt.ToDecimal(ctx.GetSessionVars().StmtCtx.TypeCtx) if err != nil { return nil, false, err } diff --git a/pkg/expression/constant_fold.go b/pkg/expression/constant_fold.go index 170f7100d1381..691c8d50a29d6 100644 --- a/pkg/expression/constant_fold.go +++ b/pkg/expression/constant_fold.go @@ -208,7 +208,7 @@ func foldConstant(expr Expression) (Expression, bool) { // of Constant to nil is ok. return &Constant{Value: value, RetType: x.RetType}, false } - if isTrue, err := value.ToBool(sc); err == nil && isTrue == 0 { + if isTrue, err := value.ToBool(sc.TypeCtx); err == nil && isTrue == 0 { // This Constant is created to compose the result expression of EvaluateExprWithNull when InNullRejectCheck // is true. We just check whether the result expression is null or false and then let it die. Basically, // the constant is used once briefly and will not be retained for a long time. Hence setting DeferredExpr diff --git a/pkg/expression/errors.go b/pkg/expression/errors.go index 6f0b037af984a..9cdf0a582ec42 100644 --- a/pkg/expression/errors.go +++ b/pkg/expression/errors.go @@ -76,7 +76,7 @@ func handleInvalidTimeError(ctx sessionctx.Context, err error) error { return err } sc := ctx.GetSessionVars().StmtCtx - err = sc.HandleTruncate(err) + err = sc.TypeCtx.HandleTruncate(err) if ctx.GetSessionVars().StrictSQLMode && (sc.InInsertStmt || sc.InUpdateStmt || sc.InDeleteStmt) { return err } @@ -104,7 +104,7 @@ func handleAllowedPacketOverflowed(ctx sessionctx.Context, exprName string, maxA sc := ctx.GetSessionVars().StmtCtx // insert|update|delete ignore ... - if sc.TruncateAsWarning { + if sc.TypeFlags().TruncateAsWarning() { sc.AppendWarning(err) return nil } diff --git a/pkg/expression/evaluator_test.go b/pkg/expression/evaluator_test.go index daeefbfacff8b..9c727ba10f616 100644 --- a/pkg/expression/evaluator_test.go +++ b/pkg/expression/evaluator_test.go @@ -206,7 +206,7 @@ func TestBinopComparison(t *testing.T) { require.NoError(t, err) v, err := evalBuiltinFunc(f, chunk.Row{}) require.NoError(t, err) - val, err := v.ToBool(ctx.GetSessionVars().StmtCtx) + val, err := v.ToBool(ctx.GetSessionVars().StmtCtx.TypeCtx) require.NoError(t, err) require.Equal(t, tt.result, val) } @@ -407,10 +407,10 @@ func TestBinopNumeric(t *testing.T) { default: // we use float64 as the result type check for all. sc := ctx.GetSessionVars().StmtCtx - f, err := v.ToFloat64(sc) + f, err := v.ToFloat64(sc.TypeCtx) require.NoError(t, err) d := types.NewDatum(tt.ret) - r, err := d.ToFloat64(sc) + r, err := d.ToFloat64(sc.TypeCtx) require.NoError(t, err) require.Equal(t, r, f) } diff --git a/pkg/expression/expression.go b/pkg/expression/expression.go index 23fa0ac2f4bbf..f81baa6d875ae 100644 --- a/pkg/expression/expression.go +++ b/pkg/expression/expression.go @@ -274,7 +274,7 @@ func EvalBool(ctx sessionctx.Context, exprList CNFExprs, row chunk.Row) (bool, b continue } - i, err := data.ToBool(ctx.GetSessionVars().StmtCtx) + i, err := data.ToBool(ctx.GetSessionVars().StmtCtx.TypeCtx) if err != nil { i, err = HandleOverflowOnSelection(ctx.GetSessionVars().StmtCtx, i, err) if err != nil { @@ -494,14 +494,14 @@ func toBool(sc *stmtctx.StatementContext, tp *types.FieldType, eType types.EvalT } case mysql.TypeBit: var bl types.BinaryLiteral = buf.GetBytes(i) - iVal, err := bl.ToInt(sc) + iVal, err := bl.ToInt(sc.TypeCtx) if err != nil { return err } fVal = float64(iVal) } } else { - fVal, err = types.StrToFloat(sc, sVal, false) + fVal, err = types.StrToFloat(sc.TypeCtx, sVal, false) if err != nil { return err } @@ -963,6 +963,8 @@ func TableInfo2SchemaAndNames(ctx sessionctx.Context, dbName model.CIStr, tbl *m } // ColumnInfos2ColumnsAndNames converts the ColumnInfo to the *Column and NameSlice. +// This function is **unsafe** to be called concurrently, unless the `IgnoreTruncate` has been set to `true`. The only +// known case which will call this function concurrently is `CheckTableExec`. Ref #18408 and #42341. func ColumnInfos2ColumnsAndNames(ctx sessionctx.Context, dbName, tblName model.CIStr, colInfos []*model.ColumnInfo, tblInfo *model.TableInfo) ([]*Column, types.NameSlice, error) { columns := make([]*Column, 0, len(colInfos)) names := make([]*types.FieldName, 0, len(colInfos)) @@ -987,11 +989,14 @@ func ColumnInfos2ColumnsAndNames(ctx sessionctx.Context, dbName, tblName model.C // Resolve virtual generated column. mockSchema := NewSchema(columns...) // Ignore redundant warning here. - save := ctx.GetSessionVars().StmtCtx.IgnoreTruncate.Load() - defer func() { - ctx.GetSessionVars().StmtCtx.IgnoreTruncate.Store(save) - }() - ctx.GetSessionVars().StmtCtx.IgnoreTruncate.Store(true) + flags := ctx.GetSessionVars().StmtCtx.TypeFlags() + if !flags.IgnoreTruncateErr() { + defer func() { + ctx.GetSessionVars().StmtCtx.SetTypeFlags(flags) + }() + ctx.GetSessionVars().StmtCtx.SetTypeFlags(flags.WithIgnoreTruncateErr(true)) + } + for i, col := range colInfos { if col.IsVirtualGenerated() { expr, err := generatedexpr.ParseExpression(col.GeneratedExprString) diff --git a/pkg/expression/helper.go b/pkg/expression/helper.go index 2cb0f161bb13f..6308608794319 100644 --- a/pkg/expression/helper.go +++ b/pkg/expression/helper.go @@ -182,7 +182,7 @@ func getStmtTimestamp(ctx sessionctx.Context) (time.Time, error) { return now, err } - timestamp, err := types.StrToFloat(sessionVars.StmtCtx, timestampStr, false) + timestamp, err := types.StrToFloat(sessionVars.StmtCtx.TypeCtx, timestampStr, false) if err != nil { return time.Time{}, err } diff --git a/pkg/expression/main_test.go b/pkg/expression/main_test.go index 274600f443ef3..8813cba966c88 100644 --- a/pkg/expression/main_test.go +++ b/pkg/expression/main_test.go @@ -59,7 +59,7 @@ func createContext(t *testing.T) *mock.Context { ctx := mock.NewContext() ctx.GetSessionVars().StmtCtx.SetTimeZone(time.Local) sc := ctx.GetSessionVars().StmtCtx - sc.TruncateAsWarning = true + sc.SetTypeFlags(sc.TypeFlags().WithTruncateAsWarning(true)) require.NoError(t, ctx.GetSessionVars().SetSystemVar("max_allowed_packet", "67108864")) ctx.GetSessionVars().PlanColumnID.Store(0) return ctx diff --git a/pkg/expression/scalar_function.go b/pkg/expression/scalar_function.go index a349670be4e65..31c5d84117a78 100644 --- a/pkg/expression/scalar_function.go +++ b/pkg/expression/scalar_function.go @@ -419,7 +419,7 @@ func (sf *ScalarFunction) Eval(row chunk.Row) (d types.Datum, err error) { res, err = types.ParseEnum(tp.GetElems(), str, tp.GetCollate()) if ctx := sf.GetCtx(); ctx != nil { if sc := ctx.GetSessionVars().StmtCtx; sc != nil { - err = sc.HandleTruncate(err) + err = sc.TypeCtx.HandleTruncate(err) } } } else { diff --git a/pkg/planner/cardinality/selectivity.go b/pkg/planner/cardinality/selectivity.go index d2d5d8ff8041a..4fca0886eb9c9 100644 --- a/pkg/planner/cardinality/selectivity.go +++ b/pkg/planner/cardinality/selectivity.go @@ -277,7 +277,7 @@ func Selectivity( ret *= 0 mask &^= 1 << uint64(i) delete(notCoveredConstants, i) - } else if isTrue, err := c.Value.ToBool(sc); err == nil { + } else if isTrue, err := c.Value.ToBool(sc.TypeCtx); err == nil { if isTrue == 0 { // c is false ret *= 0 diff --git a/pkg/planner/core/casetest/physicalplantest/physical_plan_test.go b/pkg/planner/core/casetest/physicalplantest/physical_plan_test.go index 0d481614c6e25..d05180be31f25 100644 --- a/pkg/planner/core/casetest/physicalplantest/physical_plan_test.go +++ b/pkg/planner/core/casetest/physicalplantest/physical_plan_test.go @@ -123,7 +123,7 @@ func TestRefine(t *testing.T) { stmt, err := p.ParseOneStmt(tt, "", "") require.NoError(t, err, comment) sc := tk.Session().GetSessionVars().StmtCtx - sc.IgnoreTruncate.Store(false) + sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(false)) p, _, err := planner.Optimize(context.TODO(), tk.Session(), stmt, is) require.NoError(t, err, comment) testdata.OnRecord(func() { @@ -156,7 +156,7 @@ func TestAggEliminator(t *testing.T) { stmt, err := p.ParseOneStmt(tt, "", "") require.NoError(t, err, comment) sc := tk.Session().GetSessionVars().StmtCtx - sc.IgnoreTruncate.Store(false) + sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(false)) p, _, err := planner.Optimize(context.TODO(), tk.Session(), stmt, is) require.NoError(t, err) testdata.OnRecord(func() { diff --git a/pkg/planner/core/logical_plan_builder.go b/pkg/planner/core/logical_plan_builder.go index c4febef86bfab..ff6ddec9e4c42 100644 --- a/pkg/planner/core/logical_plan_builder.go +++ b/pkg/planner/core/logical_plan_builder.go @@ -2444,8 +2444,8 @@ func getUintFromNode(ctx sessionctx.Context, n ast.Node, mustInt64orUint64 bool) return uint64(v), false, true } case string: - sc := ctx.GetSessionVars().StmtCtx - uVal, err := types.StrToUint(sc, v, false) + ctx := ctx.GetSessionVars().StmtCtx.TypeCtx + uVal, err := types.StrToUint(ctx, v, false) if err != nil { return 0, false, false } @@ -6780,10 +6780,12 @@ func (b *PlanBuilder) buildWindowFunctionFrameBound(_ context.Context, spec *ast // If it has paramMarker and is in prepare stmt. We don't need to eval it since its value is not decided yet. if !checker.InPrepareStmt { // Do not raise warnings for truncate. - oriIgnoreTruncate := b.ctx.GetSessionVars().StmtCtx.IgnoreTruncate.Load() - b.ctx.GetSessionVars().StmtCtx.IgnoreTruncate.Store(true) + sc := b.ctx.GetSessionVars().StmtCtx + oldTypeFlags := sc.TypeFlags() + newTypeFlags := oldTypeFlags.WithIgnoreTruncateErr(true) + sc.SetTypeFlags(newTypeFlags) uVal, isNull, err := expr.EvalInt(b.ctx, chunk.Row{}) - b.ctx.GetSessionVars().StmtCtx.IgnoreTruncate.Store(oriIgnoreTruncate) + sc.SetTypeFlags(oldTypeFlags) if uVal < 0 || isNull || err != nil { return nil, ErrWindowFrameIllegal.GenWithStackByArgs(getWindowName(spec.Name.O)) } diff --git a/pkg/planner/core/logical_plans.go b/pkg/planner/core/logical_plans.go index ddec51d8b2064..c7ad4e5f42ebf 100644 --- a/pkg/planner/core/logical_plans.go +++ b/pkg/planner/core/logical_plans.go @@ -373,7 +373,7 @@ func (p *LogicalJoin) extractFDForOuterJoin(filtersFromApply []expression.Expres // if one of the inner condition is constant false, the inner side are all null, left make constant all of that. for _, one := range innerCondition { if c, ok := one.(*expression.Constant); ok && c.DeferredExpr == nil && c.ParamMarker == nil { - if isTrue, err := c.Value.ToBool(p.SCtx().GetSessionVars().StmtCtx); err == nil { + if isTrue, err := c.Value.ToBool(p.SCtx().GetSessionVars().StmtCtx.TypeCtx); err == nil { if isTrue == 0 { // c is false opt.InnerIsFalse = true diff --git a/pkg/planner/core/planbuilder.go b/pkg/planner/core/planbuilder.go index 4e109bd6a5eb1..700890377be4c 100644 --- a/pkg/planner/core/planbuilder.go +++ b/pkg/planner/core/planbuilder.go @@ -3028,7 +3028,7 @@ func handleAnalyzeOptionsV2(opts []ast.AnalyzeOpt) (map[ast.AnalyzeOptionType]ui optMap[opt.Type] = v case ast.AnalyzeOptSampleRate: // Only Int/Float/decimal is accepted, so pass nil here is safe. - fVal, err := datumValue.ToFloat64(nil) + fVal, err := datumValue.ToFloat64(types.DefaultNoWarningContext) if err != nil { return nil, err } @@ -3091,7 +3091,7 @@ func handleAnalyzeOptions(opts []ast.AnalyzeOpt, statsVer int) (map[ast.AnalyzeO optMap[opt.Type] = v case ast.AnalyzeOptSampleRate: // Only Int/Float/decimal is accepted, so pass nil here is safe. - fVal, err := datumValue.ToFloat64(nil) + fVal, err := datumValue.ToFloat64(types.DefaultNoWarningContext) if err != nil { return nil, err } diff --git a/pkg/planner/core/rule_partition_processor.go b/pkg/planner/core/rule_partition_processor.go index ccf16b62920eb..99367e8faaf71 100644 --- a/pkg/planner/core/rule_partition_processor.go +++ b/pkg/planner/core/rule_partition_processor.go @@ -518,7 +518,7 @@ func newListPartitionPruner(ctx sessionctx.Context, tbl table.Table, partitionNa func (l *listPartitionPruner) locatePartition(cond expression.Expression) (tables.ListPartitionLocation, bool, error) { switch sf := cond.(type) { case *expression.Constant: - b, err := sf.Value.ToBool(l.ctx.GetSessionVars().StmtCtx) + b, err := sf.Value.ToBool(l.ctx.GetSessionVars().StmtCtx.TypeCtx) if err == nil && b == 0 { // A constant false expression. return nil, false, nil @@ -1297,7 +1297,7 @@ type rangePruner struct { func (p *rangePruner) partitionRangeForExpr(sctx sessionctx.Context, expr expression.Expression) (start int, end int, ok bool) { if constExpr, ok := expr.(*expression.Constant); ok { - if b, err := constExpr.Value.ToBool(sctx.GetSessionVars().StmtCtx); err == nil && b == 0 { + if b, err := constExpr.Value.ToBool(sctx.GetSessionVars().StmtCtx.TypeCtx); err == nil && b == 0 { // A constant false expression. return 0, 0, true } diff --git a/pkg/planner/core/rule_predicate_push_down.go b/pkg/planner/core/rule_predicate_push_down.go index dd2f82cd6ac09..bbbf233a3efcc 100644 --- a/pkg/planner/core/rule_predicate_push_down.go +++ b/pkg/planner/core/rule_predicate_push_down.go @@ -447,7 +447,7 @@ func isNullRejected(ctx sessionctx.Context, schema *expression.Schema, expr expr } if x.Value.IsNull() { return true - } else if isTrue, err := x.Value.ToBool(sc); err == nil && isTrue == 0 { + } else if isTrue, err := x.Value.ToBool(sc.TypeCtxOrDefault()); err == nil && isTrue == 0 { return true } } @@ -707,7 +707,7 @@ func Conds2TableDual(p LogicalPlan, conds []expression.Expression) LogicalPlan { if expression.MaybeOverOptimized4PlanCache(p.SCtx(), []expression.Expression{con}) { return nil } - if isTrue, err := con.Value.ToBool(sc); (err == nil && isTrue == 0) || con.Value.IsNull() { + if isTrue, err := con.Value.ToBool(sc.TypeCtxOrDefault()); (err == nil && isTrue == 0) || con.Value.IsNull() { dual := LogicalTableDual{}.Init(p.SCtx(), p.SelectBlockOffset()) dual.SetSchema(p.Schema()) return dual @@ -729,7 +729,7 @@ func DeleteTrueExprs(p LogicalPlan, conds []expression.Expression) []expression. continue } sc := p.SCtx().GetSessionVars().StmtCtx - if isTrue, err := con.Value.ToBool(sc); err == nil && isTrue == 1 { + if isTrue, err := con.Value.ToBool(sc.TypeCtx); err == nil && isTrue == 1 { continue } newConds = append(newConds, cond) diff --git a/pkg/server/internal/parse/parse.go b/pkg/server/internal/parse/parse.go index e55ca68eb6656..1b132a9d0d2eb 100644 --- a/pkg/server/internal/parse/parse.go +++ b/pkg/server/internal/parse/parse.go @@ -246,7 +246,7 @@ func ExecArgs(sc *stmtctx.StatementContext, params []expression.Expression, boun args[i] = types.NewDecimalDatum(nil) } else { var dec types.MyDecimal - err = sc.HandleTruncate(dec.FromString(v)) + err = sc.TypeCtx.HandleTruncate(dec.FromString(v)) if err != nil { return err } diff --git a/pkg/sessionctx/stmtctx/BUILD.bazel b/pkg/sessionctx/stmtctx/BUILD.bazel index 7595601c9dd19..1178d300448ea 100644 --- a/pkg/sessionctx/stmtctx/BUILD.bazel +++ b/pkg/sessionctx/stmtctx/BUILD.bazel @@ -7,7 +7,6 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/domain/resourcegroup", - "//pkg/errno", "//pkg/parser", "//pkg/parser/ast", "//pkg/parser/model", @@ -53,7 +52,6 @@ go_test( "@com_github_pingcap_errors//:errors", "@com_github_stretchr_testify//require", "@com_github_tikv_client_go_v2//util", - "@org_uber_go_atomic//:atomic", "@org_uber_go_goleak//:goleak", ], ) diff --git a/pkg/sessionctx/stmtctx/stmtctx.go b/pkg/sessionctx/stmtctx/stmtctx.go index dfe39df7db671..2f28d0a3fe3bc 100644 --- a/pkg/sessionctx/stmtctx/stmtctx.go +++ b/pkg/sessionctx/stmtctx/stmtctx.go @@ -29,7 +29,6 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/domain/resourcegroup" - "github.com/pingcap/tidb/pkg/errno" "github.com/pingcap/tidb/pkg/parser" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/model" @@ -177,13 +176,11 @@ type StatementContext struct { InCreateOrAlterStmt bool InSetSessionStatesStmt bool InPreparedPlanBuilding bool - IgnoreTruncate atomic2.Bool IgnoreZeroInDate bool NoZeroDate bool DupKeyAsWarning bool BadNullAsWarning bool DividedByZeroAsWarning bool - TruncateAsWarning bool OverflowAsWarning bool ErrAutoincReadFailedAsWarning bool InShowWarning bool @@ -1019,39 +1016,6 @@ func (sc *StatementContext) AppendExtraError(warn error) { } } -// HandleTruncate ignores or returns the error based on the StatementContext state. -func (sc *StatementContext) HandleTruncate(err error) error { - // TODO: At present we have not checked whether the error can be ignored or treated as warning. - // We will do that later, and then append WarnDataTruncated instead of the error itself. - if err == nil { - return nil - } - - err = errors.Cause(err) - if e, ok := err.(*errors.Error); !ok || - (e.Code() != errno.ErrTruncatedWrongValue && - e.Code() != errno.ErrDataTooLong && - e.Code() != errno.ErrTruncatedWrongValueForField && - e.Code() != errno.ErrWarnDataOutOfRange && - e.Code() != errno.ErrDataOutOfRange && - e.Code() != errno.ErrBadNumber && - e.Code() != errno.ErrWrongValueForType && - e.Code() != errno.ErrDatetimeFunctionOverflow && - e.Code() != errno.WarnDataTruncated && - e.Code() != errno.ErrIncorrectDatetimeValue) { - return err - } - - if sc.IgnoreTruncate.Load() { - return nil - } - if sc.TruncateAsWarning { - sc.AppendWarning(err) - return nil - } - return err -} - // HandleOverflow treats ErrOverflow as warnings or returns the error based on the StmtCtx.OverflowAsWarning state. func (sc *StatementContext) HandleOverflow(err error, warnErr error) error { if err == nil { @@ -1169,7 +1133,8 @@ func (sc *StatementContext) ShouldClipToZero() bool { // ShouldIgnoreOverflowError indicates whether we should ignore the error when type conversion overflows, // so we can leave it for further processing like clipping values less than 0 to 0 for unsigned integer types. func (sc *StatementContext) ShouldIgnoreOverflowError() bool { - if (sc.InInsertStmt && sc.TruncateAsWarning) || sc.InLoadDataStmt { + // TODO: move this function into `/types` pkg + if (sc.InInsertStmt && sc.TypeCtx.Flags().TruncateAsWarning()) || sc.InLoadDataStmt { return true } return false @@ -1185,9 +1150,9 @@ func (sc *StatementContext) PushDownFlags() uint64 { } else if sc.InSelectStmt { flags |= model.FlagInSelectStmt } - if sc.IgnoreTruncate.Load() { + if sc.TypeCtx.Flags().IgnoreTruncateErr() { flags |= model.FlagIgnoreTruncate - } else if sc.TruncateAsWarning { + } else if sc.TypeCtx.Flags().TruncateAsWarning() { flags |= model.FlagTruncateAsWarning } if sc.OverflowAsWarning { @@ -1252,15 +1217,21 @@ func (sc *StatementContext) CopTasksDetails() *CopTasksDetails { return d } -// SetFlagsFromPBFlag set the flag of StatementContext from a `tipb.SelectRequest.Flags`. -func (sc *StatementContext) SetFlagsFromPBFlag(flags uint64) { - sc.IgnoreTruncate.Store((flags & model.FlagIgnoreTruncate) > 0) - sc.TruncateAsWarning = (flags & model.FlagTruncateAsWarning) > 0 +// InitFromPBFlagAndTz set the flag and timezone of StatementContext from a `tipb.SelectRequest.Flags` and `*time.Location`. +func (sc *StatementContext) InitFromPBFlagAndTz(flags uint64, tz *time.Location) { sc.InInsertStmt = (flags & model.FlagInInsertStmt) > 0 sc.InSelectStmt = (flags & model.FlagInSelectStmt) > 0 + sc.InDeleteStmt = (flags & model.FlagInUpdateOrDeleteStmt) > 0 sc.OverflowAsWarning = (flags & model.FlagOverflowAsWarning) > 0 sc.IgnoreZeroInDate = (flags & model.FlagIgnoreZeroInDate) > 0 sc.DividedByZeroAsWarning = (flags & model.FlagDividedByZeroAsWarning) > 0 + sc.SetTimeZone(tz) + + typeFlags := sc.TypeCtx.Flags() + typeFlags = typeFlags. + WithIgnoreTruncateErr((flags & model.FlagIgnoreTruncate) > 0). + WithTruncateAsWarning((flags & model.FlagTruncateAsWarning) > 0) + sc.TypeCtx = typectx.NewContext(typeFlags, tz, sc.AppendWarning) } // GetLockWaitStartTime returns the statement pessimistic lock wait start time @@ -1399,6 +1370,18 @@ func (sc *StatementContext) RecordedStatsLoadStatusCnt() (cnt int) { return } +// TypeCtxOrDefault returns the reference to the `TypeCtx` inside the statement context. +// If the statement context is nil, it'll return a newly created default type context. +// **don't** use this function if you can make sure the `sc` is not nil. We should limit the usage of this function as +// little as possible. +func (sc *StatementContext) TypeCtxOrDefault() typectx.Context { + if sc != nil { + return sc.TypeCtx + } + + return typectx.DefaultNoWarningContext +} + // UsedStatsInfoForTable records stats that are used during query and their information. type UsedStatsInfoForTable struct { Name string diff --git a/pkg/sessionctx/stmtctx/stmtctx_test.go b/pkg/sessionctx/stmtctx/stmtctx_test.go index c9f986abf808c..1ae5c513ea7a8 100644 --- a/pkg/sessionctx/stmtctx/stmtctx_test.go +++ b/pkg/sessionctx/stmtctx/stmtctx_test.go @@ -34,7 +34,6 @@ import ( "github.com/pingcap/tidb/pkg/util/execdetails" "github.com/stretchr/testify/require" "github.com/tikv/client-go/v2/util" - "go.uber.org/atomic" ) func TestCopTasksDetails(t *testing.T) { @@ -95,19 +94,19 @@ func TestStatementContextPushDownFLags(t *testing.T) { {newStmtCtx(func(sc *stmtctx.StatementContext) { sc.InUpdateStmt = true }), 16}, {newStmtCtx(func(sc *stmtctx.StatementContext) { sc.InDeleteStmt = true }), 16}, {newStmtCtx(func(sc *stmtctx.StatementContext) { sc.InSelectStmt = true }), 32}, - {newStmtCtx(func(sc *stmtctx.StatementContext) { sc.IgnoreTruncate = *atomic.NewBool(true) }), 1}, - {newStmtCtx(func(sc *stmtctx.StatementContext) { sc.TruncateAsWarning = true }), 2}, + {newStmtCtx(func(sc *stmtctx.StatementContext) { sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(true)) }), 1}, + {newStmtCtx(func(sc *stmtctx.StatementContext) { sc.SetTypeFlags(sc.TypeFlags().WithTruncateAsWarning(true)) }), 2}, {newStmtCtx(func(sc *stmtctx.StatementContext) { sc.OverflowAsWarning = true }), 64}, {newStmtCtx(func(sc *stmtctx.StatementContext) { sc.IgnoreZeroInDate = true }), 128}, {newStmtCtx(func(sc *stmtctx.StatementContext) { sc.DividedByZeroAsWarning = true }), 256}, {newStmtCtx(func(sc *stmtctx.StatementContext) { sc.InLoadDataStmt = true }), 1024}, {newStmtCtx(func(sc *stmtctx.StatementContext) { sc.InSelectStmt = true - sc.TruncateAsWarning = true + sc.SetTypeFlags(sc.TypeFlags().WithTruncateAsWarning(true)) }), 34}, {newStmtCtx(func(sc *stmtctx.StatementContext) { sc.DividedByZeroAsWarning = true - sc.IgnoreTruncate = *atomic.NewBool(true) + sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(true)) }), 257}, {newStmtCtx(func(sc *stmtctx.StatementContext) { sc.InUpdateStmt = true diff --git a/pkg/store/mockstore/mockcopr/analyze.go b/pkg/store/mockstore/mockcopr/analyze.go index d4b19736b9ad2..87cabb184163b 100644 --- a/pkg/store/mockstore/mockcopr/analyze.go +++ b/pkg/store/mockstore/mockcopr/analyze.go @@ -83,7 +83,12 @@ func (h coprHandler) handleAnalyzeIndexReq(req *coprocessor.Request, analyzeReq execDetail: new(execDetail), hdStatus: tablecodec.HandleNotNeeded, } - statsBuilder := statistics.NewSortedBuilder(flagsToStatementContext(analyzeReq.Flags), analyzeReq.IdxReq.BucketSize, 0, types.NewFieldType(mysql.TypeBlob), statistics.Version1) + + tz, err := timeutil.ConstructTimeZone("", int(analyzeReq.TimeZoneOffset)) + if err != nil { + return nil, errors.Trace(err) + } + statsBuilder := statistics.NewSortedBuilder(flagsAndTzToStatementContext(analyzeReq.Flags, tz), analyzeReq.IdxReq.BucketSize, 0, types.NewFieldType(mysql.TypeBlob), statistics.Version1) var cms *statistics.CMSketch if analyzeReq.IdxReq.CmsketchDepth != nil && analyzeReq.IdxReq.CmsketchWidth != nil { cms = statistics.NewCMSketch(*analyzeReq.IdxReq.CmsketchDepth, *analyzeReq.IdxReq.CmsketchWidth) @@ -128,12 +133,12 @@ type analyzeColumnsExec struct { } func (h coprHandler) handleAnalyzeColumnsReq(req *coprocessor.Request, analyzeReq *tipb.AnalyzeReq) (_ *coprocessor.Response, err error) { - sc := flagsToStatementContext(analyzeReq.Flags) tz, err := timeutil.ConstructTimeZone("", int(analyzeReq.TimeZoneOffset)) if err != nil { return nil, errors.Trace(err) } - sc.SetTimeZone(tz) + + sc := flagsAndTzToStatementContext(analyzeReq.Flags, tz) evalCtx := &evalContext{sc: sc} columns := analyzeReq.ColReq.ColumnsInfo diff --git a/pkg/store/mockstore/mockcopr/cop_handler_dag.go b/pkg/store/mockstore/mockcopr/cop_handler_dag.go index ce459c4b44bb0..1b176602ce3fd 100644 --- a/pkg/store/mockstore/mockcopr/cop_handler_dag.go +++ b/pkg/store/mockstore/mockcopr/cop_handler_dag.go @@ -103,12 +103,11 @@ func (h coprHandler) buildDAGExecutor(req *coprocessor.Request) (*dagContext, ex return nil, nil, nil, errors.Trace(err) } - sc := flagsToStatementContext(dagReq.Flags) tz, err := timeutil.ConstructTimeZone(dagReq.TimeZoneName, int(dagReq.TimeZoneOffset)) if err != nil { return nil, nil, nil, errors.Trace(err) } - sc.SetTimeZone(tz) + sc := flagsAndTzToStatementContext(dagReq.Flags, tz) ctx := &dagContext{ dagReq: dagReq, @@ -128,13 +127,6 @@ func (h coprHandler) buildDAGExecutor(req *coprocessor.Request) (*dagContext, ex return ctx, e, dagReq, err } -// constructTimeZone constructs timezone by name first. When the timezone name -// is set, the daylight saving problem must be considered. Otherwise the -// timezone offset in seconds east of UTC is used to constructed the timezone. -func constructTimeZone(name string, offset int) (*time.Location, error) { - return timeutil.ConstructTimeZone(name, offset) -} - func (h coprHandler) buildExec(ctx *dagContext, curr *tipb.Executor) (executor, *tipb.Executor, error) { var currExec executor var err error @@ -466,18 +458,10 @@ func (e *evalContext) decodeRelatedColumnVals(relatedColOffsets []int, value [][ return nil } -// flagsToStatementContext creates a StatementContext from a `tipb.SelectRequest.Flags`. -func flagsToStatementContext(flags uint64) *stmtctx.StatementContext { - sc := stmtctx.NewStmtCtx() - sc.IgnoreTruncate.Store((flags & model.FlagIgnoreTruncate) > 0) - sc.TruncateAsWarning = (flags & model.FlagTruncateAsWarning) > 0 - sc.InInsertStmt = (flags & model.FlagInInsertStmt) > 0 - sc.InSelectStmt = (flags & model.FlagInSelectStmt) > 0 - sc.InDeleteStmt = (flags & model.FlagInUpdateOrDeleteStmt) > 0 - sc.OverflowAsWarning = (flags & model.FlagOverflowAsWarning) > 0 - sc.IgnoreZeroInDate = (flags & model.FlagIgnoreZeroInDate) > 0 - sc.DividedByZeroAsWarning = (flags & model.FlagDividedByZeroAsWarning) > 0 - // TODO set FlagInSetOprStmt, +// flagsAndTzToStatementContext creates a StatementContext from a `tipb.SelectRequest.Flags`. +func flagsAndTzToStatementContext(flags uint64, tz *time.Location) *stmtctx.StatementContext { + sc := new(stmtctx.StatementContext) + sc.InitFromPBFlagAndTz(flags, tz) return sc } diff --git a/pkg/store/mockstore/mockcopr/executor.go b/pkg/store/mockstore/mockcopr/executor.go index 9b86ca5bda76a..78d8687a17355 100644 --- a/pkg/store/mockstore/mockcopr/executor.go +++ b/pkg/store/mockstore/mockcopr/executor.go @@ -414,7 +414,7 @@ func evalBool(exprs []expression.Expression, row []types.Datum, ctx *stmtctx.Sta return false, nil } - isBool, err := data.ToBool(ctx) + isBool, err := data.ToBool(ctx.TypeCtx) isBool, err = expression.HandleOverflowOnSelection(ctx, isBool, err) if err != nil { return false, errors.Trace(err) diff --git a/pkg/store/mockstore/unistore/cophandler/BUILD.bazel b/pkg/store/mockstore/unistore/cophandler/BUILD.bazel index 16c2131d60083..2a6e691a3cbaa 100644 --- a/pkg/store/mockstore/unistore/cophandler/BUILD.bazel +++ b/pkg/store/mockstore/unistore/cophandler/BUILD.bazel @@ -76,6 +76,7 @@ go_test( "//pkg/util/codec", "//pkg/util/collate", "//pkg/util/rowcodec", + "//pkg/util/timeutil", "@com_github_pingcap_badger//:badger", "@com_github_pingcap_badger//y", "@com_github_pingcap_kvproto//pkg/coprocessor", diff --git a/pkg/store/mockstore/unistore/cophandler/analyze.go b/pkg/store/mockstore/unistore/cophandler/analyze.go index 770507c5fd35e..2b58c57d4ecb2 100644 --- a/pkg/store/mockstore/unistore/cophandler/analyze.go +++ b/pkg/store/mockstore/unistore/cophandler/analyze.go @@ -88,7 +88,9 @@ func handleAnalyzeIndexReq(dbReader *dbreader.DBReader, rans []kv.KeyRange, anal if analyzeReq.IdxReq.Version != nil { statsVer = *analyzeReq.IdxReq.Version } - sctx := flagsToStatementContext(analyzeReq.Flags) + + tz := time.FixedZone("UTC", int(analyzeReq.TimeZoneOffset)) + sctx := flagsAndTzToStatementContext(analyzeReq.Flags, tz) processor := &analyzeIndexProcessor{ sctx: sctx, colLen: int(analyzeReq.IdxReq.NumColumns), @@ -140,9 +142,11 @@ func handleAnalyzeCommonHandleReq(dbReader *dbreader.DBReader, rans []kv.KeyRang if analyzeReq.IdxReq.Version != nil { statsVer = int(*analyzeReq.IdxReq.Version) } + + tz := time.FixedZone("UTC", int(analyzeReq.TimeZoneOffset)) processor := &analyzeCommonHandleProcessor{ colLen: int(analyzeReq.IdxReq.NumColumns), - statsBuilder: statistics.NewSortedBuilder(flagsToStatementContext(analyzeReq.Flags), analyzeReq.IdxReq.BucketSize, 0, types.NewFieldType(mysql.TypeBlob), statsVer), + statsBuilder: statistics.NewSortedBuilder(flagsAndTzToStatementContext(analyzeReq.Flags, tz), analyzeReq.IdxReq.BucketSize, 0, types.NewFieldType(mysql.TypeBlob), statsVer), } if analyzeReq.IdxReq.CmsketchDepth != nil && analyzeReq.IdxReq.CmsketchWidth != nil { processor.cms = statistics.NewCMSketch(*analyzeReq.IdxReq.CmsketchDepth, *analyzeReq.IdxReq.CmsketchWidth) @@ -266,8 +270,8 @@ type analyzeColumnsExec struct { } func buildBaseAnalyzeColumnsExec(dbReader *dbreader.DBReader, rans []kv.KeyRange, analyzeReq *tipb.AnalyzeReq, startTS uint64) (*analyzeColumnsExec, *statistics.SampleBuilder, int64, error) { - sc := flagsToStatementContext(analyzeReq.Flags) - sc.SetTimeZone(time.FixedZone("UTC", int(analyzeReq.TimeZoneOffset))) + tz := time.FixedZone("UTC", int(analyzeReq.TimeZoneOffset)) + sc := flagsAndTzToStatementContext(analyzeReq.Flags, tz) evalCtx := &evalContext{sc: sc} columns := analyzeReq.ColReq.ColumnsInfo evalCtx.setColumnInfo(columns) @@ -372,8 +376,8 @@ func handleAnalyzeFullSamplingReq( analyzeReq *tipb.AnalyzeReq, startTS uint64, ) (*coprocessor.Response, error) { - sc := flagsToStatementContext(analyzeReq.Flags) - sc.SetTimeZone(time.FixedZone("UTC", int(analyzeReq.TimeZoneOffset))) + tz := time.FixedZone("UTC", int(analyzeReq.TimeZoneOffset)) + sc := flagsAndTzToStatementContext(analyzeReq.Flags, tz) evalCtx := &evalContext{sc: sc} columns := analyzeReq.ColReq.ColumnsInfo evalCtx.setColumnInfo(columns) @@ -527,7 +531,8 @@ func handleAnalyzeMixedReq(dbReader *dbreader.DBReader, rans []kv.KeyRange, anal if err != nil { return nil, err } - sctx := flagsToStatementContext(analyzeReq.Flags) + tz := time.FixedZone("UTC", int(analyzeReq.TimeZoneOffset)) + sctx := flagsAndTzToStatementContext(analyzeReq.Flags, tz) e := &analyzeMixedExec{ sctx: sctx, analyzeColumnsExec: *colExec, diff --git a/pkg/store/mockstore/unistore/cophandler/closure_exec.go b/pkg/store/mockstore/unistore/cophandler/closure_exec.go index 1ba3eaf9d51aa..94eda32c085ff 100644 --- a/pkg/store/mockstore/unistore/cophandler/closure_exec.go +++ b/pkg/store/mockstore/unistore/cophandler/closure_exec.go @@ -799,7 +799,7 @@ func (e *closureExecutor) processSelection(needCollectDetail bool) (gotRow bool, if d.IsNull() { gotRow = false } else { - isTrue, err := d.ToBool(e.sc) + isTrue, err := d.ToBool(e.sc.TypeCtx) isTrue, err = expression.HandleOverflowOnSelection(e.sc, isTrue, err) if err != nil { return false, errors.Trace(err) diff --git a/pkg/store/mockstore/unistore/cophandler/cop_handler.go b/pkg/store/mockstore/unistore/cophandler/cop_handler.go index c0e9d67842630..43daa0716978d 100644 --- a/pkg/store/mockstore/unistore/cophandler/cop_handler.go +++ b/pkg/store/mockstore/unistore/cophandler/cop_handler.go @@ -294,19 +294,19 @@ func buildDAG(reader *dbreader.DBReader, lockStore *lockstore.MemStore, req *cop if err != nil { return nil, nil, errors.Trace(err) } - sc := flagsToStatementContext(dagReq.Flags) + var tz *time.Location switch dagReq.TimeZoneName { case "": - sc.SetTimeZone(time.FixedZone("UTC", int(dagReq.TimeZoneOffset))) + tz = time.FixedZone("UTC", int(dagReq.TimeZoneOffset)) case "System": - sc.SetTimeZone(time.Local) + tz = time.Local default: - tz, err := time.LoadLocation(dagReq.TimeZoneName) + tz, err = time.LoadLocation(dagReq.TimeZoneName) if err != nil { return nil, nil, errors.Trace(err) } - sc.SetTimeZone(tz) } + sc := flagsAndTzToStatementContext(dagReq.Flags, tz) ctx := &dagContext{ evalContext: &evalContext{sc: sc}, dbReader: reader, @@ -421,17 +421,10 @@ func newRowDecoder(columnInfos []*tipb.ColumnInfo, fieldTps []*types.FieldType, return rowcodec.NewChunkDecoder(cols, pkCols, def, timeZone), nil } -// flagsToStatementContext creates a StatementContext from a `tipb.SelectRequest.Flags`. -func flagsToStatementContext(flags uint64) *stmtctx.StatementContext { - sc := stmtctx.NewStmtCtx() - sc.IgnoreTruncate.Store((flags & model.FlagIgnoreTruncate) > 0) - sc.TruncateAsWarning = (flags & model.FlagTruncateAsWarning) > 0 - sc.InInsertStmt = (flags & model.FlagInInsertStmt) > 0 - sc.InSelectStmt = (flags & model.FlagInSelectStmt) > 0 - sc.InDeleteStmt = (flags & model.FlagInUpdateOrDeleteStmt) > 0 - sc.OverflowAsWarning = (flags & model.FlagOverflowAsWarning) > 0 - sc.IgnoreZeroInDate = (flags & model.FlagIgnoreZeroInDate) > 0 - sc.DividedByZeroAsWarning = (flags & model.FlagDividedByZeroAsWarning) > 0 +// flagsAndTzToStatementContext creates a StatementContext from a `tipb.SelectRequest.Flags`. +func flagsAndTzToStatementContext(flags uint64, tz *time.Location) *stmtctx.StatementContext { + sc := new(stmtctx.StatementContext) + sc.InitFromPBFlagAndTz(flags, tz) return sc } diff --git a/pkg/store/mockstore/unistore/cophandler/cop_handler_test.go b/pkg/store/mockstore/unistore/cophandler/cop_handler_test.go index a3e7ff46811cd..41eb69bb3e283 100644 --- a/pkg/store/mockstore/unistore/cophandler/cop_handler_test.go +++ b/pkg/store/mockstore/unistore/cophandler/cop_handler_test.go @@ -38,6 +38,7 @@ import ( "github.com/pingcap/tidb/pkg/util/codec" "github.com/pingcap/tidb/pkg/util/collate" "github.com/pingcap/tidb/pkg/util/rowcodec" + "github.com/pingcap/tidb/pkg/util/timeutil" "github.com/pingcap/tipb/go-tipb" "github.com/stretchr/testify/require" ) @@ -177,8 +178,10 @@ func isPrefixNext(key []byte, expected []byte) bool { } // return a dag context according to dagReq and key ranges. -func newDagContext(store *testStore, keyRanges []kv.KeyRange, dagReq *tipb.DAGRequest, startTs uint64) *dagContext { - sc := flagsToStatementContext(dagReq.Flags) +func newDagContext(t require.TestingT, store *testStore, keyRanges []kv.KeyRange, dagReq *tipb.DAGRequest, startTs uint64) *dagContext { + tz, err := timeutil.ConstructTimeZone(dagReq.TimeZoneName, int(dagReq.TimeZoneOffset)) + require.NoError(t, err) + sc := flagsAndTzToStatementContext(dagReq.Flags, tz) txn := store.db.NewTransaction(false) dagCtx := &dagContext{ evalContext: &evalContext{sc: sc}, @@ -323,7 +326,7 @@ func TestPointGet(t *testing.T) { addTableScan(data.colInfos, tableID). setOutputOffsets([]uint32{0, 1}). build() - dagCtx := newDagContext(store, []kv.KeyRange{getTestPointRange(tableID, handle)}, + dagCtx := newDagContext(t, store, []kv.KeyRange{getTestPointRange(tableID, handle)}, dagRequest, dagRequestStartTs) chunks, rowCount, err := buildExecutorsAndExecute(dagCtx, dagRequest) require.Len(t, chunks, 0) @@ -337,7 +340,7 @@ func TestPointGet(t *testing.T) { addTableScan(data.colInfos, tableID). setOutputOffsets([]uint32{0, 1}). build() - dagCtx = newDagContext(store, []kv.KeyRange{getTestPointRange(tableID, handle)}, + dagCtx = newDagContext(t, store, []kv.KeyRange{getTestPointRange(tableID, handle)}, dagRequest, dagRequestStartTs) chunks, rowCount, err = buildExecutorsAndExecute(dagCtx, dagRequest) require.NoError(t, err) @@ -378,7 +381,7 @@ func TestClosureExecutor(t *testing.T) { setOutputOffsets([]uint32{0, 1}). build() - dagCtx := newDagContext(store, []kv.KeyRange{getTestPointRange(tableID, 1)}, + dagCtx := newDagContext(t, store, []kv.KeyRange{getTestPointRange(tableID, 1)}, dagRequest, dagRequestStartTs) _, rowCount, err := buildExecutorsAndExecute(dagCtx, dagRequest) require.NoError(t, err) @@ -407,7 +410,7 @@ func TestMppExecutor(t *testing.T) { setCollectRangeCounts(true). build() - dagCtx := newDagContext(store, []kv.KeyRange{getTestPointRange(tableID, 1)}, + dagCtx := newDagContext(t, store, []kv.KeyRange{getTestPointRange(tableID, 1)}, dagRequest, dagRequestStartTs) _, _, _, rowCount, _, err := buildAndRunMPPExecutor(dagCtx, dagRequest, 0) require.Equal(t, rowCount[0], int64(1)) @@ -576,6 +579,7 @@ func BenchmarkExecutors(b *testing.B) { build() dagCtx = newDagContext( + b, store, []kv.KeyRange{ { diff --git a/pkg/store/mockstore/unistore/cophandler/mpp.go b/pkg/store/mockstore/unistore/cophandler/mpp.go index b4f8ebfaae5bf..876301ad53972 100644 --- a/pkg/store/mockstore/unistore/cophandler/mpp.go +++ b/pkg/store/mockstore/unistore/cophandler/mpp.go @@ -34,6 +34,7 @@ import ( "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" "github.com/pingcap/tidb/pkg/util/rowcodec" + "github.com/pingcap/tidb/pkg/util/timeutil" "github.com/pingcap/tipb/go-tipb" "go.uber.org/atomic" ) @@ -573,10 +574,11 @@ func HandleMPPDAGReq(dbReader *dbreader.DBReader, req *coprocessor.Request, mppC startTS: req.StartTs, keyRanges: req.Ranges, } + tz, err := timeutil.ConstructTimeZone(dagReq.TimeZoneName, int(dagReq.TimeZoneOffset)) builder := mppExecBuilder{ dbReader: dbReader, mppCtx: mppCtx, - sc: flagsToStatementContext(dagReq.Flags), + sc: flagsAndTzToStatementContext(dagReq.Flags, tz), dagReq: dagReq, dagCtx: dagCtx, } diff --git a/pkg/store/mockstore/unistore/cophandler/mpp_exec.go b/pkg/store/mockstore/unistore/cophandler/mpp_exec.go index 6ec94bcd43f87..12fc125b295bd 100644 --- a/pkg/store/mockstore/unistore/cophandler/mpp_exec.go +++ b/pkg/store/mockstore/unistore/cophandler/mpp_exec.go @@ -1133,7 +1133,7 @@ func (e *selExec) next() (*chunk.Chunk, error) { if d.IsNull() { passCheck = false } else { - isBool, err := d.ToBool(e.sc) + isBool, err := d.ToBool(e.sc.TypeCtx) if err != nil { return nil, errors.Trace(err) } diff --git a/pkg/table/column.go b/pkg/table/column.go index 15be521014ca8..d021732152b96 100644 --- a/pkg/table/column.go +++ b/pkg/table/column.go @@ -322,7 +322,7 @@ func CastValue(ctx sessionctx.Context, val types.Datum, col *model.ColumnInfo, r zap.Uint64("conn", ctx.GetSessionVars().ConnectionID), zap.Error(err)) } - err = sc.HandleTruncate(err) + err = sc.TypeCtx.HandleTruncate(err) err = sc.HandleOverflow(err, err) if forceIgnoreTruncate { diff --git a/pkg/tablecodec/tablecodec.go b/pkg/tablecodec/tablecodec.go index 8cfd7cfb112ce..1c3462b764c47 100644 --- a/pkg/tablecodec/tablecodec.go +++ b/pkg/tablecodec/tablecodec.go @@ -398,7 +398,7 @@ func flatten(sc *stmtctx.StatementContext, data types.Datum, ret *types.Datum) e return nil case types.KindBinaryLiteral, types.KindMysqlBit: // We don't need to handle errors here since the literal is ensured to be able to store in uint64 in convertToMysqlBit. - val, err := data.GetBinaryLiteral().ToInt(sc) + val, err := data.GetBinaryLiteral().ToInt(sc.TypeCtx) if err != nil { return errors.Trace(err) } diff --git a/pkg/types/binary_literal.go b/pkg/types/binary_literal.go index 5b36b1a013ab1..adb593f24c834 100644 --- a/pkg/types/binary_literal.go +++ b/pkg/types/binary_literal.go @@ -24,7 +24,6 @@ import ( "strings" "github.com/pingcap/errors" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" ) // BinaryLiteral is the internal type for storing bit / hex literal type. @@ -102,7 +101,7 @@ func (b BinaryLiteral) ToBitLiteralString(trimLeadingZero bool) string { } // ToInt returns the int value for the literal. -func (b BinaryLiteral) ToInt(sc *stmtctx.StatementContext) (uint64, error) { +func (b BinaryLiteral) ToInt(ctx Context) (uint64, error) { buf := trimLeadingZeroBytes(b) length := len(buf) if length == 0 { @@ -110,9 +109,7 @@ func (b BinaryLiteral) ToInt(sc *stmtctx.StatementContext) (uint64, error) { } if length > 8 { var err = ErrTruncatedWrongVal.FastGenByArgs("BINARY", b) - if sc != nil { - err = sc.HandleTruncate(err) - } + err = ctx.HandleTruncate(err) return math.MaxUint64, err } // Note: the byte-order is BigEndian. diff --git a/pkg/types/binary_literal_test.go b/pkg/types/binary_literal_test.go index a404818ab6478..44ea4ad7d070f 100644 --- a/pkg/types/binary_literal_test.go +++ b/pkg/types/binary_literal_test.go @@ -18,7 +18,6 @@ import ( "fmt" "testing" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" "github.com/stretchr/testify/require" ) @@ -207,11 +206,11 @@ func TestBinaryLiteral(t *testing.T) { {"0x1010ffff8080ff12", 0x1010ffff8080ff12, false}, {"0x1010ffff8080ff12ff", 0xffffffffffffffff, true}, } - sc := stmtctx.NewStmtCtx() + ctx := DefaultNoWarningContext for _, item := range tbl { hex, err := ParseHexStr(item.Input) require.NoError(t, err) - intValue, err := hex.ToInt(sc) + intValue, err := hex.ToInt(ctx) if item.HasError { require.Error(t, err) } else { diff --git a/pkg/types/compare_test.go b/pkg/types/compare_test.go index 368d8c4914e30..b5d76497ec05e 100644 --- a/pkg/types/compare_test.go +++ b/pkg/types/compare_test.go @@ -147,7 +147,7 @@ func TestCompare(t *testing.T) { func compareForTest(a, b interface{}) (int, error) { sc := stmtctx.NewStmtCtx() - sc.IgnoreTruncate.Store(true) + sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(true)) aDatum := NewDatum(a) bDatum := NewDatum(b) return aDatum.Compare(sc, &bDatum, collate.GetBinaryCollator()) @@ -169,7 +169,7 @@ func TestCompareDatum(t *testing.T) { {MinNotNullDatum(), MaxValueDatum(), -1}, } sc := stmtctx.NewStmtCtx() - sc.IgnoreTruncate.Store(true) + sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(true)) for i, tt := range cmpTbl { ret, err := tt.lhs.Compare(sc, &tt.rhs, collate.GetBinaryCollator()) require.NoError(t, err) diff --git a/pkg/types/context.go b/pkg/types/context.go index 854110c5cca90..0182804aff83f 100644 --- a/pkg/types/context.go +++ b/pkg/types/context.go @@ -14,7 +14,9 @@ package types -import "github.com/pingcap/tidb/pkg/types/context" +import ( + "github.com/pingcap/tidb/pkg/types/context" +) // TODO: move a contents in `types/context/context.go` to this file after refactor finished. // Because package `types` has a dependency on `sessionctx/stmtctx`, we need a separate package `type/context` to define @@ -31,3 +33,6 @@ const StrictFlags = context.StrictFlags // NewContext creates a new `Context` var NewContext = context.NewContext + +// DefaultNoWarningContext is an alias of `DefaultNoWarningContext` +var DefaultNoWarningContext = context.DefaultNoWarningContext diff --git a/pkg/types/context/BUILD.bazel b/pkg/types/context/BUILD.bazel index f000bb3415639..5d9a1d17e6e49 100644 --- a/pkg/types/context/BUILD.bazel +++ b/pkg/types/context/BUILD.bazel @@ -2,10 +2,17 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "context", - srcs = ["context.go"], + srcs = [ + "context.go", + "truncate.go", + ], importpath = "github.com/pingcap/tidb/pkg/types/context", visibility = ["//visibility:public"], - deps = ["//pkg/util/intest"], + deps = [ + "//pkg/errno", + "//pkg/util/intest", + "@com_github_pingcap_errors//:errors", + ], ) go_test( diff --git a/pkg/types/context/context.go b/pkg/types/context/context.go index ec7e17438ed48..41e62ad1482c3 100644 --- a/pkg/types/context/context.go +++ b/pkg/types/context/context.go @@ -104,6 +104,32 @@ func (f Flags) WithSkipUTF8MB4Check(skip bool) Flags { return f &^ FlagSkipUTF8MB4Check } +// IgnoreTruncateErr indicates whether the flag `FlagIgnoreTruncateErr` is set +func (f Flags) IgnoreTruncateErr() bool { + return f&FlagIgnoreTruncateErr != 0 +} + +// WithIgnoreTruncateErr returns a new flags with `FlagIgnoreTruncateErr` set/unset according to the skip parameter +func (f Flags) WithIgnoreTruncateErr(ignore bool) Flags { + if ignore { + return f | FlagIgnoreTruncateErr + } + return f &^ FlagIgnoreTruncateErr +} + +// TruncateAsWarning indicates whether the flag `FlagTruncateAsWarning` is set +func (f Flags) TruncateAsWarning() bool { + return f&FlagTruncateAsWarning != 0 +} + +// WithTruncateAsWarning returns a new flags with `FlagTruncateAsWarning` set/unset according to the skip parameter +func (f Flags) WithTruncateAsWarning(warn bool) Flags { + if warn { + return f | FlagTruncateAsWarning + } + return f &^ FlagTruncateAsWarning +} + // Context provides the information when converting between different types. type Context struct { flags Flags @@ -164,3 +190,8 @@ func (c *Context) AppendWarning(err error) { func (c *Context) AppendWarningFunc() func(err error) { return c.appendWarningFn } + +// DefaultNoWarningContext is the context without any special configuration +var DefaultNoWarningContext = NewContext(StrictFlags, time.UTC, func(_ error) { + // the error is ignored +}) diff --git a/pkg/types/context/truncate.go b/pkg/types/context/truncate.go new file mode 100644 index 0000000000000..271c8ed4b1d16 --- /dev/null +++ b/pkg/types/context/truncate.go @@ -0,0 +1,53 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "github.com/pingcap/errors" + "github.com/pingcap/tidb/pkg/errno" +) + +// HandleTruncate ignores or returns the error based on the Context state. +func (c *Context) HandleTruncate(err error) error { + // TODO: At present we have not checked whether the error can be ignored or treated as warning. + // We will do that later, and then append WarnDataTruncated instead of the error itself. + if err == nil { + return nil + } + + err = errors.Cause(err) + if e, ok := err.(*errors.Error); !ok || + (e.Code() != errno.ErrTruncatedWrongValue && + e.Code() != errno.ErrDataTooLong && + e.Code() != errno.ErrTruncatedWrongValueForField && + e.Code() != errno.ErrWarnDataOutOfRange && + e.Code() != errno.ErrDataOutOfRange && + e.Code() != errno.ErrBadNumber && + e.Code() != errno.ErrWrongValueForType && + e.Code() != errno.ErrDatetimeFunctionOverflow && + e.Code() != errno.WarnDataTruncated && + e.Code() != errno.ErrIncorrectDatetimeValue) { + return err + } + + if c.Flags().IgnoreTruncateErr() { + return nil + } + if c.Flags().TruncateAsWarning() { + c.AppendWarning(err) + return nil + } + return err +} diff --git a/pkg/types/convert.go b/pkg/types/convert.go index 874b6c788f392..816c494ec4e70 100644 --- a/pkg/types/convert.go +++ b/pkg/types/convert.go @@ -276,9 +276,9 @@ func ConvertDecimalToUint(sc *stmtctx.StatementContext, d *MyDecimal, upperBound } // StrToInt converts a string to an integer at the best-effort. -func StrToInt(sc *stmtctx.StatementContext, str string, isFuncCast bool) (int64, error) { +func StrToInt(ctx Context, str string, isFuncCast bool) (int64, error) { str = strings.TrimSpace(str) - validPrefix, err := getValidIntPrefix(sc, str, isFuncCast) + validPrefix, err := getValidIntPrefix(ctx, str, isFuncCast) iVal, err1 := strconv.ParseInt(validPrefix, 10, 64) if err1 != nil { return iVal, ErrOverflow.GenWithStackByArgs("BIGINT", validPrefix) @@ -287,9 +287,9 @@ func StrToInt(sc *stmtctx.StatementContext, str string, isFuncCast bool) (int64, } // StrToUint converts a string to an unsigned integer at the best-effort. -func StrToUint(sc *stmtctx.StatementContext, str string, isFuncCast bool) (uint64, error) { +func StrToUint(ctx Context, str string, isFuncCast bool) (uint64, error) { str = strings.TrimSpace(str) - validPrefix, err := getValidIntPrefix(sc, str, isFuncCast) + validPrefix, err := getValidIntPrefix(ctx, str, isFuncCast) uVal := uint64(0) hasParseErr := false @@ -344,7 +344,7 @@ func StrToDuration(sc *stmtctx.StatementContext, str string, fsp int) (d Duratio d, _, err = ParseDuration(sc, str, fsp) if ErrTruncatedWrongVal.Equal(err) { - err = sc.HandleTruncate(err) + err = sc.TypeCtx.HandleTruncate(err) } return d, t, true, errors.Trace(err) } @@ -382,13 +382,13 @@ func NumberToDuration(number int64, fsp int) (Duration, error) { } // getValidIntPrefix gets prefix of the string which can be successfully parsed as int. -func getValidIntPrefix(sc *stmtctx.StatementContext, str string, isFuncCast bool) (string, error) { +func getValidIntPrefix(ctx Context, str string, isFuncCast bool) (string, error) { if !isFuncCast { - floatPrefix, err := getValidFloatPrefix(sc, str, isFuncCast) + floatPrefix, err := getValidFloatPrefix(ctx, str, isFuncCast) if err != nil { return floatPrefix, errors.Trace(err) } - return floatStrToIntStr(sc, floatPrefix, str) + return floatStrToIntStr(ctx, floatPrefix, str) } validLen := 0 @@ -411,7 +411,7 @@ func getValidIntPrefix(sc *stmtctx.StatementContext, str string, isFuncCast bool valid = "0" } if validLen == 0 || validLen != len(str) { - return valid, errors.Trace(sc.HandleTruncate(ErrTruncatedWrongVal.GenWithStackByArgs("INTEGER", str))) + return valid, errors.Trace(ctx.HandleTruncate(ErrTruncatedWrongVal.GenWithStackByArgs("INTEGER", str))) } return valid, nil } @@ -451,7 +451,7 @@ func roundIntStr(numNextDot byte, intStr string) string { // // This func will find serious overflow such as the len of intStr > 20 (without prefix `+/-`) // however, it will not check whether the intStr overflow BIGINT. -func floatStrToIntStr(sc *stmtctx.StatementContext, validFloat string, oriStr string) (intStr string, _ error) { +func floatStrToIntStr(ctx Context, validFloat string, oriStr string) (intStr string, _ error) { var dotIdx = -1 var eIdx = -1 for i := 0; i < len(validFloat); i++ { @@ -507,7 +507,7 @@ func floatStrToIntStr(sc *stmtctx.StatementContext, validFloat string, oriStr st // MaxUint64 has 20 decimal digits. // And the intCnt may contain the len of `+/-`, // so I use 21 here as the early detection. - sc.AppendWarning(ErrOverflow.GenWithStackByArgs("BIGINT", oriStr)) + ctx.AppendWarning(ErrOverflow.GenWithStackByArgs("BIGINT", oriStr)) return validFloat[:eIdx], nil } if intCnt <= 0 { @@ -541,15 +541,15 @@ func floatStrToIntStr(sc *stmtctx.StatementContext, validFloat string, oriStr st } // StrToFloat converts a string to a float64 at the best-effort. -func StrToFloat(sc *stmtctx.StatementContext, str string, isFuncCast bool) (float64, error) { +func StrToFloat(ctx Context, str string, isFuncCast bool) (float64, error) { str = strings.TrimSpace(str) - validStr, err := getValidFloatPrefix(sc, str, isFuncCast) + validStr, err := getValidFloatPrefix(ctx, str, isFuncCast) f, err1 := strconv.ParseFloat(validStr, 64) if err1 != nil { if err2, ok := err1.(*strconv.NumError); ok { // value will truncate to MAX/MIN if out of range. if err2.Err == strconv.ErrRange { - err1 = sc.HandleTruncate(ErrTruncatedWrongVal.GenWithStackByArgs("DOUBLE", str)) + err1 = ctx.HandleTruncate(ErrTruncatedWrongVal.GenWithStackByArgs("DOUBLE", str)) if math.IsInf(f, 1) { f = math.MaxFloat64 } else if math.IsInf(f, -1) { @@ -571,13 +571,13 @@ func ConvertJSONToInt64(sc *stmtctx.StatementContext, j BinaryJSON, unsigned boo func ConvertJSONToInt(sc *stmtctx.StatementContext, j BinaryJSON, unsigned bool, tp byte) (int64, error) { switch j.TypeCode { case JSONTypeCodeObject, JSONTypeCodeArray, JSONTypeCodeOpaque, JSONTypeCodeDate, JSONTypeCodeDatetime, JSONTypeCodeTimestamp, JSONTypeCodeDuration: - return 0, sc.HandleTruncate(ErrTruncatedWrongVal.GenWithStackByArgs("INTEGER", j.String())) + return 0, sc.TypeCtx.HandleTruncate(ErrTruncatedWrongVal.GenWithStackByArgs("INTEGER", j.String())) case JSONTypeCodeLiteral: switch j.Value[0] { case JSONLiteralFalse: return 0, nil case JSONLiteralNil: - return 0, sc.HandleTruncate(ErrTruncatedWrongVal.GenWithStackByArgs("INTEGER", j.String())) + return 0, sc.TypeCtx.HandleTruncate(ErrTruncatedWrongVal.GenWithStackByArgs("INTEGER", j.String())) default: return 1, nil } @@ -618,26 +618,26 @@ func ConvertJSONToInt(sc *stmtctx.StatementContext, j BinaryJSON, unsigned bool, case JSONTypeCodeString: str := string(hack.String(j.GetString())) if !unsigned { - r, e := StrToInt(sc, str, false) + r, e := StrToInt(sc.TypeCtxOrDefault(), str, false) return r, sc.HandleOverflow(e, e) } - u, err := StrToUint(sc, str, false) + u, err := StrToUint(sc.TypeCtxOrDefault(), str, false) return int64(u), sc.HandleOverflow(err, err) } return 0, errors.New("Unknown type code in JSON") } // ConvertJSONToFloat casts JSON into float64. -func ConvertJSONToFloat(sc *stmtctx.StatementContext, j BinaryJSON) (float64, error) { +func ConvertJSONToFloat(ctx Context, j BinaryJSON) (float64, error) { switch j.TypeCode { case JSONTypeCodeObject, JSONTypeCodeArray, JSONTypeCodeOpaque, JSONTypeCodeDate, JSONTypeCodeDatetime, JSONTypeCodeTimestamp, JSONTypeCodeDuration: - return 0, sc.HandleTruncate(ErrTruncatedWrongVal.GenWithStackByArgs("FLOAT", j.String())) + return 0, ctx.HandleTruncate(ErrTruncatedWrongVal.GenWithStackByArgs("FLOAT", j.String())) case JSONTypeCodeLiteral: switch j.Value[0] { case JSONLiteralFalse: return 0, nil case JSONLiteralNil: - return 0, sc.HandleTruncate(ErrTruncatedWrongVal.GenWithStackByArgs("FLOAT", j.String())) + return 0, ctx.HandleTruncate(ErrTruncatedWrongVal.GenWithStackByArgs("FLOAT", j.String())) default: return 1, nil } @@ -649,13 +649,13 @@ func ConvertJSONToFloat(sc *stmtctx.StatementContext, j BinaryJSON) (float64, er return j.GetFloat64(), nil case JSONTypeCodeString: str := string(hack.String(j.GetString())) - return StrToFloat(sc, str, false) + return StrToFloat(ctx, str, false) } return 0, errors.New("Unknown type code in JSON") } // ConvertJSONToDecimal casts JSON into decimal. -func ConvertJSONToDecimal(sc *stmtctx.StatementContext, j BinaryJSON) (*MyDecimal, error) { +func ConvertJSONToDecimal(ctx Context, j BinaryJSON) (*MyDecimal, error) { var err error = nil res := new(MyDecimal) switch j.TypeCode { @@ -679,7 +679,7 @@ func ConvertJSONToDecimal(sc *stmtctx.StatementContext, j BinaryJSON) (*MyDecima case JSONTypeCodeString: err = res.FromString(j.GetString()) } - err = sc.HandleTruncate(err) + err = ctx.HandleTruncate(err) if err != nil { return res, errors.Trace(err) } @@ -687,7 +687,7 @@ func ConvertJSONToDecimal(sc *stmtctx.StatementContext, j BinaryJSON) (*MyDecima } // getValidFloatPrefix gets prefix of string which can be successfully parsed as float. -func getValidFloatPrefix(sc *stmtctx.StatementContext, s string, isFuncCast bool) (valid string, err error) { +func getValidFloatPrefix(ctx Context, s string, isFuncCast bool) (valid string, err error) { if isFuncCast && s == "" { return "0", nil } @@ -735,7 +735,7 @@ func getValidFloatPrefix(sc *stmtctx.StatementContext, s string, isFuncCast bool valid = "0" } if validLen == 0 || validLen != len(s) { - err = errors.Trace(sc.HandleTruncate(ErrTruncatedWrongVal.GenWithStackByArgs("DOUBLE", s))) + err = errors.Trace(ctx.HandleTruncate(ErrTruncatedWrongVal.GenWithStackByArgs("DOUBLE", s))) } return valid, err } diff --git a/pkg/types/convert_test.go b/pkg/types/convert_test.go index b255e908a0f62..23d5883964e50 100644 --- a/pkg/types/convert_test.go +++ b/pkg/types/convert_test.go @@ -248,11 +248,11 @@ func TestConvertType(t *testing.T) { // Test Datum.ToDecimal with bad number. d := NewDatum("hello") - _, err = d.ToDecimal(sc) + _, err = d.ToDecimal(sc.TypeCtxOrDefault()) require.Truef(t, terror.ErrorEqual(err, ErrTruncatedWrongVal), "err %v", err) - sc.IgnoreTruncate.Store(true) - v, err = d.ToDecimal(sc) + sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(true)) + v, err = d.ToDecimal(sc.TypeCtxOrDefault()) require.NoError(t, err) require.Equal(t, "0", v.(*MyDecimal).String()) @@ -421,7 +421,7 @@ func TestConvertToStringWithCheck(t *testing.T) { ft.SetCharset(tt.outputChs) inputDatum := NewStringDatum(tt.input) sc := stmtctx.NewStmtCtx() - flags := tt.newFlags(sc.TypeCtx.Flags()) + flags := tt.newFlags(sc.TypeFlags()) sc.SetTypeFlags(flags) outputDatum, err := inputDatum.ConvertTo(sc, ft) if len(tt.output) == 0 { @@ -472,9 +472,8 @@ func TestConvertToBinaryString(t *testing.T) { } func testStrToInt(t *testing.T, str string, expect int64, truncateAsErr bool, expectErr error) { - sc := stmtctx.NewStmtCtx() - sc.IgnoreTruncate.Store(!truncateAsErr) - val, err := StrToInt(sc, str, false) + ctx := DefaultNoWarningContext.WithFlags(StrictFlags.WithIgnoreTruncateErr(!truncateAsErr)) + val, err := StrToInt(ctx, str, false) if expectErr != nil { require.Truef(t, terror.ErrorEqual(err, expectErr), "err %v", err) } else { @@ -484,9 +483,8 @@ func testStrToInt(t *testing.T, str string, expect int64, truncateAsErr bool, ex } func testStrToUint(t *testing.T, str string, expect uint64, truncateAsErr bool, expectErr error) { - sc := stmtctx.NewStmtCtx() - sc.IgnoreTruncate.Store(!truncateAsErr) - val, err := StrToUint(sc, str, false) + ctx := DefaultNoWarningContext.WithFlags(StrictFlags.WithIgnoreTruncateErr(!truncateAsErr)) + val, err := StrToUint(ctx, str, false) if expectErr != nil { require.Truef(t, terror.ErrorEqual(err, expectErr), "err %v", err) } else { @@ -496,9 +494,8 @@ func testStrToUint(t *testing.T, str string, expect uint64, truncateAsErr bool, } func testStrToFloat(t *testing.T, str string, expect float64, truncateAsErr bool, expectErr error) { - sc := stmtctx.NewStmtCtx() - sc.IgnoreTruncate.Store(!truncateAsErr) - val, err := StrToFloat(sc, str, false) + ctx := DefaultNoWarningContext.WithFlags(StrictFlags.WithIgnoreTruncateErr(!truncateAsErr)) + val, err := StrToFloat(ctx, str, false) if expectErr != nil { require.Truef(t, terror.ErrorEqual(err, expectErr), "err %v", err) } else { @@ -566,7 +563,7 @@ func testSelectUpdateDeleteEmptyStringError(t *testing.T) { {false, true}, } sc := stmtctx.NewStmtCtx() - sc.TruncateAsWarning = true + sc.SetTypeFlags(sc.TypeFlags().WithTruncateAsWarning(true)) for _, tc := range testCases { sc.InSelectStmt = tc.inSelect sc.InDeleteStmt = tc.inDelete @@ -574,15 +571,15 @@ func testSelectUpdateDeleteEmptyStringError(t *testing.T) { str := "" expect := 0 - val, err := StrToInt(sc, str, false) + val, err := StrToInt(sc.TypeCtxOrDefault(), str, false) require.NoError(t, err) require.Equal(t, int64(expect), val) - val1, err := StrToUint(sc, str, false) + val1, err := StrToUint(sc.TypeCtxOrDefault(), str, false) require.NoError(t, err) require.Equal(t, uint64(expect), val1) - val2, err := StrToFloat(sc, str, false) + val2, err := StrToFloat(sc.TypeCtxOrDefault(), str, false) require.NoError(t, err) require.Equal(t, float64(expect), val2) } @@ -605,7 +602,7 @@ func accept(t *testing.T, tp byte, value interface{}, unsigned bool, expected st d := NewDatum(value) sc := stmtctx.NewStmtCtx() sc.SetTimeZone(time.UTC) - sc.IgnoreTruncate.Store(true) + sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(true)) casted, err := d.ConvertTo(sc, ft) require.NoErrorf(t, err, "%v", ft) if casted.IsNull() { @@ -887,11 +884,11 @@ func TestGetValidInt(t *testing.T) { {"123de", "123", true, true}, } sc := stmtctx.NewStmtCtx() - sc.TruncateAsWarning = true + sc.SetTypeFlags(sc.TypeFlags().WithTruncateAsWarning(true)) sc.InSelectStmt = true warningCount := 0 for i, tt := range tests { - prefix, err := getValidIntPrefix(sc, tt.origin, false) + prefix, err := getValidIntPrefix(sc.TypeCtxOrDefault(), tt.origin, false) require.NoError(t, err) require.Equal(t, tt.valid, prefix) if tt.signed { @@ -930,10 +927,10 @@ func TestGetValidInt(t *testing.T) { {"123e+", "123", true}, {"123de", "123", true}, } - sc.TruncateAsWarning = false + sc.SetTypeFlags(StrictFlags) sc.InSelectStmt = false for _, tt := range tests2 { - prefix, err := getValidIntPrefix(sc, tt.origin, false) + prefix, err := getValidIntPrefix(sc.TypeCtxOrDefault(), tt.origin, false) if tt.warning { require.True(t, terror.ErrorEqual(err, ErrTruncatedWrongVal)) } else { @@ -966,9 +963,9 @@ func TestGetValidFloat(t *testing.T) { {"9-3", "9"}, {"1001001\\u0000\\u0000\\u0000", "1001001"}, } - sc := stmtctx.NewStmtCtx() + ctx := DefaultNoWarningContext for _, tt := range tests { - prefix, _ := getValidFloatPrefix(sc, tt.origin, false) + prefix, _ := getValidFloatPrefix(ctx, tt.origin, false) require.Equal(t, tt.valid, prefix) _, err := strconv.ParseFloat(prefix, 64) require.NoError(t, err) @@ -994,7 +991,7 @@ func TestGetValidFloat(t *testing.T) { {"+999.9999e2", "+100000"}, } for _, tt := range tests2 { - str, err := floatStrToIntStr(sc, tt.origin, tt.origin) + str, err := floatStrToIntStr(ctx, tt.origin, tt.origin) require.NoError(t, err) require.Equalf(t, tt.expected, str, "%v, %v", tt.origin, tt.expected) } @@ -1108,10 +1105,11 @@ func TestConvertJSONToFloat(t *testing.T) { {in: "123.456hello", out: 123.456, ty: JSONTypeCodeString, err: true}, {in: "1234", out: 1234, ty: JSONTypeCodeString}, } + ctx := DefaultNoWarningContext for _, tt := range tests { j := CreateBinaryJSON(tt.in) require.Equal(t, tt.ty, j.TypeCode) - casted, err := ConvertJSONToFloat(stmtctx.NewStmtCtx(), j) + casted, err := ConvertJSONToFloat(ctx, j) if tt.err { require.Error(t, err, tt) } else { @@ -1136,10 +1134,11 @@ func TestConvertJSONToDecimal(t *testing.T) { {in: `false`, out: NewDecFromStringForTest("0")}, {in: `null`, out: NewDecFromStringForTest("0"), err: true}, } + ctx := DefaultNoWarningContext for _, tt := range tests { j, err := ParseBinaryJSONFromString(tt.in) require.NoError(t, err) - casted, err := ConvertJSONToDecimal(stmtctx.NewStmtCtx(), j) + casted, err := ConvertJSONToDecimal(ctx, j) errMsg := fmt.Sprintf("input: %v, casted: %v, out: %v, json: %#v", tt.in, casted, tt.out, j) if tt.err { require.Error(t, err, errMsg) diff --git a/pkg/types/datum.go b/pkg/types/datum.go index 034b07add7e53..13a3bfa59a584 100644 --- a/pkg/types/datum.go +++ b/pkg/types/datum.go @@ -631,6 +631,10 @@ func (d *Datum) SetValue(val interface{}, tp *types.FieldType) { // Compare compares datum to another datum. // Notes: don't rely on datum.collation to get the collator, it's tend to buggy. func (d *Datum) Compare(sc *stmtctx.StatementContext, ad *Datum, comparer collate.Collator) (int, error) { + typeCtx := DefaultNoWarningContext + if sc != nil { + typeCtx = sc.TypeCtx + } if d.k == KindMysqlJSON && ad.k != KindMysqlJSON { cmp, err := ad.Compare(sc, d, comparer) return cmp * -1, errors.Trace(err) @@ -654,11 +658,11 @@ func (d *Datum) Compare(sc *stmtctx.StatementContext, ad *Datum, comparer collat } return -1, nil case KindInt64: - return d.compareInt64(sc, ad.GetInt64()) + return d.compareInt64(typeCtx, ad.GetInt64()) case KindUint64: - return d.compareUint64(sc, ad.GetUint64()) + return d.compareUint64(typeCtx, ad.GetUint64()) case KindFloat32, KindFloat64: - return d.compareFloat64(sc, ad.GetFloat64()) + return d.compareFloat64(typeCtx, ad.GetFloat64()) case KindString: return d.compareString(sc, ad.GetString(), comparer) case KindBytes: @@ -668,11 +672,11 @@ func (d *Datum) Compare(sc *stmtctx.StatementContext, ad *Datum, comparer collat case KindMysqlDuration: return d.compareMysqlDuration(sc, ad.GetMysqlDuration()) case KindMysqlEnum: - return d.compareMysqlEnum(sc, ad.GetMysqlEnum(), comparer) + return d.compareMysqlEnum(typeCtx, ad.GetMysqlEnum(), comparer) case KindBinaryLiteral, KindMysqlBit: - return d.compareBinaryLiteral(sc, ad.GetBinaryLiteral4Cmp(), comparer) + return d.compareBinaryLiteral(typeCtx, ad.GetBinaryLiteral4Cmp(), comparer) case KindMysqlSet: - return d.compareMysqlSet(sc, ad.GetMysqlSet(), comparer) + return d.compareMysqlSet(typeCtx, ad.GetMysqlSet(), comparer) case KindMysqlJSON: return d.compareMysqlJSON(sc, ad.GetMysqlJSON()) case KindMysqlTime: @@ -682,7 +686,7 @@ func (d *Datum) Compare(sc *stmtctx.StatementContext, ad *Datum, comparer collat } } -func (d *Datum) compareInt64(sc *stmtctx.StatementContext, i int64) (int, error) { +func (d *Datum) compareInt64(ctx Context, i int64) (int, error) { switch d.k { case KindMaxValue: return 1, nil @@ -694,11 +698,11 @@ func (d *Datum) compareInt64(sc *stmtctx.StatementContext, i int64) (int, error) } return cmp.Compare(d.i, i), nil default: - return d.compareFloat64(sc, float64(i)) + return d.compareFloat64(ctx, float64(i)) } } -func (d *Datum) compareUint64(sc *stmtctx.StatementContext, u uint64) (int, error) { +func (d *Datum) compareUint64(ctx Context, u uint64) (int, error) { switch d.k { case KindMaxValue: return 1, nil @@ -710,11 +714,11 @@ func (d *Datum) compareUint64(sc *stmtctx.StatementContext, u uint64) (int, erro case KindUint64: return cmp.Compare(d.GetUint64(), u), nil default: - return d.compareFloat64(sc, float64(u)) + return d.compareFloat64(ctx, float64(u)) } } -func (d *Datum) compareFloat64(sc *stmtctx.StatementContext, f float64) (int, error) { +func (d *Datum) compareFloat64(ctx Context, f float64) (int, error) { switch d.k { case KindNull, KindMinNotNull: return -1, nil @@ -727,7 +731,7 @@ func (d *Datum) compareFloat64(sc *stmtctx.StatementContext, f float64) (int, er case KindFloat32, KindFloat64: return cmp.Compare(d.GetFloat64(), f), nil case KindString, KindBytes: - fVal, err := StrToFloat(sc, d.GetString(), false) + fVal, err := StrToFloat(ctx, d.GetString(), false) return cmp.Compare(fVal, f), errors.Trace(err) case KindMysqlDecimal: fVal, err := d.GetMysqlDecimal().ToFloat64() @@ -739,7 +743,7 @@ func (d *Datum) compareFloat64(sc *stmtctx.StatementContext, f float64) (int, er fVal := d.GetMysqlEnum().ToNumber() return cmp.Compare(fVal, f), nil case KindBinaryLiteral, KindMysqlBit: - val, err := d.GetBinaryLiteral4Cmp().ToInt(sc) + val, err := d.GetBinaryLiteral4Cmp().ToInt(ctx) fVal := float64(val) return cmp.Compare(fVal, f), errors.Trace(err) case KindMysqlSet: @@ -763,7 +767,7 @@ func (d *Datum) compareString(sc *stmtctx.StatementContext, s string, comparer c return comparer.Compare(d.GetString(), s), nil case KindMysqlDecimal: dec := new(MyDecimal) - err := sc.HandleTruncate(dec.FromString(hack.Slice(s))) + err := sc.TypeCtx.HandleTruncate(dec.FromString(hack.Slice(s))) return d.GetMysqlDecimal().Compare(dec), errors.Trace(err) case KindMysqlTime: dt, err := ParseDatetime(sc, s) @@ -778,11 +782,11 @@ func (d *Datum) compareString(sc *stmtctx.StatementContext, s string, comparer c case KindBinaryLiteral, KindMysqlBit: return comparer.Compare(d.GetBinaryLiteral4Cmp().ToString(), s), nil default: - fVal, err := StrToFloat(sc, s, false) + fVal, err := StrToFloat(sc.TypeCtxOrDefault(), s, false) if err != nil { return 0, errors.Trace(err) } - return d.compareFloat64(sc, fVal) + return d.compareFloat64(sc.TypeCtxOrDefault(), fVal) } } @@ -796,7 +800,7 @@ func (d *Datum) compareMysqlDecimal(sc *stmtctx.StatementContext, dec *MyDecimal return d.GetMysqlDecimal().Compare(dec), nil case KindString, KindBytes: dDec := new(MyDecimal) - err := sc.HandleTruncate(dDec.FromString(d.GetBytes())) + err := sc.TypeCtx.HandleTruncate(dDec.FromString(d.GetBytes())) return dDec.Compare(dec), errors.Trace(err) default: dVal, err := d.ConvertTo(sc, NewFieldType(mysql.TypeNewDecimal)) @@ -819,11 +823,11 @@ func (d *Datum) compareMysqlDuration(sc *stmtctx.StatementContext, dur Duration) dDur, _, err := ParseDuration(sc, d.GetString(), MaxFsp) return dDur.Compare(dur), errors.Trace(err) default: - return d.compareFloat64(sc, dur.Seconds()) + return d.compareFloat64(sc.TypeCtxOrDefault(), dur.Seconds()) } } -func (d *Datum) compareMysqlEnum(sc *stmtctx.StatementContext, enum Enum, comparer collate.Collator) (int, error) { +func (d *Datum) compareMysqlEnum(sc Context, enum Enum, comparer collate.Collator) (int, error) { switch d.k { case KindNull, KindMinNotNull: return -1, nil @@ -836,7 +840,7 @@ func (d *Datum) compareMysqlEnum(sc *stmtctx.StatementContext, enum Enum, compar } } -func (d *Datum) compareBinaryLiteral(sc *stmtctx.StatementContext, b BinaryLiteral, comparer collate.Collator) (int, error) { +func (d *Datum) compareBinaryLiteral(ctx Context, b BinaryLiteral, comparer collate.Collator) (int, error) { switch d.k { case KindNull, KindMinNotNull: return -1, nil @@ -847,16 +851,16 @@ func (d *Datum) compareBinaryLiteral(sc *stmtctx.StatementContext, b BinaryLiter case KindBinaryLiteral, KindMysqlBit: return comparer.Compare(d.GetBinaryLiteral4Cmp().ToString(), b.ToString()), nil default: - val, err := b.ToInt(sc) + val, err := b.ToInt(ctx) if err != nil { return 0, errors.Trace(err) } - result, err := d.compareFloat64(sc, float64(val)) + result, err := d.compareFloat64(ctx, float64(val)) return result, errors.Trace(err) } } -func (d *Datum) compareMysqlSet(sc *stmtctx.StatementContext, set Set, comparer collate.Collator) (int, error) { +func (d *Datum) compareMysqlSet(ctx Context, set Set, comparer collate.Collator) (int, error) { switch d.k { case KindNull, KindMinNotNull: return -1, nil @@ -865,7 +869,7 @@ func (d *Datum) compareMysqlSet(sc *stmtctx.StatementContext, set Set, comparer case KindString, KindBytes, KindMysqlEnum, KindMysqlSet: return comparer.Compare(d.GetString(), set.String()), nil default: - return d.compareFloat64(sc, set.ToNumber()) + return d.compareFloat64(ctx, set.ToNumber()) } } @@ -898,7 +902,7 @@ func (d *Datum) compareMysqlTime(sc *stmtctx.StatementContext, time Time) (int, if err != nil { return 0, errors.Trace(err) } - return d.compareFloat64(sc, fVal) + return d.compareFloat64(sc.TypeCtxOrDefault(), fVal) } } @@ -961,7 +965,7 @@ func (d *Datum) convertToFloat(sc *stmtctx.StatementContext, target *FieldType) case KindFloat32, KindFloat64: f = d.GetFloat64() case KindString, KindBytes: - f, err = StrToFloat(sc, d.GetString(), false) + f, err = StrToFloat(sc.TypeCtxOrDefault(), d.GetString(), false) case KindMysqlTime: f, err = d.GetMysqlTime().ToNumber().ToFloat64() case KindMysqlDuration: @@ -973,10 +977,10 @@ func (d *Datum) convertToFloat(sc *stmtctx.StatementContext, target *FieldType) case KindMysqlEnum: f = d.GetMysqlEnum().ToNumber() case KindBinaryLiteral, KindMysqlBit: - val, err1 := d.GetBinaryLiteral().ToInt(sc) + val, err1 := d.GetBinaryLiteral().ToInt(sc.TypeCtxOrDefault()) f, err = float64(val), err1 case KindMysqlJSON: - f, err = ConvertJSONToFloat(sc, d.GetMysqlJSON()) + f, err = ConvertJSONToFloat(sc.TypeCtxOrDefault(), d.GetMysqlJSON()) default: return invalidConv(d, target.GetType()) } @@ -1063,7 +1067,7 @@ func (d *Datum) convertToString(sc *stmtctx.StatementContext, target *FieldType) case KindMysqlBit: // https://github.com/pingcap/tidb/issues/31124. // Consider converting to uint first. - val, err := d.GetBinaryLiteral().ToInt(sc) + val, err := d.GetBinaryLiteral().ToInt(ctx) if err != nil { s = d.GetBinaryLiteral().ToString() } else { @@ -1075,7 +1079,7 @@ func (d *Datum) convertToString(sc *stmtctx.StatementContext, target *FieldType) return invalidConv(d, target.GetType()) } if err == nil { - s, err = ProduceStrWithSpecifiedTp(s, target, sc, true) + s, err = ProduceStrWithSpecifiedTp(s, target, ctx, true) } ret.SetString(s, target.GetCollate()) if target.GetCharset() == charset.CharsetBin { @@ -1086,7 +1090,7 @@ func (d *Datum) convertToString(sc *stmtctx.StatementContext, target *FieldType) // ProduceStrWithSpecifiedTp produces a new string according to `flen` and `chs`. Param `padZero` indicates // whether we should pad `\0` for `binary(flen)` type. -func ProduceStrWithSpecifiedTp(s string, tp *FieldType, sc *stmtctx.StatementContext, padZero bool) (_ string, err error) { +func ProduceStrWithSpecifiedTp(s string, tp *FieldType, ctx Context, padZero bool) (_ string, err error) { flen, chs := tp.GetFlen(), tp.GetCharset() if flen >= 0 { // overflowed stores the part of the string that is out of the length constraint, it is later checked to see if the @@ -1156,7 +1160,7 @@ func ProduceStrWithSpecifiedTp(s string, tp *FieldType, sc *stmtctx.StatementCon trimed := strings.TrimRight(overflowed, " \t\n\r") if len(trimed) == 0 && !IsBinaryStr(tp) && IsTypeChar(tp.GetType()) { if tp.GetType() == mysql.TypeVarchar { - sc.AppendWarning(ErrTruncated.GenWithStack("Data truncated, field len %d, data len %d", flen, characterLen)) + ctx.AppendWarning(ErrTruncated.GenWithStack("Data truncated, field len %d, data len %d", flen, characterLen)) } } else { err = ErrDataTooLong.GenWithStack("Data Too Long, field len %d, data len %d", flen, characterLen) @@ -1168,7 +1172,7 @@ func ProduceStrWithSpecifiedTp(s string, tp *FieldType, sc *stmtctx.StatementCon s = string(append([]byte(s), padding...)) } } - return s, errors.Trace(sc.HandleTruncate(err)) + return s, errors.Trace(ctx.HandleTruncate(err)) } func (d *Datum) convertToInt(sc *stmtctx.StatementContext, target *FieldType) (Datum, error) { @@ -1192,7 +1196,7 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) ( case KindFloat32, KindFloat64: val, err = ConvertFloatToUint(sc, d.GetFloat64(), upperBound, tp) case KindString, KindBytes: - uval, err1 := StrToUint(sc, d.GetString(), false) + uval, err1 := StrToUint(sc.TypeCtxOrDefault(), d.GetString(), false) if err1 != nil && ErrOverflow.Equal(err1) && !sc.ShouldIgnoreOverflowError() { return ret, errors.Trace(err1) } @@ -1226,7 +1230,7 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) ( case KindMysqlSet: val, err = ConvertFloatToUint(sc, d.GetMysqlSet().ToNumber(), upperBound, tp) case KindBinaryLiteral, KindMysqlBit: - val, err = d.GetBinaryLiteral().ToInt(sc) + val, err = d.GetBinaryLiteral().ToInt(sc.TypeCtxOrDefault()) if err == nil { val, err = ConvertUintToUint(val, upperBound, tp) } @@ -1457,11 +1461,11 @@ func (d *Datum) convertToMysqlDecimal(sc *stmtctx.StatementContext, target *Fiel case KindMysqlSet: err = dec.FromFloat64(d.GetMysqlSet().ToNumber()) case KindBinaryLiteral, KindMysqlBit: - val, err1 := d.GetBinaryLiteral().ToInt(sc) + val, err1 := d.GetBinaryLiteral().ToInt(sc.TypeCtxOrDefault()) err = err1 dec.FromUint(val) case KindMysqlJSON: - f, err1 := ConvertJSONToDecimal(sc, d.GetMysqlJSON()) + f, err1 := ConvertJSONToDecimal(sc.TypeCtxOrDefault(), d.GetMysqlJSON()) if err1 != nil { return ret, errors.Trace(err1) } @@ -1543,7 +1547,7 @@ func (d *Datum) ConvertToMysqlYear(sc *stmtctx.StatementContext, target *FieldTy case KindString, KindBytes: s := d.GetString() trimS := strings.TrimSpace(s) - y, err = StrToInt(sc, trimS, false) + y, err = StrToInt(sc.TypeCtxOrDefault(), trimS, false) if err != nil { ret.SetInt64(0) return ret, errors.Trace(err) @@ -1575,13 +1579,13 @@ func (d *Datum) ConvertToMysqlYear(sc *stmtctx.StatementContext, target *FieldTy return ret, errors.Trace(err) } -func (d *Datum) convertStringToMysqlBit(sc *stmtctx.StatementContext) (uint64, error) { +func (d *Datum) convertStringToMysqlBit(ctx Context) (uint64, error) { bitStr, err := ParseBitStr(BinaryLiteral(d.b).ToString()) if err != nil { // It cannot be converted to bit type, so we need to convert it to int type. - return BinaryLiteral(d.b).ToInt(sc) + return BinaryLiteral(d.b).ToInt(ctx) } - return bitStr.ToInt(sc) + return bitStr.ToInt(ctx) } func (d *Datum) convertToMysqlBit(sc *stmtctx.StatementContext, target *FieldType) (Datum, error) { @@ -1590,7 +1594,7 @@ func (d *Datum) convertToMysqlBit(sc *stmtctx.StatementContext, target *FieldTyp var err error switch d.k { case KindBytes: - uintValue, err = BinaryLiteral(d.b).ToInt(sc) + uintValue, err = BinaryLiteral(d.b).ToInt(sc.TypeCtxOrDefault()) case KindString: // For single bit value, we take string like "true", "1" as 1, and "false", "0" as 0, // this behavior is not documented in MySQL, but it behaves so, for more information, see issue #18681 @@ -1602,10 +1606,10 @@ func (d *Datum) convertToMysqlBit(sc *stmtctx.StatementContext, target *FieldTyp case "false", "0": uintValue = 0 default: - uintValue, err = d.convertStringToMysqlBit(sc) + uintValue, err = d.convertStringToMysqlBit(sc.TypeCtxOrDefault()) } } else { - uintValue, err = d.convertStringToMysqlBit(sc) + uintValue, err = d.convertStringToMysqlBit(sc.TypeCtxOrDefault()) } case KindInt64: // if input kind is int64 (signed), when trans to bit, we need to treat it as unsigned @@ -1740,7 +1744,7 @@ func (d *Datum) convertToMysqlJSON(_ *stmtctx.StatementContext, _ *FieldType) (r // ToBool converts to a bool. // We will use 1 for true, and 0 for false. -func (d *Datum) ToBool(sc *stmtctx.StatementContext) (int64, error) { +func (d *Datum) ToBool(ctx Context) (int64, error) { var err error isZero := false switch d.Kind() { @@ -1753,7 +1757,7 @@ func (d *Datum) ToBool(sc *stmtctx.StatementContext) (int64, error) { case KindFloat64: isZero = d.GetFloat64() == 0 case KindString, KindBytes: - iVal, err1 := StrToFloat(sc, d.GetString(), false) + iVal, err1 := StrToFloat(ctx, d.GetString(), false) isZero, err = iVal == 0, err1 case KindMysqlTime: @@ -1767,7 +1771,7 @@ func (d *Datum) ToBool(sc *stmtctx.StatementContext) (int64, error) { case KindMysqlSet: isZero = d.GetMysqlSet().ToNumber() == 0 case KindBinaryLiteral, KindMysqlBit: - val, err1 := d.GetBinaryLiteral().ToInt(sc) + val, err1 := d.GetBinaryLiteral().ToInt(ctx) isZero, err = val == 0, err1 case KindMysqlJSON: val := d.GetMysqlJSON() @@ -1788,7 +1792,7 @@ func (d *Datum) ToBool(sc *stmtctx.StatementContext) (int64, error) { } // ConvertDatumToDecimal converts datum to decimal. -func ConvertDatumToDecimal(sc *stmtctx.StatementContext, d Datum) (*MyDecimal, error) { +func ConvertDatumToDecimal(ctx Context, d Datum) (*MyDecimal, error) { dec := new(MyDecimal) var err error switch d.Kind() { @@ -1801,7 +1805,7 @@ func ConvertDatumToDecimal(sc *stmtctx.StatementContext, d Datum) (*MyDecimal, e case KindFloat64: err = dec.FromFloat64(d.GetFloat64()) case KindString: - err = sc.HandleTruncate(dec.FromString(d.GetBytes())) + err = ctx.HandleTruncate(dec.FromString(d.GetBytes())) case KindMysqlDecimal: *dec = *d.GetMysqlDecimal() case KindMysqlEnum: @@ -1809,11 +1813,11 @@ func ConvertDatumToDecimal(sc *stmtctx.StatementContext, d Datum) (*MyDecimal, e case KindMysqlSet: dec.FromUint(d.GetMysqlSet().Value) case KindBinaryLiteral, KindMysqlBit: - val, err1 := d.GetBinaryLiteral().ToInt(sc) + val, err1 := d.GetBinaryLiteral().ToInt(ctx) dec.FromUint(val) err = err1 case KindMysqlJSON: - f, err1 := ConvertJSONToDecimal(sc, d.GetMysqlJSON()) + f, err1 := ConvertJSONToDecimal(ctx, d.GetMysqlJSON()) if err1 != nil { return nil, errors.Trace(err1) } @@ -1825,21 +1829,21 @@ func ConvertDatumToDecimal(sc *stmtctx.StatementContext, d Datum) (*MyDecimal, e } // ToDecimal converts to a decimal. -func (d *Datum) ToDecimal(sc *stmtctx.StatementContext) (*MyDecimal, error) { +func (d *Datum) ToDecimal(ctx Context) (*MyDecimal, error) { switch d.Kind() { case KindMysqlTime: return d.GetMysqlTime().ToNumber(), nil case KindMysqlDuration: return d.GetMysqlDuration().ToNumber(), nil default: - return ConvertDatumToDecimal(sc, *d) + return ConvertDatumToDecimal(ctx, *d) } } // ToInt64 converts to a int64. func (d *Datum) ToInt64(sc *stmtctx.StatementContext) (int64, error) { if d.Kind() == KindMysqlBit { - uintVal, err := d.GetBinaryLiteral().ToInt(sc) + uintVal, err := d.GetBinaryLiteral().ToInt(sc.TypeCtxOrDefault()) return int64(uintVal), err } return d.toSignedInteger(sc, mysql.TypeLonglong) @@ -1858,7 +1862,7 @@ func (d *Datum) toSignedInteger(sc *stmtctx.StatementContext, tp byte) (int64, e case KindFloat64: return ConvertFloatToInt(d.GetFloat64(), lowerBound, upperBound, tp) case KindString, KindBytes: - iVal, err := StrToInt(sc, d.GetString(), false) + iVal, err := StrToInt(sc.TypeCtxOrDefault(), d.GetString(), false) iVal, err2 := ConvertIntToInt(iVal, lowerBound, upperBound, tp) if err == nil { err = err2 @@ -1911,7 +1915,7 @@ func (d *Datum) toSignedInteger(sc *stmtctx.StatementContext, tp byte) (int64, e case KindMysqlJSON: return ConvertJSONToInt(sc, d.GetMysqlJSON(), false, tp) case KindBinaryLiteral, KindMysqlBit: - val, err := d.GetBinaryLiteral().ToInt(sc) + val, err := d.GetBinaryLiteral().ToInt(sc.TypeCtxOrDefault()) if err != nil { return 0, errors.Trace(err) } @@ -1923,7 +1927,7 @@ func (d *Datum) toSignedInteger(sc *stmtctx.StatementContext, tp byte) (int64, e } // ToFloat64 converts to a float64 -func (d *Datum) ToFloat64(sc *stmtctx.StatementContext) (float64, error) { +func (d *Datum) ToFloat64(ctx Context) (float64, error) { switch d.Kind() { case KindInt64: return float64(d.GetInt64()), nil @@ -1934,9 +1938,9 @@ func (d *Datum) ToFloat64(sc *stmtctx.StatementContext) (float64, error) { case KindFloat64: return d.GetFloat64(), nil case KindString: - return StrToFloat(sc, d.GetString(), false) + return StrToFloat(ctx, d.GetString(), false) case KindBytes: - return StrToFloat(sc, string(d.GetBytes()), false) + return StrToFloat(ctx, string(d.GetBytes()), false) case KindMysqlTime: f, err := d.GetMysqlTime().ToNumber().ToFloat64() return f, errors.Trace(err) @@ -1951,10 +1955,10 @@ func (d *Datum) ToFloat64(sc *stmtctx.StatementContext) (float64, error) { case KindMysqlSet: return d.GetMysqlSet().ToNumber(), nil case KindBinaryLiteral, KindMysqlBit: - val, err := d.GetBinaryLiteral().ToInt(sc) + val, err := d.GetBinaryLiteral().ToInt(ctx) return float64(val), errors.Trace(err) case KindMysqlJSON: - f, err := ConvertJSONToFloat(sc, d.GetMysqlJSON()) + f, err := ConvertJSONToFloat(ctx, d.GetMysqlJSON()) return f, errors.Trace(err) default: return 0, errors.Errorf("cannot convert %v(type %T) to float64", d.GetValue(), d.GetValue()) diff --git a/pkg/types/datum_test.go b/pkg/types/datum_test.go index 4b6a4e576207e..a5fc7a57d7758 100644 --- a/pkg/types/datum_test.go +++ b/pkg/types/datum_test.go @@ -56,9 +56,8 @@ func TestDatum(t *testing.T) { func testDatumToBool(t *testing.T, in interface{}, res int) { datum := NewDatum(in) res64 := int64(res) - sc := stmtctx.NewStmtCtx() - sc.IgnoreTruncate.Store(true) - b, err := datum.ToBool(sc) + ctx := DefaultNoWarningContext.WithFlags(StrictFlags.WithIgnoreTruncateErr(true)) + b, err := datum.ToBool(ctx) require.NoError(t, err) require.Equal(t, res64, b) } @@ -108,16 +107,17 @@ func TestToBool(t *testing.T) { require.NoError(t, err) testDatumToBool(t, v, 1) d := NewDatum(&invalidMockType{}) - sc := stmtctx.NewStmtCtx() - sc.IgnoreTruncate.Store(true) - _, err = d.ToBool(sc) + ctx := DefaultNoWarningContext.WithFlags(StrictFlags.WithIgnoreTruncateErr(true)) + _, err = d.ToBool(ctx) require.Error(t, err) } func testDatumToInt64(t *testing.T, val interface{}, expect int64) { d := NewDatum(val) + sc := stmtctx.NewStmtCtx() - sc.IgnoreTruncate.Store(true) + sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(true)) + b, err := d.ToInt64(sc) require.NoError(t, err) require.Equal(t, expect, b) @@ -153,7 +153,7 @@ func TestToInt64(t *testing.T) { func testDatumToUInt32(t *testing.T, val interface{}, expect uint32, hasError bool) { d := NewDatum(val) sc := stmtctx.NewStmtCtx() - sc.IgnoreTruncate.Store(true) + sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(true)) ft := NewFieldType(mysql.TypeLong) ft.AddFlag(mysql.UnsignedFlag) @@ -205,7 +205,7 @@ func TestConvertToFloat(t *testing.T) { } sc := stmtctx.NewStmtCtx() - sc.IgnoreTruncate.Store(true) + sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(true)) for _, testCase := range testCases { converted, err := testCase.d.ConvertTo(sc, NewFieldType(testCase.tp)) if testCase.errMsg == "" { @@ -310,7 +310,7 @@ func TestToBytes(t *testing.T) { {Datum{}, []byte{}}, } sc := stmtctx.NewStmtCtx() - sc.IgnoreTruncate.Store(true) + sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(true)) for _, tt := range tests { bin, err := tt.a.ToBytes() require.NoError(t, err) @@ -359,7 +359,7 @@ func TestCloneDatum(t *testing.T) { } sc := stmtctx.NewStmtCtx() - sc.IgnoreTruncate.Store(true) + sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(true)) for _, tt := range tests { tt1 := *tt.Clone() res, err := tt.Compare(sc, &tt1, collate.GetBinaryCollator()) @@ -413,7 +413,7 @@ func TestEstimatedMemUsage(t *testing.T) { func TestChangeReverseResultByUpperLowerBound(t *testing.T) { sc := stmtctx.NewStmtCtx() - sc.IgnoreTruncate.Store(true) + sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(true)) sc.OverflowAsWarning = true // TODO: add more reserve convert tests for each pair of convert type. testData := []struct { @@ -538,7 +538,7 @@ func TestStringToMysqlBit(t *testing.T) { {NewStringDatum("b'0'"), []byte{0}}, } sc := stmtctx.NewStmtCtx() - sc.IgnoreTruncate.Store(true) + sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(true)) tp := NewFieldType(mysql.TypeBit) tp.SetFlen(1) for _, tt := range tests { diff --git a/pkg/types/time.go b/pkg/types/time.go index 27c76ef1d7492..012f3c4e27df0 100644 --- a/pkg/types/time.go +++ b/pkg/types/time.go @@ -1045,7 +1045,7 @@ func parseDatetime(sc *stmtctx.StatementContext, str string, fsp int, isFloat bo l := len(seps[0]) // Values specified as numbers if isFloat { - numOfTime, err := StrToInt(sc, seps[0], false) + numOfTime, err := StrToInt(sc.TypeCtxOrDefault(), seps[0], false) if err != nil { return ZeroDatetime, errors.Trace(ErrWrongValue.GenWithStackByArgs(DateTimeStr, str)) } @@ -2045,7 +2045,7 @@ func ParseTimeFromNum(sc *stmtctx.StatementContext, num int64, tp byte, fsp int) // MySQL compatibility: 0 should not be converted to null, see #11203 if num == 0 { zt := NewTime(ZeroCoreTime, tp, DefaultFsp) - if sc != nil && sc.InCreateOrAlterStmt && !sc.TruncateAsWarning && sc.NoZeroDate { + if sc != nil && sc.InCreateOrAlterStmt && !sc.TypeFlags().TruncateAsWarning() && sc.NoZeroDate { switch tp { case mysql.TypeTimestamp: return zt, ErrTruncatedWrongVal.GenWithStackByArgs(TimestampStr, "0") diff --git a/pkg/util/codec/codec.go b/pkg/util/codec/codec.go index 1c9519280d8d4..cec50ac92cc51 100644 --- a/pkg/util/codec/codec.go +++ b/pkg/util/codec/codec.go @@ -110,7 +110,7 @@ func encode(sc *stmtctx.StatementContext, b []byte, vals []types.Datum, comparab b = append(b, decimalFlag) b, err = EncodeDecimal(b, vals[i].GetMysqlDecimal(), vals[i].Length(), vals[i].Frac()) if terror.ErrorEqual(err, types.ErrTruncated) { - err = sc.HandleTruncate(err) + err = sc.TypeCtx.HandleTruncate(err) } else if terror.ErrorEqual(err, types.ErrOverflow) { err = sc.HandleOverflow(err, err) } @@ -121,7 +121,7 @@ func encode(sc *stmtctx.StatementContext, b []byte, vals []types.Datum, comparab case types.KindMysqlBit, types.KindBinaryLiteral: // We don't need to handle errors here since the literal is ensured to be able to store in uint64 in convertToMysqlBit. var val uint64 - val, err = vals[i].GetBinaryLiteral().ToInt(sc) + val, err = vals[i].GetBinaryLiteral().ToInt(sc.TypeCtxOrDefault()) terror.Log(errors.Trace(err)) b = encodeUnsignedInt(b, val, comparable1) case types.KindMysqlJSON: @@ -167,7 +167,7 @@ func EstimateValueSize(sc *stmtctx.StatementContext, val types.Datum) (int, erro case types.KindMysqlSet: l = valueSizeOfUnsignedInt(val.GetMysqlSet().Value) case types.KindMysqlBit, types.KindBinaryLiteral: - val, err := val.GetBinaryLiteral().ToInt(sc) + val, err := val.GetBinaryLiteral().ToInt(sc.TypeCtxOrDefault()) terror.Log(errors.Trace(err)) l = valueSizeOfUnsignedInt(val) case types.KindMysqlJSON: @@ -381,7 +381,7 @@ func encodeHashChunkRowIdx(sc *stmtctx.StatementContext, row chunk.Row, tp *type case mysql.TypeBit: // We don't need to handle errors here since the literal is ensured to be able to store in uint64 in convertToMysqlBit. flag = uvarintFlag - v, err1 := types.BinaryLiteral(row.GetBytes(idx)).ToInt(sc) + v, err1 := types.BinaryLiteral(row.GetBytes(idx)).ToInt(sc.TypeCtxOrDefault()) terror.Log(errors.Trace(err1)) b = unsafe.Slice((*byte)(unsafe.Pointer(&v)), unsafe.Sizeof(v)) case mysql.TypeJSON: @@ -626,7 +626,7 @@ func HashChunkSelected(sc *stmtctx.StatementContext, h []hash.Hash64, chk *chunk } else { // We don't need to handle errors here since the literal is ensured to be able to store in uint64 in convertToMysqlBit. buf[0] = uvarintFlag - v, err1 := types.BinaryLiteral(column.GetBytes(i)).ToInt(sc) + v, err1 := types.BinaryLiteral(column.GetBytes(i)).ToInt(sc.TypeCtxOrDefault()) terror.Log(errors.Trace(err1)) b = unsafe.Slice((*byte)(unsafe.Pointer(&v)), sizeUint64) } @@ -1260,7 +1260,7 @@ func HashGroupKey(sc *stmtctx.StatementContext, n int, col *chunk.Column, buf [] buf[i] = append(buf[i], decimalFlag) buf[i], err = EncodeDecimal(buf[i], &ds[i], ft.GetFlen(), ft.GetDecimal()) if terror.ErrorEqual(err, types.ErrTruncated) { - err = sc.HandleTruncate(err) + err = sc.TypeCtx.HandleTruncate(err) } else if terror.ErrorEqual(err, types.ErrOverflow) { err = sc.HandleOverflow(err, err) } diff --git a/pkg/util/codec/codec_test.go b/pkg/util/codec/codec_test.go index f4931c8529013..3875c32a65ab4 100644 --- a/pkg/util/codec/codec_test.go +++ b/pkg/util/codec/codec_test.go @@ -720,12 +720,12 @@ func TestDecimal(t *testing.T) { } for _, decimalNums := range tblCmp { d1 := types.NewDatum(decimalNums.Arg1) - dec1, err := d1.ToDecimal(sc) + dec1, err := d1.ToDecimal(sc.TypeCtxOrDefault()) require.NoError(t, err) d1.SetMysqlDecimal(dec1) d2 := types.NewDatum(decimalNums.Arg2) - dec2, err := d2.ToDecimal(sc) + dec2, err := d2.ToDecimal(sc.TypeCtxOrDefault()) require.NoError(t, err) d2.SetMysqlDecimal(dec2) @@ -778,7 +778,7 @@ func TestDecimal(t *testing.T) { _, err = EncodeDecimal(nil, d, 12, 10) require.Truef(t, terror.ErrorEqual(err, types.ErrOverflow), "err %v", err) - sc.IgnoreTruncate.Store(true) + sc.SetTypeFlags(types.StrictFlags.WithIgnoreTruncateErr(true)) decimalDatum := types.NewDatum(d) decimalDatum.SetLength(20) decimalDatum.SetFrac(5) diff --git a/pkg/util/ranger/points.go b/pkg/util/ranger/points.go index 31063894d533d..a637dab5578ef 100644 --- a/pkg/util/ranger/points.go +++ b/pkg/util/ranger/points.go @@ -209,7 +209,7 @@ func (r *builder) buildFromConstant(expr *expression.Constant) []*point { return nil } - val, err := dt.ToBool(r.sc) + val, err := dt.ToBool(r.sc.TypeCtx) if err != nil { r.err = err return nil diff --git a/pkg/util/rowcodec/common.go b/pkg/util/rowcodec/common.go index e8651a53b5966..eb399b07ca088 100644 --- a/pkg/util/rowcodec/common.go +++ b/pkg/util/rowcodec/common.go @@ -338,7 +338,7 @@ func appendDatumForChecksum(buf []byte, dat *data.Datum, typ byte) (out []byte, out = binary.LittleEndian.AppendUint64(buf, dat.GetMysqlSet().Value) case mysql.TypeBit: // ticdc transforms a bit value as the following way, no need to handle truncate error here. - v, _ := dat.GetBinaryLiteral().ToInt(nil) + v, _ := dat.GetBinaryLiteral().ToInt(data.DefaultNoWarningContext) out = binary.LittleEndian.AppendUint64(buf, v) case mysql.TypeJSON: out = appendLengthValue(buf, []byte(dat.GetMysqlJSON().String())) diff --git a/pkg/util/rowcodec/encoder.go b/pkg/util/rowcodec/encoder.go index 92fbd566c96cb..14bb36425d219 100644 --- a/pkg/util/rowcodec/encoder.go +++ b/pkg/util/rowcodec/encoder.go @@ -194,7 +194,7 @@ func encodeValueDatum(sc *stmtctx.StatementContext, d *types.Datum, buffer []byt case types.KindBinaryLiteral, types.KindMysqlBit: // We don't need to handle errors here since the literal is ensured to be able to store in uint64 in convertToMysqlBit. var val uint64 - val, err = d.GetBinaryLiteral().ToInt(sc) + val, err = d.GetBinaryLiteral().ToInt(sc.TypeCtxOrDefault()) if err != nil { return } @@ -205,7 +205,7 @@ func encodeValueDatum(sc *stmtctx.StatementContext, d *types.Datum, buffer []byt buffer, err = codec.EncodeDecimal(buffer, d.GetMysqlDecimal(), d.Length(), d.Frac()) if err != nil && sc != nil { if terror.ErrorEqual(err, types.ErrTruncated) { - err = sc.HandleTruncate(err) + err = sc.TypeCtx.HandleTruncate(err) } else if terror.ErrorEqual(err, types.ErrOverflow) { err = sc.HandleOverflow(err, err) }