diff --git a/breaker.go b/breaker.go index 767fadd..454ce24 100644 --- a/breaker.go +++ b/breaker.go @@ -1,34 +1,25 @@ package hoglet import ( + "context" "math" - "sync" "sync/atomic" "time" ) -// untypedCircuit is used to avoid type annotations when implementing a breaker. +// untypedCircuit is used to avoid type annotations when implementing breakers. type untypedCircuit interface { stateForCall() State setOpenedAt(int64) } +// observer is used to observe the result of a single wrapped call through the circuit breaker. type observer interface { - // observe is called after the wrapped function returns. If [Circuit.Do] returns a non-nil [Observable], it will be + // observe is called after the wrapped function returns. If [Circuit.Do] returns a non-nil [observer], it will be // called exactly once. observe(failure bool) } -// newObservableCall creates a new [Observable] that ensures it can only be observed a single time. -func newObservableCall(f func(bool)) observer { - o := sync.Once{} - return observableCall(func(failure bool) { - o.Do(func() { - f(failure) - }) - }) -} - // observableCall tracks a single call through the breaker. // It should be instantiated via [newObservableCall] to ensure the observer is only called once. type observableCall func(bool) @@ -80,16 +71,16 @@ func (e *EWMABreaker) connect(c untypedCircuit) { e.circuit = c } -func (e *EWMABreaker) observerForCall() observer { +func (e *EWMABreaker) observerForCall(_ context.Context) (observer, error) { state := e.circuit.stateForCall() if state == StateOpen { - return nil + return nil, ErrCircuitOpen } - return newObservableCall(func(failure bool) { + return observableCall(func(failure bool) { e.observe(state == StateHalfOpen, failure) - }) + }), nil } func (e *EWMABreaker) observe(halfOpen, failure bool) { @@ -164,16 +155,16 @@ func (s *SlidingWindowBreaker) connect(c untypedCircuit) { s.circuit = c } -func (s *SlidingWindowBreaker) observerForCall() observer { +func (s *SlidingWindowBreaker) observerForCall(_ context.Context) (observer, error) { state := s.circuit.stateForCall() if state == StateOpen { - return nil + return nil, ErrCircuitOpen } - return newObservableCall(func(failure bool) { + return observableCall(func(failure bool) { s.observe(state == StateHalfOpen, failure) - }) + }), nil } func (s *SlidingWindowBreaker) observe(halfOpen, failure bool) { diff --git a/breaker_test.go b/breaker_test.go index 25ae0ce..596b5b8 100644 --- a/breaker_test.go +++ b/breaker_test.go @@ -1,6 +1,7 @@ package hoglet import ( + "context" "math/rand" "testing" "time" @@ -12,16 +13,19 @@ import ( func TestEWMABreaker_zero_value_does_not_open(t *testing.T) { b := &EWMABreaker{} b.connect(&mockCircuit{}) - o := b.observerForCall() - require.NotNil(t, o) + o, err := b.observerForCall(context.TODO()) + require.NoError(t, err) o.observe(true) - assert.NotNil(t, b.observerForCall()) + _, err = b.observerForCall(context.TODO()) + assert.NoError(t, err) } func TestEWMABreaker_zero_value_does_not_panic(t *testing.T) { b := &EWMABreaker{} b.connect(&mockCircuit{}) - assert.NotPanics(t, func() { b.observerForCall() }) + assert.NotPanics(t, func() { + b.observerForCall(context.TODO()) // nolint: errcheck // we are just interested in the panic + }) } func TestBreaker_Observe_State(t *testing.T) { @@ -193,7 +197,7 @@ func TestBreaker_Observe_State(t *testing.T) { c.setState(StateHalfOpen) } failure := s.failureFunc(i) - o := b.observerForCall() + o, _ := b.observerForCall(context.TODO()) // nolint: errcheck // always observe switch b := b.(type) { case *EWMABreaker: @@ -213,10 +217,11 @@ func TestBreaker_Observe_State(t *testing.T) { } } } + _, err := b.observerForCall(context.TODO()) if tt.wantCall { - assert.NotNil(t, b.observerForCall()) + assert.NoError(t, err) } else { - assert.Nil(t, b.observerForCall()) + assert.ErrorIs(t, err, ErrCircuitOpen) } }) } diff --git a/error.go b/error.go index 49075e8..5a5e794 100644 --- a/error.go +++ b/error.go @@ -11,4 +11,8 @@ func (b Error) Error() string { return "hoglet: " + b.msg } -var ErrCircuitOpen = Error{msg: "breaker is open"} +var ( + ErrCircuitOpen = Error{msg: "breaker is open"} + ErrConcurrencyLimitReached = Error{msg: "concurrency limit reached"} + ErrWaitingForSlot = Error{msg: "waiting for slot"} +) diff --git a/go.mod b/go.mod index dd8e05d..21119b1 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/kr/pretty v0.3.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.9.0 // indirect + golang.org/x/sync v0.4.0 gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index c20e4a0..5f99b3f 100644 --- a/go.sum +++ b/go.sum @@ -21,6 +21,8 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ= +golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= diff --git a/hoglet.go b/hoglet.go index 856889f..66e98b3 100644 --- a/hoglet.go +++ b/hoglet.go @@ -3,6 +3,7 @@ package hoglet import ( "context" "errors" + "sync" "sync/atomic" "time" ) @@ -12,8 +13,7 @@ import ( // // A zero Circuit will panic, analogous to calling a nil function variable. Initialize with [NewCircuit]. type Circuit[IN, OUT any] struct { - f BreakableFunc[IN, OUT] - breaker Breaker + f BreakableFunc[IN, OUT] options // State @@ -29,8 +29,15 @@ type options struct { // halfOpenDelay is the duration the circuit will stay open before switching to the half-open state, where a // limited (~1) amount of calls are allowed that - if successful - may re-close the breaker. halfOpenDelay time.Duration + + // Usually, this is implemented by the breaker, but it can be overridden for testing purposes. + observerForCall observerFactory } +// observerFactory is a function that returns one observer for each call going through the circuit. +// It is used analogously to a http.Handler, allowing different plugins to "wrap" each execution. +type observerFactory func(context.Context) (observer, error) + // Breaker is the interface implemented by the different breakers, responsible for actually opening the circuit. // Each implementation behaves differently when deciding whether to open the breaker upon failure. type Breaker interface { @@ -38,27 +45,38 @@ type Breaker interface { connect(untypedCircuit) // observerForCall returns an observer for the incoming call. - // It is called exactly once per call to [Breaker.Do], before calling the wrapped function. + // It is called exactly once per call to [Circuit.Call], before calling the wrapped function. // If the breaker is open, it returns nil. - // If the breaker is closed, it returns a non-nil [Observable] that will be used to observe the result of the call. - observerForCall() observer + // If the breaker is closed, it returns a non-nil [observer] that will be used to observe the result of the call. + observerForCall(context.Context) (observer, error) } // BreakableFunc is the type of the function wrapped by a Breaker. type BreakableFunc[IN, OUT any] func(context.Context, IN) (OUT, error) +// dedupObservableCall wraps an [observer] ensuring it can only be observed a single time. +func dedupObservableCall(obs observer) observer { + o := sync.Once{} + return observableCall(func(failure bool) { + o.Do(func() { + obs.observe(failure) + }) + }) +} + // NewCircuit instantiates a new [Circuit] that wraps the provided function. See [Circuit.Call] for calling semantics. // A Circuit with a nil breaker is a noop wrapper around the provided function and will never open. func NewCircuit[IN, OUT any](f BreakableFunc[IN, OUT], breaker Breaker, opts ...Option) *Circuit[IN, OUT] { b := &Circuit[IN, OUT]{ - f: f, - breaker: breaker, + f: f, options: options{ - isFailure: defaultFailureCondition, + isFailure: defaultFailureCondition, + observerForCall: defaultObserver, }, } if breaker != nil { + b.observerForCall = breaker.observerForCall breaker.connect(b) } @@ -118,7 +136,7 @@ func (c *Circuit[IN, OUT]) setOpenedAt(i int64) { // Call calls the wrapped function if the circuit is closed and returns its result. If the circuit is open, it returns // [ErrCircuitOpen]. // -// The wrapped function is called synchronously, but possilble context errors are recorded as soon as they occur. This +// The wrapped function is called synchronously, but possible context errors are recorded as soon as they occur. This // ensures the circuit opens quickly, even if the wrapped function blocks. // // By default, all errors are considered failures (including [context.Canceled]), but this can be customized via @@ -126,11 +144,23 @@ func (c *Circuit[IN, OUT]) setOpenedAt(i int64) { // // Panics are observed as failures, but are not recovered (i.e.: they are "repanicked" instead). func (c *Circuit[IN, OUT]) Call(ctx context.Context, in IN) (out OUT, err error) { - obs := c.observerForCall() - if obs == nil { - return out, ErrCircuitOpen + if c.f == nil { + return out, nil + } + + obs, err := c.observerForCall(ctx) + if err != nil { + // Note: errors here are not "observed" and do not count towards the breaker's failure rate. + // This includes: + // - ErrCircuitOpen (so no "feedback loops") + // - ErrConcurrencyLimit (for blocking limited circuits) + // - context timeouts while blocked on concurrency limit + return out, err } + // ensure we dedup the final - potentially wrapped - observer. + obs = dedupObservableCall(obs) + ctx, cancel := context.WithCancelCause(ctx) defer cancel(internalCancellation) @@ -149,19 +179,11 @@ func (c *Circuit[IN, OUT]) Call(ctx context.Context, in IN) (out OUT, err error) return c.f(ctx, in) } -// observerForCall is a thin wrapper around the breaker to simplify the case where no breaker has been set. -func (c *Circuit[IN, OUT]) observerForCall() observer { - if c.breaker == nil { - return noopObserveable{} - } - return c.breaker.observerForCall() -} - // internalCancellation is used to distinguish between internal and external (to the lib) context cancellations. var internalCancellation = errors.New("internal cancellation") // observeCtx observes the given context for cancellation and records it as a failure. -// It assumes [Observable.Observe] is idempotent and deduplicates calls itself. +// It assumes [observer] is idempotent and deduplicates calls itself. func (c *Circuit[IN, OUT]) observeCtx(obs observer, ctx context.Context) { // We want to observe a context error as soon as possible to open the breaker, but at the same time we want to // keep the call to the wrapped function synchronous to avoid all pitfalls that come with asynchronicity. @@ -200,11 +222,16 @@ func (s State) String() string { } // defaultFailureCondition is the default failure condition used by [NewCircuit]. -// It consider any non-nil error a failure. +// It considers any non-nil error a failure. func defaultFailureCondition(err error) bool { return err != nil } +// defaultObserver is the default observer used by [NewCircuit]. +func defaultObserver(context.Context) (observer, error) { + return noopObserveable{}, nil +} + type noopObserveable struct{} func (noopObserveable) observe(bool) {} diff --git a/hoglet_test.go b/hoglet_test.go index 72f2d2c..4031328 100644 --- a/hoglet_test.go +++ b/hoglet_test.go @@ -83,11 +83,11 @@ type mockBreaker struct { } // observerForCall implements [Breaker] -func (mt *mockBreaker) observerForCall() observer { +func (mt *mockBreaker) observerForCall(context.Context) (observer, error) { if mt.open { - return nil + return nil, ErrCircuitOpen } else { - return &mockObservable{breaker: mt} + return &mockObservable{breaker: mt}, nil } } @@ -99,10 +99,12 @@ type mockObservable struct { once sync.Once } -// observe implements [Observer] +// observe implements [observer] func (mo *mockObservable) observe(failure bool) { mo.once.Do(func() { - mo.breaker.open = failure + if mo.breaker != nil { + mo.breaker.open = failure + } }) } diff --git a/limiter.go b/limiter.go new file mode 100644 index 0000000..78075dd --- /dev/null +++ b/limiter.go @@ -0,0 +1,38 @@ +package hoglet + +import ( + "context" + "fmt" + + "golang.org/x/sync/semaphore" +) + +func newLimiter(origFactory observerFactory, limit int64, block bool) observerFactory { + sem := semaphore.NewWeighted(limit) + + wrappedFactory := func(ctx context.Context) (observer, error) { + o, err := origFactory(ctx) + if err != nil { + return nil, err + } + return observableCall(func(b bool) { + defer sem.Release(1) + o.observe(b) + }), nil + } + + if block { + return func(ctx context.Context) (observer, error) { + if err := sem.Acquire(ctx, 1); err != nil { + return nil, fmt.Errorf("%w: %w", ErrWaitingForSlot, err) + } + return wrappedFactory(ctx) + } + } + return func(ctx context.Context) (observer, error) { + if !sem.TryAcquire(1) { + return nil, ErrConcurrencyLimitReached + } + return wrappedFactory(ctx) + } +} diff --git a/limiter_test.go b/limiter_test.go new file mode 100644 index 0000000..c0d858b --- /dev/null +++ b/limiter_test.go @@ -0,0 +1,130 @@ +package hoglet + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type mockPanickingObservable struct{} + +func (mo *mockPanickingObservable) observe(shouldPanic bool) { + // abuse the observer interface to signal a panic + if shouldPanic { + panic("mockObservable meant to panic") + } +} + +func Test_newLimiter(t *testing.T) { + orig := func() observerFactory { + return func(context.Context) (observer, error) { + return &mockPanickingObservable{}, nil + } + } + + type args struct { + limit int64 + block bool + } + tests := []struct { + name string + args args + calls int + cancel bool + wantPanicOn *int // which call to panic on (if at all) + wantErr error + }{ + { + name: "under limit", + args: args{limit: 1, block: false}, + calls: 0, + wantErr: nil, + }, + { + name: "over limit; non-blocking", + args: args{limit: 1, block: false}, + calls: 1, + wantErr: ErrConcurrencyLimitReached, + }, + { + name: "on limit; blocking", + args: args{limit: 1, block: true}, + calls: 1, + cancel: true, // cancel simulates a timeout in this case + wantErr: ErrWaitingForSlot, + }, + { + name: "cancelation releases with error", + args: args{limit: 1, block: true}, + calls: 1, + cancel: true, + wantErr: context.Canceled, + }, + { + name: "panic releases", + args: args{limit: 1, block: true}, + calls: 1, + cancel: false, + wantPanicOn: ptr(0), + wantErr: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctxCalls, cancelCalls := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancelCalls() + + wgStart := &sync.WaitGroup{} + wgStop := &sync.WaitGroup{} + defer wgStop.Wait() + + of := newLimiter(orig(), tt.args.limit, tt.args.block) + for i := 0; i < tt.calls; i++ { + wantPanic := tt.wantPanicOn != nil && *tt.wantPanicOn == i + + f := func() { + defer wgStop.Done() + o, err := of(ctxCalls) + wgStart.Done() + require.NoError(t, err) + + <-ctxCalls.Done() + + o.observe(wantPanic) + } + + wgStart.Add(1) + wgStop.Add(1) + if wantPanic { + go assert.Panics(t, f) + } else { + go f() + } + } + + ctx, cancel := context.WithCancel(context.Background()) + + if tt.cancel { + cancel() + } else { + defer cancel() + } + + wgStart.Wait() // ensure all calls are started + + o, err := of(ctx) + assert.ErrorIs(t, err, tt.wantErr) + if tt.wantErr == nil { + assert.NotNil(t, o) + } + }) + } +} + +func ptr[T any](in T) *T { + return &in +} diff --git a/options.go b/options.go index 23d1955..a3f13e9 100644 --- a/options.go +++ b/options.go @@ -37,3 +37,14 @@ func WithFailureCondition(condition func(error) bool) Option { func IgnoreContextCancelation(err error) bool { return err != nil && err != context.Canceled } + +// WithConcurrencyLimit sets the maximum number of concurrent calls to the provided limit. If the limit is reached, the +// circuit's behavior depends on the blocking parameter: +// - it either returns [ErrConcurrencyLimitReached] immediately if blocking is false +// - or blocks until a slot is available if blocking is true, potentially returning [ErrWaitingForSlot] if the context +// is canceled or times out while waiting. +func WithConcurrencyLimit(limit int64, blocking bool) Option { + return optionFunc(func(b *options) { + b.observerForCall = newLimiter(b.observerForCall, limit, blocking) + }) +}