diff --git a/iter/iter.go b/iter/iter.go index 124b4f9..4f22e43 100644 --- a/iter/iter.go +++ b/iter/iter.go @@ -1,10 +1,11 @@ package iter import ( + "context" "runtime" "sync/atomic" - "github.com/sourcegraph/conc" + "github.com/sourcegraph/conc/pool" ) // defaultMaxGoroutines returns the default maximum number of @@ -57,29 +58,69 @@ func ForEachIdx[T any](input []T, f func(int, *T)) { Iterator[T]{}.ForEachIdx(in // ForEachIdx is the same as ForEach except it also provides the // index of the element to the callback. func (iter Iterator[T]) ForEachIdx(input []T, f func(int, *T)) { + _ = iter.ForEachIdxCtx(context.Background(), input, func(_ context.Context, idx int, input *T) error { + f(idx, input) + return nil + }) +} + +// ForEachCtx is the same as ForEach except it also accepts a context +// that it uses to manages the execution of tasks. +// The context is cancelled on task failure and the first error is returned. +func ForEachCtx[T any](ctx context.Context, input []T, f func(context.Context, *T) error) error { + return Iterator[T]{}.ForEachCtx(ctx, input, f) +} + +// ForEachCtx is the same as ForEach except it also accepts a context +// that it uses to manages the execution of tasks. +// The context is cancelled on task failure and the first error is returned. +func (iter Iterator[T]) ForEachCtx(ctx context.Context, input []T, f func(context.Context, *T) error) error { + return iter.ForEachIdxCtx(ctx, input, func(innerctx context.Context, _ int, input *T) error { + return f(innerctx, input) + }) +} + +// ForEachIdxCtx is the same as ForEachIdx except it also accepts a context +// that it uses to manages the execution of tasks. +// The context is cancelled on task failure and the first error is returned. +func ForEachIdxCtx[T any](ctx context.Context, input []T, f func(context.Context, int, *T) error) error { + return Iterator[T]{}.ForEachIdxCtx(ctx, input, f) +} + +// ForEachIdxCtx is the same as ForEachIdx except it also accepts a context +// that it uses to manages the execution of tasks. +// The context is cancelled on task failure and the first error is returned. +func (iter Iterator[T]) ForEachIdxCtx(ctx context.Context, input []T, f func(context.Context, int, *T) error) error { if iter.MaxGoroutines == 0 { // iter is a value receiver and is hence safe to mutate iter.MaxGoroutines = defaultMaxGoroutines() } numInput := len(input) - if iter.MaxGoroutines > numInput { + if iter.MaxGoroutines > numInput && numInput > 0 { // No more concurrent tasks than the number of input items. iter.MaxGoroutines = numInput } var idx atomic.Int64 // Create the task outside the loop to avoid extra closure allocations. - task := func() { + task := func(innerctx context.Context) error { i := int(idx.Add(1) - 1) - for ; i < numInput; i = int(idx.Add(1) - 1) { - f(i, &input[i]) + for ; i < numInput && innerctx.Err() == nil; i = int(idx.Add(1) - 1) { + if err := f(innerctx, i, &input[i]); err != nil { + return err + } } + return innerctx.Err() // nil if the context was never cancelled } - var wg conc.WaitGroup + runner := pool.New(). + WithContext(ctx). + WithCancelOnError(). + WithFirstError(). + WithMaxGoroutines(iter.MaxGoroutines) for i := 0; i < iter.MaxGoroutines; i++ { - wg.Go(task) + runner.Go(task) } - wg.Wait() + return runner.Wait() } diff --git a/iter/iter_test.go b/iter/iter_test.go index 48fc8bb..12ad580 100644 --- a/iter/iter_test.go +++ b/iter/iter_test.go @@ -1,6 +1,8 @@ package iter_test import ( + "context" + "errors" "fmt" "strconv" "sync/atomic" @@ -72,16 +74,18 @@ func TestIterator(t *testing.T) { }) } -func TestForEachIdx(t *testing.T) { +func TestForEachIdxCtx(t *testing.T) { t.Parallel() + bgctx := context.Background() t.Run("empty", func(t *testing.T) { t.Parallel() f := func() { ints := []int{} - iter.ForEachIdx(ints, func(i int, val *int) { + err := iter.ForEachIdxCtx(bgctx, ints, func(ctx context.Context, i int, val *int) error { panic("this should never be called") }) + require.NoError(t, err) } require.NotPanics(t, f) }) @@ -90,9 +94,10 @@ func TestForEachIdx(t *testing.T) { t.Parallel() f := func() { ints := []int{1} - iter.ForEachIdx(ints, func(i int, val *int) { - panic("super bad thing happened") - }) + _ = iter.ForEachIdxCtx(bgctx, ints, + func(ctx context.Context, i int, val *int) error { + panic("super bad thing happened") + }) } require.Panics(t, f) }) @@ -100,23 +105,46 @@ func TestForEachIdx(t *testing.T) { t.Run("mutating inputs is fine", func(t *testing.T) { t.Parallel() ints := []int{1, 2, 3, 4, 5} - iter.ForEachIdx(ints, func(i int, val *int) { + err := iter.ForEachIdxCtx(bgctx, ints, func(ctx context.Context, i int, val *int) error { *val += 1 + return nil }) require.Equal(t, []int{2, 3, 4, 5, 6}, ints) + require.NoError(t, err) }) t.Run("huge inputs", func(t *testing.T) { t.Parallel() ints := make([]int, 10000) - iter.ForEachIdx(ints, func(i int, val *int) { + err := iter.ForEachIdxCtx(bgctx, ints, func(ctx context.Context, i int, val *int) error { *val = i + return nil }) expected := make([]int, 10000) for i := 0; i < 10000; i++ { expected[i] = i } require.Equal(t, expected, ints) + require.NoError(t, err) + }) + + err1 := errors.New("error1") + err2 := errors.New("error2") + + t.Run("first error is propagated", func(t *testing.T) { + t.Parallel() + ints := []int{1, 2, 3, 4, 5} + err := iter.ForEachIdxCtx(bgctx, ints, func(ctx context.Context, i int, val *int) error { + if *val == 3 { + return err1 + } + if *val == 4 { + return err2 + } + return nil + }) + require.ErrorIs(t, err, err1) + require.NotErrorIs(t, err, err2) }) } @@ -168,6 +196,41 @@ func TestForEach(t *testing.T) { }) } +func TestForEachCtx(t *testing.T) { + t.Parallel() + + bgctx := context.Background() + t.Run("mutating inputs is fine", func(t *testing.T) { + t.Parallel() + ints := []int{1, 2, 3, 4, 5} + err := iter.ForEachCtx(bgctx, ints, func(ctx context.Context, val *int) error { + *val += 1 + return nil + }) + require.Equal(t, []int{2, 3, 4, 5, 6}, ints) + require.NoError(t, err) + }) + + err1 := errors.New("error1") + err2 := errors.New("error2") + + t.Run("first error is propagated", func(t *testing.T) { + t.Parallel() + ints := []int{1, 2, 3, 4, 5} + err := iter.ForEachCtx(bgctx, ints, func(ctx context.Context, val *int) error { + if *val == 3 { + return err1 + } + if *val == 4 { + return err2 + } + return nil + }) + require.ErrorIs(t, err, err1) + require.NotErrorIs(t, err, err2) + }) +} + func BenchmarkForEach(b *testing.B) { for _, count := range []int{0, 1, 8, 100, 1000, 10000, 100000} { b.Run(strconv.Itoa(count), func(b *testing.B) { diff --git a/iter/map.go b/iter/map.go index af8c3b2..4373ddc 100644 --- a/iter/map.go +++ b/iter/map.go @@ -1,6 +1,7 @@ package iter import ( + "context" "errors" "sync" ) @@ -24,9 +25,8 @@ func Map[T, R any](input []T, f func(*T) R) []R { // // Map uses up to the configured Mapper's maximum number of goroutines. func (m Mapper[T, R]) Map(input []T, f func(*T) R) []R { - res := make([]R, len(input)) - Iterator[T](m).ForEachIdx(input, func(i int, t *T) { - res[i] = f(t) + res, _ := m.MapCtx(context.Background(), input, func(_ context.Context, t *T) (R, error) { + return f(t), nil }) return res } @@ -46,18 +46,39 @@ func MapErr[T, R any](input []T, f func(*T) (R, error)) ([]R, error) { // Map uses up to the configured Mapper's maximum number of goroutines. func (m Mapper[T, R]) MapErr(input []T, f func(*T) (R, error)) ([]R, error) { var ( - res = make([]R, len(input)) errMux sync.Mutex errs []error ) - Iterator[T](m).ForEachIdx(input, func(i int, t *T) { - var err error - res[i], err = f(t) + // MapErr handles its own errors by accumulating them as a multierror, ignoring the error from MapCtx which is only the first error + res, _ := m.MapCtx(context.Background(), input, func(ctx context.Context, t *T) (R, error) { + ires, err := f(t) if err != nil { errMux.Lock() errs = append(errs, err) errMux.Unlock() } + return ires, nil }) return res, errors.Join(errs...) } + +// MapCtx is the same as Map except it also accepts a context +// that it uses to manages the execution of tasks. +// The context is cancelled on task failure and the first error is returned. +func MapCtx[T, R any](ctx context.Context, input []T, f func(context.Context, *T) (R, error)) ([]R, error) { + return Mapper[T, R]{}.MapCtx(ctx, input, f) +} + +// MapCtx is the same as Map except it also accepts a context +// that it uses to manages the execution of tasks. +// The context is cancelled on task failure and the first error is returned. +func (m Mapper[T, R]) MapCtx(ctx context.Context, input []T, f func(context.Context, *T) (R, error)) ([]R, error) { + var ( + res = make([]R, len(input)) + ) + return res, Iterator[T](m).ForEachIdxCtx(ctx, input, func(innerctx context.Context, i int, t *T) error { + var err error + res[i], err = f(innerctx, t) + return err + }) +} diff --git a/iter/map_test.go b/iter/map_test.go index 5749e9a..4fc4674 100644 --- a/iter/map_test.go +++ b/iter/map_test.go @@ -1,6 +1,7 @@ package iter_test import ( + "context" "errors" "fmt" "testing" @@ -132,7 +133,7 @@ func TestMapErr(t *testing.T) { }) err1 := errors.New("error1") - err2 := errors.New("error1") + err2 := errors.New("error2") t.Run("error is propagated", func(t *testing.T) { t.Parallel() @@ -148,7 +149,7 @@ func TestMapErr(t *testing.T) { require.Equal(t, []int{1, 2, 3, 4, 5}, ints) }) - t.Run("multiple errors are propagated", func(t *testing.T) { + t.Run("first errors are propagated", func(t *testing.T) { t.Parallel() ints := []int{1, 2, 3, 4, 5} res, err := iter.MapErr(ints, func(val *int) (int, error) { @@ -166,17 +167,106 @@ func TestMapErr(t *testing.T) { require.Equal(t, []int{2, 3, 0, 0, 6}, res) require.Equal(t, []int{1, 2, 3, 4, 5}, ints) }) +} + +func TestMapCtx(t *testing.T) { + t.Parallel() + + bgctx := context.Background() + t.Run("empty", func(t *testing.T) { + t.Parallel() + f := func() { + ints := []int{} + res, err := iter.MapCtx(bgctx, ints, func(ctx context.Context, val *int) (int, error) { + panic("this should never be called") + }) + require.NoError(t, err) + require.Equal(t, ints, res) + } + require.NotPanics(t, f) + }) + + t.Run("panic is propagated", func(t *testing.T) { + t.Parallel() + f := func() { + ints := []int{1} + _, _ = iter.MapCtx(bgctx, ints, func(ctx context.Context, val *int) (int, error) { + panic("super bad thing happened") + }) + } + require.Panics(t, f) + }) + + t.Run("mutating inputs is fine, though not recommended", func(t *testing.T) { + t.Parallel() + ints := []int{1, 2, 3, 4, 5} + res, err := iter.MapCtx(bgctx, ints, func(ctx context.Context, val *int) (int, error) { + *val += 1 + return 0, nil + }) + require.NoError(t, err) + require.Equal(t, []int{2, 3, 4, 5, 6}, ints) + require.Equal(t, []int{0, 0, 0, 0, 0}, res) + }) + + t.Run("basic increment", func(t *testing.T) { + t.Parallel() + ints := []int{1, 2, 3, 4, 5} + res, err := iter.MapCtx(bgctx, ints, func(ctx context.Context, val *int) (int, error) { + return *val + 1, nil + }) + require.NoError(t, err) + require.Equal(t, []int{2, 3, 4, 5, 6}, res) + require.Equal(t, []int{1, 2, 3, 4, 5}, ints) + }) + + err1 := errors.New("error1") + err2 := errors.New("error2") + + t.Run("error is propagated", func(t *testing.T) { + t.Parallel() + ints := []int{1, 2, 3, 4, 5} + res, err := iter.MapCtx(bgctx, ints, func(ctx context.Context, val *int) (int, error) { + if *val == 3 { + return 0, err1 + } + return *val + 1, nil + }) + require.ErrorIs(t, err, err1) + require.NotErrorIs(t, err, err2) + require.Equal(t, []int{2, 3, 0, 0, 0}, res) + require.Equal(t, []int{1, 2, 3, 4, 5}, ints) + }) + + t.Run("first error is propagated", func(t *testing.T) { + t.Parallel() + ints := []int{1, 2, 3, 4, 5} + res, err := iter.MapCtx(bgctx, ints, func(ctx context.Context, val *int) (int, error) { + if *val == 3 { + return 0, err1 + } + if *val == 4 { + return 0, err2 + } + return *val + 1, nil + }) + require.ErrorIs(t, err, err1) + require.NotErrorIs(t, err, err2) + require.Equal(t, []int{2, 3, 0, 0, 0}, res) + require.Equal(t, []int{1, 2, 3, 4, 5}, ints) + }) t.Run("huge inputs", func(t *testing.T) { t.Parallel() ints := make([]int, 10000) - res := iter.Map(ints, func(val *int) int { - return 1 + res, err := iter.MapCtx(bgctx, ints, func(ctx context.Context, val *int) (int, error) { + return 1, nil }) expected := make([]int, 10000) for i := 0; i < 10000; i++ { expected[i] = 1 } require.Equal(t, expected, res) + require.NoError(t, err) }) }