diff --git a/internal/sequencer/script/script_test.go b/internal/sequencer/script/script_test.go index a4fc0a31..ee5d75dc 100644 --- a/internal/sequencer/script/script_test.go +++ b/internal/sequencer/script/script_test.go @@ -252,26 +252,12 @@ func TestUserScriptSequencer(t *testing.T) { } } -func requireStagingTblRowCnt( - ctx *stopper.Context, stagingTbl ident.Table, expectedCnt int, pool *types.StagingPool, -) error { - var res int - if err := pool.QueryRow(ctx, "SELECT COUNT(*) FROM %s", stagingTbl.Table()).Scan(&res); err != nil { - return errors.Wrapf(err, "failed to query rows for staging table %s", stagingTbl.Table()) - } - if res != expectedCnt { - return fmt.Errorf("expected %d rows, got %d", expectedCnt, res) - } - log.Infof("finished verify rows for staging table %s", stagingTbl.Table()) - return nil -} - func testUserScriptSequencer(t *testing.T, baseMode switcher.Mode) { r := require.New(t) // Create a basic test fixture. // TODO(janexing): I just randomly set the refresh delay. - fixture, err := all.NewFixtureWithRefresh(t, all.RefreshDelay(500*time.Millisecond)) + fixture, err := all.NewFixture(t) r.NoError(err) ctx := fixture.Context @@ -403,26 +389,27 @@ api.configureTable("t_2", { }, }, )) - // t.Logf("pushed for k %d", i) } - r.NoError(retryAttempt.Do(ctx, func() error { - for _, tgt := range tgts { - res, err := base.GetRowCount(ctx, fixture.StagingPool, stage.StagingTable(fixture.StagingDB.Schema(), tgt)) - if err != nil { - return err - } - if res != numEmits { - return fmt.Errorf("expected %d rows for table %s, got %d", numEmits, tgt, res) + if baseMode != switcher.ModeImmediate { + r.NoError(retryAttempt.Do(ctx, func() error { + for _, tgt := range tgts { + res, err := base.GetRowCountWithPredicate(ctx, fixture.StagingPool, stage.StagingTable(fixture.StagingDB.Schema(), tgt), fmt.Sprintf("nanos <= %d", numEmits)) + if err != nil { + return err + } + if res != numEmits { + return fmt.Errorf("expected %d rows for table %s, got %d", numEmits, tgt, res) + } } - } - return nil - }, func(err error) { - t.Logf("retrying checking staging tbl count: %s", err.Error()) - })) + return nil + }, func(err error) { + t.Logf("retrying checking staging tbl count: %s", err.Error()) + })) - require.NoError(t, checkpointGroup.Advance(ctx, checkpointGroup.TableGroup().Name, endTime)) - t.Logf("advanced bounds to %s", endTime) + require.NoError(t, checkpointGroup.Advance(ctx, checkpointGroup.TableGroup().Name, endTime)) + t.Logf("advanced bounds to %s", endTime) + } // Wait for (async) replication. for { @@ -485,11 +472,24 @@ api.configureTable("t_2", { log.Infof("[fakeSrc] pushed delete for %d", i+1) } - fakeSrc.PushTemporalBatch(ctx, &types.TemporalBatch{ - Time: endTime, - }) + if baseMode != switcher.ModeImmediate { + r.NoError(retryAttempt.Do(ctx, func() error { + res, err := base.GetRowCountWithPredicate(ctx, fixture.StagingPool, stage.StagingTable(fixture.StagingDB.Schema(), tgts[0]), "nanos >= 1000") + if err != nil { + return err + } + if res != numEmits { + return fmt.Errorf("expected %d rows for table %s, got %d", numEmits, tgts[0], res) + } + return nil + }, func(err error) { + t.Logf("retrying checking staging tbl count: %s", err.Error()) + })) - t.Logf("pushed for sentinel") + } + + require.NoError(t, checkpointGroup.Advance(ctx, checkpointGroup.TableGroup().Name, endTime)) + t.Logf("advanced bounds to %s", endTime) // Wait for (async) replication for the tables. progress, progressMade = stats.Get()