Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(iter): Added context accepting variants of Map & ForEach #114

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
57 changes: 49 additions & 8 deletions iter/iter.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package iter

import (
"context"
"runtime"
"sync/atomic"

"github.com/sourcegraph/conc"
"github.com/sourcegraph/conc/pool"
)

// defaultMaxGoroutines returns the default maximum number of
Expand Down Expand Up @@ -57,29 +58,69 @@ func ForEachIdx[T any](input []T, f func(int, *T)) { Iterator[T]{}.ForEachIdx(in
// ForEachIdx is the same as ForEach except it also provides the
// index of the element to the callback.
func (iter Iterator[T]) ForEachIdx(input []T, f func(int, *T)) {
_ = iter.ForEachIdxCtx(context.Background(), input, func(_ context.Context, idx int, input *T) error {
f(idx, input)
return nil
})
}

// ForEachCtx is the same as ForEach except it also accepts a context
// that it uses to manages the execution of tasks.
// The context is cancelled on task failure and the first error is returned
func ForEachCtx[T any](octx context.Context, input []T, f func(context.Context, *T) error) error {
return Iterator[T]{}.ForEachCtx(octx, input, f)
}

// ForEachCtx is the same as ForEach except it also accepts a context
// that it uses to manages the execution of tasks.
// The context is cancelled on task failure and the first error is returned
func (iter Iterator[T]) ForEachCtx(octx context.Context, input []T, f func(context.Context, *T) error) error {
return iter.ForEachIdxCtx(octx, input, func(ctx context.Context, _ int, input *T) error {
return f(ctx, input)
})
}
rkoehl05 marked this conversation as resolved.
Show resolved Hide resolved

// ForEachIdxCtx is the same as ForEachIdx except it also accepts a context
// that it uses to manages the execution of tasks.
// The context is cancelled on task failure and the first error is returned
func ForEachIdxCtx[T any](octx context.Context, input []T, f func(context.Context, int, *T) error) error {
return Iterator[T]{}.ForEachIdxCtx(octx, input, f)
}

// ForEachIdxCtx is the same as ForEachIdx except it also accepts a context
// that it uses to manages the execution of tasks.
// The context is cancelled on task failure and the first error is returned
func (iter Iterator[T]) ForEachIdxCtx(octx context.Context, input []T, f func(context.Context, int, *T) error) error {
if iter.MaxGoroutines == 0 {
// iter is a value receiver and is hence safe to mutate
iter.MaxGoroutines = defaultMaxGoroutines()
}

numInput := len(input)
if iter.MaxGoroutines > numInput {
if iter.MaxGoroutines > numInput && numInput > 0 {
// No more concurrent tasks than the number of input items.
iter.MaxGoroutines = numInput
}

var idx atomic.Int64
// Create the task outside the loop to avoid extra closure allocations.
task := func() {
task := func(ctx context.Context) error {
i := int(idx.Add(1) - 1)
for ; i < numInput; i = int(idx.Add(1) - 1) {
f(i, &input[i])
for ; i < numInput && ctx.Err() == nil; i = int(idx.Add(1) - 1) {
if err := f(ctx, i, &input[i]); err != nil {
return err
}
}
return ctx.Err() // nil if the context was never cancelled
}

var wg conc.WaitGroup
runner := pool.New().
WithContext(octx).
WithCancelOnError().
WithFirstError().
WithMaxGoroutines(iter.MaxGoroutines)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to not add a dependency on pool in the iter package. Because the iter package knows the number of inputs and the number of outputs in advance, it can be considerably more efficient than the pool package, which must work for an unbounded number of iterations.

I think it would be good to reconcile this PR with the patterns in #104. In particular, the FailFast flag should mean similar things iterating with ForEachErr and ForEachCtx.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The abstractions built into ContextPool offered a simple way to achieve the functionality I was looking for with regard to context cancellation without having to duplicate code. Both iter and pool used waitgroup underneath. I'm not too sure what efficiency gains are made by having two separate implementations but I'd be happy to switch it back to original setup.

Additionally, as a further argument for using ContextPool, the FailFast behavior seems like it could be easily implemented by using the new bool to optionally call WithFirstError() on the underlying pool. Right now I call that by default assuming the caller is only interested in the first error when using a context.

I could add the *Err variants as extensions to what I have here that just return an error but don't require a context.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the FailFast behavior seems like it could be easily implemented by using the new bool to optionally call WithFirstError() on the underlying pool

Related: #118

The abstractions built into ContextPool offered a simple way to achieve the functionality I was looking for with regard to context cancellation without having to duplicate code

You might be right. The simplicity of using a Pool to back the iter package is likely more valuable than the minor efficiency gains we get from knowing the size of the set in advance (which basically boil down to allocating in advance and not needing a mutex).

Let me noodle on the design a bit and get back to ya. I'll probably open a draft that unifies this with #104 and the pool package in general, maybe just replacing the Iterator structs with configured pools.

for i := 0; i < iter.MaxGoroutines; i++ {
wg.Go(task)
runner.Go(task)
}
wg.Wait()
return runner.Wait()
}
77 changes: 70 additions & 7 deletions iter/iter_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package iter

import (
"context"
"errors"
"fmt"
"strconv"
"sync/atomic"
Expand Down Expand Up @@ -70,16 +72,18 @@ func TestIterator(t *testing.T) {
})
}

func TestForEachIdx(t *testing.T) {
func TestForEachIdxCtx(t *testing.T) {
t.Parallel()

bgctx := context.Background()
t.Run("empty", func(t *testing.T) {
t.Parallel()
f := func() {
ints := []int{}
ForEachIdx(ints, func(i int, val *int) {
err := ForEachIdxCtx(bgctx, ints, func(ctx context.Context, i int, val *int) error {
panic("this should never be called")
})
require.NoError(t, err)
}
require.NotPanics(t, f)
})
Expand All @@ -88,33 +92,57 @@ func TestForEachIdx(t *testing.T) {
t.Parallel()
f := func() {
ints := []int{1}
ForEachIdx(ints, func(i int, val *int) {
panic("super bad thing happened")
})
ForEachIdxCtx(bgctx, ints,
func(ctx context.Context, i int, val *int) error {
panic("super bad thing happened")
})
}
require.Panics(t, f)
})

t.Run("mutating inputs is fine", func(t *testing.T) {
t.Parallel()
ints := []int{1, 2, 3, 4, 5}
ForEachIdx(ints, func(i int, val *int) {
err := ForEachIdxCtx(bgctx, ints, func(ctx context.Context, i int, val *int) error {
*val += 1
return nil
})
require.Equal(t, []int{2, 3, 4, 5, 6}, ints)
require.NoError(t, err)
})

t.Run("huge inputs", func(t *testing.T) {
t.Parallel()
ints := make([]int, 10000)
ForEachIdx(ints, func(i int, val *int) {
err := ForEachIdxCtx(bgctx, ints, func(ctx context.Context, i int, val *int) error {
*val = i
return nil
})
expected := make([]int, 10000)
for i := 0; i < 10000; i++ {
expected[i] = i
}
require.Equal(t, expected, ints)
require.NoError(t, err)
})

err1 := errors.New("error1")
err2 := errors.New("error2")

t.Run("first error is propagated", func(t *testing.T) {
t.Parallel()
ints := []int{1, 2, 3, 4, 5}
err := ForEachIdxCtx(bgctx, ints, func(ctx context.Context, i int, val *int) error {
if *val == 3 {
return err1
}
if *val == 4 {
return err2
}
return nil
})
require.ErrorIs(t, err, err1)
require.NotErrorIs(t, err, err2)
})
}

Expand Down Expand Up @@ -166,6 +194,41 @@ func TestForEach(t *testing.T) {
})
}

func TestForEachCtx(t *testing.T) {
t.Parallel()

bgctx := context.Background()
t.Run("mutating inputs is fine", func(t *testing.T) {
t.Parallel()
ints := []int{1, 2, 3, 4, 5}
err := ForEachCtx(bgctx, ints, func(ctx context.Context, val *int) error {
*val += 1
return nil
})
require.Equal(t, []int{2, 3, 4, 5, 6}, ints)
require.NoError(t, err)
})

err1 := errors.New("error1")
err2 := errors.New("error2")

t.Run("first error is propagated", func(t *testing.T) {
t.Parallel()
ints := []int{1, 2, 3, 4, 5}
err := ForEachCtx(bgctx, ints, func(ctx context.Context, val *int) error {
if *val == 3 {
return err1
}
if *val == 4 {
return err2
}
return nil
})
require.ErrorIs(t, err, err1)
require.NotErrorIs(t, err, err2)
})
}

func BenchmarkForEach(b *testing.B) {
for _, count := range []int{0, 1, 8, 100, 1000, 10000, 100000} {
b.Run(strconv.Itoa(count), func(b *testing.B) {
Expand Down
35 changes: 28 additions & 7 deletions iter/map.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package iter

import (
"context"
"sync"

"github.com/sourcegraph/conc/internal/multierror"
Expand All @@ -25,9 +26,8 @@ func Map[T, R any](input []T, f func(*T) R) []R {
//
// Map uses up to the configured Mapper's maximum number of goroutines.
func (m Mapper[T, R]) Map(input []T, f func(*T) R) []R {
res := make([]R, len(input))
Iterator[T](m).ForEachIdx(input, func(i int, t *T) {
res[i] = f(t)
res, _ := m.MapErr(input, func(t *T) (R, error) {
return f(t), nil
})
return res
}
Expand All @@ -47,19 +47,40 @@ func MapErr[T, R any](input []T, f func(*T) (R, error)) ([]R, error) {
// Map uses up to the configured Mapper's maximum number of goroutines.
func (m Mapper[T, R]) MapErr(input []T, f func(*T) (R, error)) ([]R, error) {
var (
res = make([]R, len(input))
errMux sync.Mutex
errs error
)
Iterator[T](m).ForEachIdx(input, func(i int, t *T) {
var err error
res[i], err = f(t)
// MapErr handles its own errors by accumulating them as a multierror, ignoring the error from MapErrCtx
res, _ := m.MapErrCtx(context.Background(), input, func(ctx context.Context, t *T) (R, error) {
rkoehl05 marked this conversation as resolved.
Show resolved Hide resolved
ires, err := f(t)
if err != nil {
errMux.Lock()
// TODO: use stdlib errors once multierrors land in go 1.20
errs = multierror.Join(errs, err)
errMux.Unlock()
}
return ires, nil
})
return res, errs
}

// MapErrCtx is the same as MapErr except it also accepts a context
// that it uses to manages the execution of tasks.
// The context is cancelled on task failure and the first error is returned
func MapErrCtx[T, R any](octx context.Context, input []T, f func(context.Context, *T) (R, error)) ([]R, error) {
rkoehl05 marked this conversation as resolved.
Show resolved Hide resolved
return Mapper[T, R]{}.MapErrCtx(octx, input, f)
}

// MapErrCtx is the same as MapErr except it also accepts a context
// that it uses to manages the execution of tasks.
// The context is cancelled on task failure and the first error is returned
func (m Mapper[T, R]) MapErrCtx(octx context.Context, input []T, f func(context.Context, *T) (R, error)) ([]R, error) {
var (
res = make([]R, len(input))
)
return res, Iterator[T](m).ForEachIdxCtx(octx, input, func(ctx context.Context, i int, t *T) error {
var err error
res[i], err = f(ctx, t)
return err
})
}
Loading