diff --git a/pkg/meta/model/BUILD.bazel b/pkg/meta/model/BUILD.bazel index 0dd924c3903ae..e9b16e0813c85 100644 --- a/pkg/meta/model/BUILD.bazel +++ b/pkg/meta/model/BUILD.bazel @@ -29,6 +29,7 @@ go_library( "//pkg/planner/cascades/base", "//pkg/util/intest", "@com_github_pingcap_errors//:errors", + "@com_github_pingcap_failpoint//:failpoint", "@com_github_tikv_pd_client//http", "@org_uber_go_atomic//:atomic", ], diff --git a/pkg/meta/model/table.go b/pkg/meta/model/table.go index 27d4a8784b51a..d2d7fc21d8e9b 100644 --- a/pkg/meta/model/table.go +++ b/pkg/meta/model/table.go @@ -22,6 +22,7 @@ import ( "time" "unsafe" + "github.com/pingcap/failpoint" "github.com/pingcap/tidb/pkg/parser/auth" "github.com/pingcap/tidb/pkg/parser/duration" "github.com/pingcap/tidb/pkg/parser/model" @@ -1351,6 +1352,10 @@ func (t *TTLInfo) Clone() *TTLInfo { // Didn't set TTL_JOB_INTERVAL during upgrade and bootstrap because setting default value here is much simpler // and could avoid bugs blocking users from upgrading or bootstrapping the cluster. func (t *TTLInfo) GetJobInterval() (time.Duration, error) { + failpoint.Inject("overwrite-ttl-job-interval", func(val failpoint.Value) (time.Duration, error) { + return time.Duration(val.(int)), nil + }) + if len(t.JobInterval) == 0 { // This only happens when the table is created from 6.5 in which the `tidb_job_interval` is not introduced yet. // We use `OldDefaultTTLJobInterval` as the return value to ensure a consistent behavior for the diff --git a/pkg/ttl/ttlworker/config.go b/pkg/ttl/ttlworker/config.go index d3c7cbc5c5474..3c06c3795b43c 100644 --- a/pkg/ttl/ttlworker/config.go +++ b/pkg/ttl/ttlworker/config.go @@ -44,6 +44,13 @@ func getCheckJobInterval() time.Duration { return jobManagerLoopTickerInterval } +func getHeartbeatInterval() time.Duration { + failpoint.Inject("heartbeat-interval", func(val failpoint.Value) time.Duration { + return time.Duration(val.(int)) + }) + return jobManagerLoopTickerInterval +} + func getJobManagerLoopSyncTimerInterval() time.Duration { failpoint.Inject("sync-timer", func(val failpoint.Value) time.Duration { return time.Duration(val.(int)) @@ -86,11 +93,11 @@ func getTaskManagerLoopTickerInterval() time.Duration { return taskManagerLoopTickerInterval } -func getTaskManagerHeartBeatExpireInterval() time.Duration { - failpoint.Inject("task-manager-heartbeat-expire-interval", func(val failpoint.Value) time.Duration { +func getTaskManagerHeartBeatInterval() time.Duration { + failpoint.Inject("task-manager-heartbeat-interval", func(val failpoint.Value) time.Duration { return time.Duration(val.(int)) }) - return 2 * ttlTaskHeartBeatTickerInterval + return ttlTaskHeartBeatTickerInterval } func getCheckJobTriggeredInterval() time.Duration { @@ -100,6 +107,13 @@ func getCheckJobTriggeredInterval() time.Duration { return 2 * time.Second } +func getTTLGCInterval() time.Duration { + failpoint.Inject("gc-interval", func(val failpoint.Value) time.Duration { + return time.Duration(val.(int)) + }) + return ttlGCInterval +} + func getScanSplitCnt(store kv.Storage) int { tikvStore, ok := store.(tikv.Storage) if !ok { diff --git a/pkg/ttl/ttlworker/job_manager.go b/pkg/ttl/ttlworker/job_manager.go index 88fbdd1a6e84f..55dc9e770f3ca 100644 --- a/pkg/ttl/ttlworker/job_manager.go +++ b/pkg/ttl/ttlworker/job_manager.go @@ -66,7 +66,7 @@ const taskGCTemplate = `DELETE task FROM const ttlJobHistoryGCTemplate = `DELETE FROM mysql.tidb_ttl_job_history WHERE create_time < CURDATE() - INTERVAL 90 DAY` const ttlTableStatusGCWithoutIDTemplate = `DELETE FROM mysql.tidb_ttl_table_status WHERE (current_job_status IS NULL OR current_job_owner_hb_time < %?)` -const timeFormat = time.DateTime +var timeFormat = time.DateTime func insertNewTableIntoStatusSQL(tableID int64, parentTableID int64) (string, []any) { return insertNewTableIntoStatusTemplate, []any{tableID, parentTableID} @@ -86,7 +86,7 @@ func gcTTLTableStatusGCSQL(existIDs []int64, now time.Time) (string, []any) { existIDStrs = append(existIDStrs, strconv.Itoa(int(id))) } - hbExpireTime := now.Add(-jobManagerLoopTickerInterval * 2) + hbExpireTime := now.Add(-getHeartbeatInterval() * 2) args := []any{hbExpireTime.Format(timeFormat)} if len(existIDStrs) > 0 { return ttlTableStatusGCWithoutIDTemplate + fmt.Sprintf(` AND table_id NOT IN (%s)`, strings.Join(existIDStrs, ",")), args @@ -137,6 +137,10 @@ func NewJobManager(id string, sessPool util.SessionPool, store kv.Storage, etcdC manager.init(manager.jobLoop) manager.ctx = logutil.WithKeyValue(manager.ctx, "ttl-worker", "job-manager") + if intest.InTest { + // in test environment, in the same log there will be multiple ttl managers, so we need to distinguish them + manager.ctx = logutil.WithKeyValue(manager.ctx, "ttl-worker", id) + } manager.infoSchemaCache = cache.NewInfoSchemaCache(getUpdateInfoSchemaCacheInterval()) manager.tableStatusCache = cache.NewTableStatusCache(getUpdateTTLTableStatusCacheInterval()) @@ -181,15 +185,15 @@ func (m *JobManager) jobLoop() error { infoSchemaCacheUpdateTicker := time.Tick(m.infoSchemaCache.GetInterval()) tableStatusCacheUpdateTicker := time.Tick(m.tableStatusCache.GetInterval()) resizeWorkersTicker := time.Tick(getResizeWorkersInterval()) - gcTicker := time.Tick(ttlGCInterval) + gcTicker := time.Tick(getTTLGCInterval()) scheduleJobTicker := time.Tick(getCheckJobInterval()) jobCheckTicker := time.Tick(getCheckJobInterval()) - updateJobHeartBeatTicker := time.Tick(jobManagerLoopTickerInterval) + updateJobHeartBeatTicker := time.Tick(getHeartbeatInterval()) timerTicker := time.Tick(getJobManagerLoopSyncTimerInterval()) scheduleTaskTicker := time.Tick(getTaskManagerLoopTickerInterval()) - updateTaskHeartBeatTicker := time.Tick(ttlTaskHeartBeatTickerInterval) + updateTaskHeartBeatTicker := time.Tick(getTaskManagerHeartBeatInterval()) taskCheckTicker := time.Tick(getTaskManagerLoopCheckTaskInterval()) checkScanTaskFinishedTicker := time.Tick(getTaskManagerLoopTickerInterval()) @@ -732,7 +736,7 @@ func (m *JobManager) couldLockJob(tableStatus *cache.TableStatus, table *cache.P hbTime := tableStatus.CurrentJobOwnerHBTime // jobManagerLoopTickerInterval is used to do heartbeat periodically. // Use twice the time to detect the heartbeat timeout. - hbTimeout := jobManagerLoopTickerInterval * 2 + hbTimeout := getHeartbeatInterval() * 2 if interval := getUpdateTTLTableStatusCacheInterval() * 2; interval > hbTimeout { // tableStatus is get from the cache which may contain stale data. // So if cache update interval > heartbeat interval, use the cache update interval instead. diff --git a/pkg/ttl/ttlworker/job_manager_integration_test.go b/pkg/ttl/ttlworker/job_manager_integration_test.go index f271c4518f5ff..03d27ace8f70c 100644 --- a/pkg/ttl/ttlworker/job_manager_integration_test.go +++ b/pkg/ttl/ttlworker/job_manager_integration_test.go @@ -18,6 +18,7 @@ import ( "context" "encoding/json" "fmt" + "math/rand/v2" "strconv" "strings" "sync" @@ -1619,3 +1620,179 @@ func TestJobHeartBeatFailNotBlockOthers(t *testing.T) { fmt.Sprintf("%d %s", testTable1.Meta().ID, now.Add(-2*time.Hour).In(tkTZ).Format(time.DateTime)), fmt.Sprintf("%d %s", testTable2.Meta().ID, now.In(tkTZ).Format(time.DateTime)))) } + +var _ fault = &faultWithProbability{} + +type faultWithProbability struct { + percent float64 +} + +func (f *faultWithProbability) shouldFault(sql string) bool { + return rand.Float64() < f.percent +} + +func newFaultWithProbability(percent float64) *faultWithProbability { + return &faultWithProbability{percent: percent} +} + +func accelerateHeartBeat(t *testing.T, tk *testkit.TestKit) func() { + tk.MustExec("ALTER TABLE mysql.tidb_ttl_table_status MODIFY COLUMN current_job_owner_hb_time TIMESTAMP(6)") + tk.MustExec("ALTER TABLE mysql.tidb_ttl_task MODIFY COLUMN owner_hb_time TIMESTAMP(6)") + ttlworker.SetTimeFormat(time.DateTime + ".999999") + + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/ttl/ttlworker/heartbeat-interval", fmt.Sprintf("return(%d)", time.Millisecond*100))) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/ttl/ttlworker/task-manager-heartbeat-interval", fmt.Sprintf("return(%d)", time.Millisecond*100))) + return func() { + tk.MustExec("ALTER TABLE mysql.tidb_ttl_table_status MODIFY COLUMN current_job_owner_hb_time TIMESTAMP") + tk.MustExec("ALTER TABLE mysql.tidb_ttl_task MODIFY COLUMN owner_hb_time TIMESTAMP") + ttlworker.SetTimeFormat(time.DateTime) + + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/ttl/ttlworker/heartbeat-interval")) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/ttl/ttlworker/task-manager-heartbeat-interval")) + } +} + +func TestJobManagerWithFault(t *testing.T) { + // TODO: add a flag `-long` to enable this test + t.Skip("skip this test because it'll need to run for a long time") + + defer boostJobScheduleForTest(t)() + + store, dom := testkit.CreateMockStoreAndDomain(t) + waitAndStopTTLManager(t, dom) + + tk := testkit.NewTestKit(t, store) + defer accelerateHeartBeat(t, tk)() + tk.MustExec("set @@global.tidb_ttl_running_tasks=32") + + managerCount := 20 + testDuration := 10 * time.Minute + faultPercent := 0.5 + + leader := atomic.NewString("") + isLeaderFactory := func(id string) func() bool { + return func() bool { + return leader.Load() == id + } + } + + type managerWithPool struct { + m *ttlworker.JobManager + pool util.SessionPool + } + managers := make([]managerWithPool, 0, managerCount) + for i := 0; i < managerCount; i++ { + pool := wrapPoolForTest(dom.SysSessionPool()) + faultPool := newFaultSessionPool(pool) + + id := fmt.Sprintf("test-ttl-job-manager-%d", i) + m := ttlworker.NewJobManager(id, faultPool, store, nil, isLeaderFactory(id)) + managers = append(managers, managerWithPool{ + m: m, + pool: faultPool, + }) + + m.Start() + } + + stopTestCh := make(chan struct{}) + wg := &sync.WaitGroup{} + wg.Add(1) + + fault := newFaultWithFilter(func(sql string) bool { + // skip some local only sql, ref `getSession()` in `session.go` + if strings.HasPrefix(sql, "set tidb_") || strings.HasPrefix(sql, "set @@") || + strings.ToUpper(sql) == "COMMIT" || strings.ToUpper(sql) == "ROLLBACK" { + return false + } + + return true + }, newFaultWithProbability(faultPercent)) + go func() { + defer wg.Done() + + faultTicker := time.NewTicker(time.Second) + for { + select { + case <-stopTestCh: + // Recover all sessions + for _, m := range managers { + m.pool.(*faultSessionPool).setFault(nil) + } + + return + case <-faultTicker.C: + // Recover all sessions + for _, m := range managers { + m.pool.(*faultSessionPool).setFault(nil) + } + + faultCount := rand.Int() % managerCount + logutil.BgLogger().Info("inject fault", zap.Int("faultCount", faultCount)) + rand.Shuffle(managerCount, func(i, j int) { + managers[i], managers[j] = managers[j], managers[i] + }) + // the first non-faultt manager is the leader + leader.Store(managers[faultCount].m.ID()) + logutil.BgLogger().Info("set leader", zap.String("leader", leader.Load())) + for i := 0; i < faultCount; i++ { + m := managers[i] + logutil.BgLogger().Info("inject fault", zap.String("id", m.m.ID())) + m.pool.(*faultSessionPool).setFault(fault) + } + } + } + }() + + // run the workload goroutine + testStart := time.Now() + for time.Since(testStart) < testDuration { + // create a new table + tk.MustExec("use test") + tk.MustExec("DROP TABLE if exists t") + tk.MustExec("CREATE TABLE t (id INT PRIMARY KEY, created_at DATETIME) TTL = created_at + INTERVAL 1 HOUR TTL_ENABLE='OFF'") + tbl, err := dom.InfoSchema().TableByName(context.Background(), pmodel.NewCIStr("test"), pmodel.NewCIStr("t")) + require.NoError(t, err) + logutil.BgLogger().Info("create table", zap.Int64("table_id", tbl.Meta().ID)) + + // insert some data + for i := 0; i < 5; i++ { + tk.MustExec(fmt.Sprintf("INSERT INTO t VALUES (%d, '%s')", i, time.Now().Add(-time.Hour*2).Format(time.DateTime))) + } + for i := 0; i < 5; i++ { + tk.MustExec(fmt.Sprintf("INSERT INTO t VALUES (%d, '%s')", i+5, time.Now().Format(time.DateTime))) + } + + tk.MustExec("ALTER TABLE t TTL_ENABLE='ON'") + + start := time.Now() + require.Eventually(t, func() bool { + rows := tk.MustQuery("SELECT COUNT(*) FROM t").Rows() + if len(rows) == 1 && rows[0][0].(string) == "5" { + return true + } + + logutil.BgLogger().Info("get row count", zap.String("count", rows[0][0].(string))) + return false + }, time.Second*5, time.Millisecond*100) + + require.Eventually(t, func() bool { + rows := tk.MustQuery("SELECT current_job_state FROM mysql.tidb_ttl_table_status").Rows() + if len(rows) == 1 && rows[0][0].(string) == "" { + return true + } + + tableStatus := tk.MustQuery("SELECT * FROM mysql.tidb_ttl_table_status").String() + logutil.BgLogger().Info("get job state", zap.String("tidb_ttl_table_status", tableStatus)) + return false + }, time.Second*5, time.Millisecond*100) + + logutil.BgLogger().Info("finish workload", zap.Duration("duration", time.Since(start))) + } + + logutil.BgLogger().Info("test finished") + stopTestCh <- struct{}{} + close(stopTestCh) + + wg.Wait() +} diff --git a/pkg/ttl/ttlworker/job_manager_test.go b/pkg/ttl/ttlworker/job_manager_test.go index c40b83b589080..37102bc48863a 100644 --- a/pkg/ttl/ttlworker/job_manager_test.go +++ b/pkg/ttl/ttlworker/job_manager_test.go @@ -210,6 +210,11 @@ func (m *JobManager) ReportMetrics(se session.Session) { m.reportMetrics(se) } +// ID returns the id of JobManager +func (m *JobManager) ID() string { + return m.id +} + // CheckFinishedJob is an exported version of checkFinishedJob func (m *JobManager) CheckFinishedJob(se session.Session) { m.checkFinishedJob(se) @@ -695,3 +700,10 @@ func TestSplitCnt(t *testing.T) { } } } + +// SetTimeFormat sets the time format used by the test. +// Some tests require a greater precision than the default time format. We don't change it globally to avoid potential compatibility issues. +// Therefore, the format for most tests are also not changed, to make sure the tests can represent the real-world scenarios. +func SetTimeFormat(format string) { + timeFormat = format +} diff --git a/pkg/ttl/ttlworker/session_integration_test.go b/pkg/ttl/ttlworker/session_integration_test.go index c2842e303a17b..2d0a7b91f523d 100644 --- a/pkg/ttl/ttlworker/session_integration_test.go +++ b/pkg/ttl/ttlworker/session_integration_test.go @@ -35,7 +35,7 @@ import ( type fault interface { // shouldFault returns whether the session should fault this time. - shouldFault() bool + shouldFault(sql string) bool } var _ fault = &faultAfterCount{} @@ -46,7 +46,11 @@ type faultAfterCount struct { currentCount int } -func (f *faultAfterCount) shouldFault() bool { +func newFaultAfterCount(faultCount int) *faultAfterCount { + return &faultAfterCount{faultCount: faultCount} +} + +func (f *faultAfterCount) shouldFault(sql string) bool { if f.currentCount >= f.faultCount { return true } @@ -55,6 +59,23 @@ func (f *faultAfterCount) shouldFault() bool { return false } +type faultWithFilter struct { + filter func(string) bool + f fault +} + +func (f *faultWithFilter) shouldFault(sql string) bool { + if f.filter == nil || f.filter(sql) { + return f.f.shouldFault(sql) + } + + return false +} + +func newFaultWithFilter(filter func(string) bool, f fault) *faultWithFilter { + return &faultWithFilter{filter: filter, f: f} +} + // sessionWithFault is a session which will fail to execute SQL after successfully executing several SQLs. It's designed // to trigger every possible branch of returning error from `Execute` type sessionWithFault struct { @@ -97,19 +118,12 @@ func (s *sessionWithFault) ExecuteInternal(ctx context.Context, sql string, args } func (s *sessionWithFault) shouldFault(sql string) bool { - if s.fault.Load() == nil { + fault := s.fault.Load() + if fault == nil { return false } - // as a fault implementation may have side-effect, we should always call it before checking the SQL. - shouldFault := (*s.fault.Load()).shouldFault() - - // skip some local only sql, ref `getSession()` in `session.go` - if strings.HasPrefix(sql, "set tidb_") || strings.HasPrefix(sql, "set @@") { - return false - } - - return shouldFault + return (*fault).shouldFault(sql) } type faultSessionPool struct { @@ -144,6 +158,11 @@ func (f *faultSessionPool) Put(se pools.Resource) { } func (f *faultSessionPool) setFault(ft fault) { + if ft == nil { + f.fault.Store(nil) + return + } + f.fault.Store(&ft) } @@ -153,7 +172,14 @@ func TestGetSessionWithFault(t *testing.T) { pool := newFaultSessionPool(dom.SysSessionPool()) for i := 0; i < 50; i++ { - pool.setFault(&faultAfterCount{faultCount: i}) + pool.setFault(newFaultWithFilter(func(sql string) bool { + // skip some local only sql, ref `getSession()` in `session.go` + if strings.HasPrefix(sql, "set tidb_") || strings.HasPrefix(sql, "set @@") { + return false + } + return true + }, newFaultAfterCount(i))) + se, err := ttlworker.GetSessionForTest(pool) logutil.BgLogger().Info("get session", zap.Int("error after count", i), zap.Bool("session is nil", se == nil), zap.Bool("error is nil", err == nil)) require.True(t, se != nil || err != nil) diff --git a/pkg/ttl/ttlworker/task_manager.go b/pkg/ttl/ttlworker/task_manager.go index d8cedb6695849..2c7569cf988a1 100644 --- a/pkg/ttl/ttlworker/task_manager.go +++ b/pkg/ttl/ttlworker/task_manager.go @@ -118,8 +118,13 @@ type taskManager struct { } func newTaskManager(ctx context.Context, sessPool util.SessionPool, infoSchemaCache *cache.InfoSchemaCache, id string, store kv.Storage) *taskManager { + ctx = logutil.WithKeyValue(ctx, "ttl-worker", "task-manager") + if intest.InTest { + // in test environment, in the same log there will be multiple ttl managers, so we need to distinguish them + ctx = logutil.WithKeyValue(ctx, "ttl-worker", id) + } return &taskManager{ - ctx: logutil.WithKeyValue(ctx, "ttl-worker", "task-manager"), + ctx: ctx, sessPool: sessPool, id: id, @@ -374,7 +379,7 @@ loop: func (m *taskManager) peekWaitingScanTasks(se session.Session, now time.Time) ([]*cache.TTLTask, error) { intest.Assert(se.GetSessionVars().Location().String() == now.Location().String()) - sql, args := cache.PeekWaitingTTLTask(now.Add(-getTaskManagerHeartBeatExpireInterval())) + sql, args := cache.PeekWaitingTTLTask(now.Add(-2 * getTaskManagerHeartBeatInterval())) rows, err := se.ExecuteSQL(m.ctx, sql, args...) if err != nil { return nil, errors.Wrapf(err, "execute sql: %s", sql) @@ -412,7 +417,7 @@ func (m *taskManager) lockScanTask(se session.Session, task *cache.TTLTask, now if err != nil { return err } - if task.OwnerID != "" && !task.OwnerHBTime.Add(getTaskManagerHeartBeatExpireInterval()).Before(now) { + if task.OwnerID != "" && !task.OwnerHBTime.Add(2*getTaskManagerHeartBeatInterval()).Before(now) { return errors.WithStack(errAlreadyScheduled) }