From e39f6534f9e7f1566a26fdeef4730681defd2d74 Mon Sep 17 00:00:00 2001 From: Julien Duchesne Date: Thu, 10 Oct 2024 12:10:09 -0400 Subject: [PATCH] `ReusableGoroutinesPool`: Add protection to `Close` There is a data race in mimir where the pool can be closed and `Go` is still called on it: https://github.com/grafana/mimir/blob/0c6070552517bda1ccb97b8fc84ca50c591a71f7/pkg/distributor/distributor.go#L528-L532 Rather than handling this in Mimir, this can be handled in the util directly In Mimir, we'll be able to define this pool as `wp := concurrency.NewReusableGoroutinesPool(cfg.ReusableIngesterPushWorkers).WithClosedAction(concurrency.ErrorWhenClosed)` and we can check for errors wherever we call `Go` on it. --- concurrency/worker.go | 60 +++++++++++++++++++++++++++++++--- concurrency/worker_test.go | 66 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 122 insertions(+), 4 deletions(-) diff --git a/concurrency/worker.go b/concurrency/worker.go index f40f03348..b407d1f8c 100644 --- a/concurrency/worker.go +++ b/concurrency/worker.go @@ -1,5 +1,18 @@ package concurrency +import ( + "errors" + "sync" +) + +type ClosedAction int + +const ( + PanicWhenClosed ClosedAction = iota + ErrorWhenClosed + SpawnNewGoroutineWhenClosed +) + // NewReusableGoroutinesPool creates a new worker pool with the given size. // These workers will run the workloads passed through Go() calls. // If all workers are busy, Go() will spawn a new goroutine to run the workload. @@ -17,22 +30,61 @@ func NewReusableGoroutinesPool(size int) *ReusableGoroutinesPool { return p } +func (p *ReusableGoroutinesPool) WithClosedAction(action ClosedAction) *ReusableGoroutinesPool { + p.closedAction = action + return p +} + type ReusableGoroutinesPool struct { - jobs chan func() + jobsMu sync.Mutex + closed bool + closedAction ClosedAction + jobs chan func() } // Go will run the given function in a worker of the pool. -// If all workers are busy, Go() will spawn a new goroutine to run the workload. +// For retrocompatibility, errors will be ignored if the pool is closed. func (p *ReusableGoroutinesPool) Go(f func()) { + _ = p.GoErr(f) +} + +// GoErr will run the given function in a worker of the pool. +// If all workers are busy, Go() will spawn a new goroutine to run the workload. +// If the pool is closed, an error will be returned and the workload will be run or dropped according to the ClosedAction. +func (p *ReusableGoroutinesPool) GoErr(f func()) error { + p.jobsMu.Lock() + defer p.jobsMu.Unlock() + + if p.closed { + switch p.closedAction { + case PanicWhenClosed: + panic("tried to run a workload on a closed ReusableGoroutinesPool. Use a different ClosedAction to avoid this panic.") + case ErrorWhenClosed: + msg := "tried to run a workload on a closed ReusableGoroutinesPool, dropping the workload" + return errors.New(msg) + case SpawnNewGoroutineWhenClosed: + msg := "tried to run a workload on a closed ReusableGoroutinesPool, spawning a new goroutine to run the workload" + go f() + return errors.New(msg) + } + } + select { case p.jobs <- f: default: go f() } + + return nil } // Close stops the workers of the pool. -// No new Do() calls should be performed after calling Close(). +// No new Go() calls should be performed after calling Close(). // Close does NOT wait for all jobs to finish, it is the caller's responsibility to ensure that in the provided workloads. // Close is intended to be used in tests to ensure that no goroutines are leaked. -func (p *ReusableGoroutinesPool) Close() { close(p.jobs) } +func (p *ReusableGoroutinesPool) Close() { + p.jobsMu.Lock() + defer p.jobsMu.Unlock() + p.closed = true + close(p.jobs) +} diff --git a/concurrency/worker_test.go b/concurrency/worker_test.go index 338062055..5e7b035c4 100644 --- a/concurrency/worker_test.go +++ b/concurrency/worker_test.go @@ -4,10 +4,12 @@ import ( "regexp" "runtime" "strings" + "sync" "testing" "time" "github.com/stretchr/testify/require" + "go.uber.org/atomic" ) func TestReusableGoroutinesPool(t *testing.T) { @@ -59,3 +61,67 @@ func TestReusableGoroutinesPool(t *testing.T) { } t.Fatalf("expected %d goroutines after closing, got %d", 0, countGoroutines()) } + +func TestReusableGoroutinesPool_ClosedActionPanic(t *testing.T) { + w := NewReusableGoroutinesPool(2) + + runCount, panicked, _ := causePoolFailure(t, w, 10) + + require.NotZero(t, runCount, "expected at least one run") + require.Less(t, runCount, 10, "expected less than 10 runs") + require.True(t, panicked, "expected panic") +} + +func TestReusableGoroutinesPool_ClosedActionError(t *testing.T) { + w := NewReusableGoroutinesPool(2).WithClosedAction(ErrorWhenClosed) + + runCount, panicked, errors := causePoolFailure(t, w, 10) + + require.NotZero(t, runCount, "expected at least one run") + require.Less(t, runCount, 10, "expected less than 10 runs") + require.False(t, panicked, "expected no panic") + require.NotZero(t, len(errors), "expected errors") + require.Less(t, len(errors), 10, "expected less than 10 errors. Some workloads were submitted before close.") +} + +func TestReusableGoroutinesPool_ClosedActionSpawn(t *testing.T) { + w := NewReusableGoroutinesPool(2).WithClosedAction(SpawnNewGoroutineWhenClosed) + + runCount, panicked, errors := causePoolFailure(t, w, 10) + + require.Equal(t, runCount, 10, "expected all workloads to run") + require.False(t, panicked, "expected no panic") + require.NotZero(t, len(errors), "expected errors") + require.Less(t, len(errors), 10, "expected less than 10 errors. Some workloads were submitted before close.") +} + +func causePoolFailure(t *testing.T, w *ReusableGoroutinesPool, maxMsgCount int) (runCount int, panicked bool, errors []error) { + t.Helper() + + var runCountAtomic atomic.Int32 + + var testWG sync.WaitGroup + testWG.Add(1) + go func() { + defer testWG.Done() + defer func() { + if r := recover(); r != nil { + panicked = true + } + }() + for i := 0; i < maxMsgCount; i++ { + err := w.GoErr(func() { + runCountAtomic.Add(1) + }) + if err != nil { + errors = append(errors, err) + } + time.Sleep(10 * time.Millisecond) + } + }() + time.Sleep(10 * time.Millisecond) + w.Close() // close the pool + testWG.Wait() // wait for the test to finish + + return int(runCountAtomic.Load()), panicked, errors +}