Skip to content

Commit

Permalink
Show completion estimate during backfill (#567)
Browse files Browse the repository at this point in the history
Instead of only showing the number of rows backfills, show an estimate
of the total number of tows completed as a percentage.

For example:

```
1500 records complete... (12.34%)
```

It attempts to estimate the total number of rows but will fall back to a
full scan if the number of rows estimated is zero.

Closes #492
  • Loading branch information
ryanslade authored Jan 8, 2025
1 parent e6e2ee5 commit 952fb23
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 5 deletions.
16 changes: 14 additions & 2 deletions cmd/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,20 @@ func runMigrationFromFile(ctx context.Context, m *roll.Roll, fileName string, co

func runMigration(ctx context.Context, m *roll.Roll, migration *migrations.Migration, complete bool) error {
sp, _ := pterm.DefaultSpinner.WithText("Starting migration...").Start()
cb := func(n int64) {
sp.UpdateText(fmt.Sprintf("%d records complete...", n))
cb := func(n int64, total int64) {
var percent float64
if total > 0 {
percent = float64(n) / float64(total) * 100
}
if percent > 100 {
// This can happen if we're on the last batch
percent = 100
}
if total > 0 {
sp.UpdateText(fmt.Sprintf("%d records complete... (%.2f%%)", n, percent))
} else {
sp.UpdateText(fmt.Sprintf("%d records complete...", n))
}
}

err := m.Start(ctx, migration, cb)
Expand Down
34 changes: 34 additions & 0 deletions pkg/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ const (

type DB interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
WithRetryableTransaction(ctx context.Context, f func(context.Context, *sql.Tx) error) error
Close() error
}
Expand Down Expand Up @@ -52,6 +53,28 @@ func (db *RDB) ExecContext(ctx context.Context, query string, args ...interface{
}
}

// QueryContext wraps sql.DB.QueryContext, retrying queries on lock_timeout errors.
func (db *RDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
b := backoff.New(maxBackoffDuration, backoffInterval)

for {
rows, err := db.DB.QueryContext(ctx, query, args...)
if err == nil {
return rows, nil
}

pqErr := &pq.Error{}
if errors.As(err, &pqErr) && pqErr.Code == lockNotAvailableErrorCode {
if err := sleepCtx(ctx, b.Duration()); err != nil {
return nil, err
}
continue
}

return nil, err
}
}

// WithRetryableTransaction runs `f` in a transaction, retrying on lock_timeout errors.
func (db *RDB) WithRetryableTransaction(ctx context.Context, f func(context.Context, *sql.Tx) error) error {
b := backoff.New(maxBackoffDuration, backoffInterval)
Expand Down Expand Up @@ -95,3 +118,14 @@ func sleepCtx(ctx context.Context, d time.Duration) error {
return nil
}
}

// ScanFirstValue is a helper function to scan the first value with the assumption that Rows contains
// a single row with a single value.
func ScanFirstValue[T any](rows *sql.Rows, dest *T) error {
if rows.Next() {
if err := rows.Scan(dest); err != nil {
return err
}
}
return rows.Err()
}
48 changes: 48 additions & 0 deletions pkg/db/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/xataio/pgroll/internal/testutils"
Expand Down Expand Up @@ -61,6 +62,53 @@ func TestExecContextWhenContextCancelled(t *testing.T) {
})
}

func TestQueryContext(t *testing.T) {
t.Parallel()

testutils.WithConnectionToContainer(t, func(conn *sql.DB, connStr string) {
ctx := context.Background()
// create a table on which an exclusive lock is held for 2 seconds
setupTableLock(t, connStr, 2*time.Second)

// set the lock timeout to 100ms
ensureLockTimeout(t, conn, 100)

// execute a query that should retry until the lock is released
rdb := &db.RDB{DB: conn}
rows, err := rdb.QueryContext(ctx, "SELECT COUNT(*) FROM test")
require.NoError(t, err)

var count int
err = db.ScanFirstValue(rows, &count)
assert.NoError(t, err)
assert.Equal(t, 0, count)
})
}

func TestQueryContextWhenContextCancelled(t *testing.T) {
t.Parallel()

testutils.WithConnectionToContainer(t, func(conn *sql.DB, connStr string) {
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)

// create a table on which an exclusive lock is held for 2 seconds
setupTableLock(t, connStr, 2*time.Second)

// set the lock timeout to 100ms
ensureLockTimeout(t, conn, 100)

// execute a query that should retry until the lock is released
rdb := &db.RDB{DB: conn}

// Cancel the context before the lock times out
go time.AfterFunc(500*time.Millisecond, cancel)

_, err := rdb.QueryContext(ctx, "SELECT COUNT(*) FROM test")
require.Errorf(t, err, "context canceled")
})
}

func TestWithRetryableTransaction(t *testing.T) {
t.Parallel()

Expand Down
47 changes: 46 additions & 1 deletion pkg/migrations/backfill.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,18 @@ func Backfill(ctx context.Context, conn db.DB, table *schema.Table, batchSize in
return BackfillNotPossibleError{Table: table.Name}
}

total, err := getRowCount(ctx, conn, table.Name)
if err != nil {
return fmt.Errorf("get row count for %q: %w", table.Name, err)
}

// Create a batcher for the table.
b := newBatcher(table, batchSize)

// Update each batch of rows, invoking callbacks for each one.
for batch := 0; ; batch++ {
for _, cb := range cbs {
cb(int64(batch * batchSize))
cb(int64(batch*batchSize), total)
}

if err := b.updateBatch(ctx, conn); err != nil {
Expand All @@ -55,6 +60,46 @@ func Backfill(ctx context.Context, conn db.DB, table *schema.Table, batchSize in
return nil
}

// getRowCount will attempt to get the row count for the given table. It first attempts to get an
// estimate and if that is zero, falls back to a full table scan.
func getRowCount(ctx context.Context, conn db.DB, tableName string) (int64, error) {
// Try and get estimated row count
var currentSchema string
rows, err := conn.QueryContext(ctx, "select current_schema()")
if err != nil {
return 0, fmt.Errorf("getting current schema: %w", err)
}
if err := db.ScanFirstValue(rows, &currentSchema); err != nil {
return 0, fmt.Errorf("scanning current schema: %w", err)
}

var total int64
rows, err = conn.QueryContext(ctx, `
SELECT n_live_tup AS estimate
FROM pg_stat_user_tables
WHERE schemaname = $1 AND relname = $2`, currentSchema, tableName)
if err != nil {
return 0, fmt.Errorf("getting row count estimate for %q: %w", tableName, err)
}
if err := db.ScanFirstValue(rows, &total); err != nil {
return 0, fmt.Errorf("scanning row count estimate for %q: %w", tableName, err)
}
if total > 0 {
return total, nil
}

// If the estimate is zero, fall back to full count
rows, err = conn.QueryContext(ctx, fmt.Sprintf(`SELECT count(*) from %s`, tableName))
if err != nil {
return 0, fmt.Errorf("getting row count for %q: %w", tableName, err)
}
if err := db.ScanFirstValue(rows, &total); err != nil {
return 0, fmt.Errorf("scanning row count for %q: %w", tableName, err)
}

return total, nil
}

// checkBackfill will return an error if the backfill operation is not supported.
func checkBackfill(table *schema.Table) error {
cols := getIdentityColumns(table)
Expand Down
2 changes: 1 addition & 1 deletion pkg/migrations/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"github.com/xataio/pgroll/pkg/schema"
)

type CallbackFn func(int64)
type CallbackFn func(done int64, total int64)

// Operation is an operation that can be applied to a schema
type Operation interface {
Expand Down
2 changes: 1 addition & 1 deletion pkg/roll/execute_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@ func TestCallbacksAreInvokedOnMigrationStart(t *testing.T) {

// Define a mock callback
invoked := false
cb := func(n int64) { invoked = true }
cb := func(n, total int64) { invoked = true }

// Start a migration that requires a backfill
err = mig.Start(ctx, &migrations.Migration{
Expand Down

0 comments on commit 952fb23

Please sign in to comment.