Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add option to limit concurrency #4

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 12 additions & 21 deletions breaker.go
Original file line number Diff line number Diff line change
@@ -1,34 +1,25 @@
package hoglet

import (
"context"
"math"
"sync"
"sync/atomic"
"time"
)

// untypedCircuit is used to avoid type annotations when implementing a breaker.
// untypedCircuit is used to avoid type annotations when implementing breakers.
type untypedCircuit interface {
stateForCall() State
setOpenedAt(int64)
}

// observer is used to observe the result of a single wrapped call through the circuit breaker.
type observer interface {
// observe is called after the wrapped function returns. If [Circuit.Do] returns a non-nil [Observable], it will be
// observe is called after the wrapped function returns. If [Circuit.Do] returns a non-nil [observer], it will be
// called exactly once.
observe(failure bool)
}

// newObservableCall creates a new [Observable] that ensures it can only be observed a single time.
func newObservableCall(f func(bool)) observer {
o := sync.Once{}
return observableCall(func(failure bool) {
o.Do(func() {
f(failure)
})
})
}

// observableCall tracks a single call through the breaker.
// It should be instantiated via [newObservableCall] to ensure the observer is only called once.
type observableCall func(bool)
Expand Down Expand Up @@ -80,16 +71,16 @@ func (e *EWMABreaker) connect(c untypedCircuit) {
e.circuit = c
}

func (e *EWMABreaker) observerForCall() observer {
func (e *EWMABreaker) observerForCall(_ context.Context) (observer, error) {
state := e.circuit.stateForCall()

if state == StateOpen {
return nil
return nil, ErrCircuitOpen
}

return newObservableCall(func(failure bool) {
return observableCall(func(failure bool) {
e.observe(state == StateHalfOpen, failure)
})
}), nil
}

func (e *EWMABreaker) observe(halfOpen, failure bool) {
Expand Down Expand Up @@ -164,16 +155,16 @@ func (s *SlidingWindowBreaker) connect(c untypedCircuit) {
s.circuit = c
}

func (s *SlidingWindowBreaker) observerForCall() observer {
func (s *SlidingWindowBreaker) observerForCall(_ context.Context) (observer, error) {
state := s.circuit.stateForCall()

if state == StateOpen {
return nil
return nil, ErrCircuitOpen
}

return newObservableCall(func(failure bool) {
return observableCall(func(failure bool) {
s.observe(state == StateHalfOpen, failure)
})
}), nil
}

func (s *SlidingWindowBreaker) observe(halfOpen, failure bool) {
Expand Down
19 changes: 12 additions & 7 deletions breaker_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package hoglet

import (
"context"
"math/rand"
"testing"
"time"
Expand All @@ -12,16 +13,19 @@ import (
func TestEWMABreaker_zero_value_does_not_open(t *testing.T) {
b := &EWMABreaker{}
b.connect(&mockCircuit{})
o := b.observerForCall()
require.NotNil(t, o)
o, err := b.observerForCall(context.TODO())
require.NoError(t, err)
o.observe(true)
assert.NotNil(t, b.observerForCall())
_, err = b.observerForCall(context.TODO())
assert.NoError(t, err)
}

func TestEWMABreaker_zero_value_does_not_panic(t *testing.T) {
b := &EWMABreaker{}
b.connect(&mockCircuit{})
assert.NotPanics(t, func() { b.observerForCall() })
assert.NotPanics(t, func() {
b.observerForCall(context.TODO()) // nolint: errcheck // we are just interested in the panic
})
}

func TestBreaker_Observe_State(t *testing.T) {
Expand Down Expand Up @@ -193,7 +197,7 @@ func TestBreaker_Observe_State(t *testing.T) {
c.setState(StateHalfOpen)
}
failure := s.failureFunc(i)
o := b.observerForCall()
o, _ := b.observerForCall(context.TODO()) // nolint: errcheck // always observe

switch b := b.(type) {
case *EWMABreaker:
Expand All @@ -213,10 +217,11 @@ func TestBreaker_Observe_State(t *testing.T) {
}
}
}
_, err := b.observerForCall(context.TODO())
if tt.wantCall {
assert.NotNil(t, b.observerForCall())
assert.NoError(t, err)
} else {
assert.Nil(t, b.observerForCall())
assert.ErrorIs(t, err, ErrCircuitOpen)
}
})
}
Expand Down
6 changes: 5 additions & 1 deletion error.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,8 @@ func (b Error) Error() string {
return "hoglet: " + b.msg
}

var ErrCircuitOpen = Error{msg: "breaker is open"}
var (
ErrCircuitOpen = Error{msg: "breaker is open"}
ErrConcurrencyLimitReached = Error{msg: "concurrency limit reached"}
ErrWaitingForSlot = Error{msg: "waiting for slot"}
)
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ require (
github.com/kr/pretty v0.3.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.9.0 // indirect
golang.org/x/sync v0.4.0
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ=
golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
Expand Down
71 changes: 49 additions & 22 deletions hoglet.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package hoglet
import (
"context"
"errors"
"sync"
"sync/atomic"
"time"
)
Expand All @@ -12,8 +13,7 @@ import (
//
// A zero Circuit will panic, analogous to calling a nil function variable. Initialize with [NewCircuit].
type Circuit[IN, OUT any] struct {
f BreakableFunc[IN, OUT]
breaker Breaker
f BreakableFunc[IN, OUT]
options

// State
Expand All @@ -29,36 +29,54 @@ type options struct {
// halfOpenDelay is the duration the circuit will stay open before switching to the half-open state, where a
// limited (~1) amount of calls are allowed that - if successful - may re-close the breaker.
halfOpenDelay time.Duration

// Usually, this is implemented by the breaker, but it can be overridden for testing purposes.
observerForCall observerFactory
}

// observerFactory is a function that returns one observer for each call going through the circuit.
// It is used analogously to a http.Handler, allowing different plugins to "wrap" each execution.
type observerFactory func(context.Context) (observer, error)

// Breaker is the interface implemented by the different breakers, responsible for actually opening the circuit.
// Each implementation behaves differently when deciding whether to open the breaker upon failure.
type Breaker interface {
// connect is called to allow the breaker to actuate its parent circuit.
connect(untypedCircuit)

// observerForCall returns an observer for the incoming call.
// It is called exactly once per call to [Breaker.Do], before calling the wrapped function.
// It is called exactly once per call to [Circuit.Call], before calling the wrapped function.
// If the breaker is open, it returns nil.
// If the breaker is closed, it returns a non-nil [Observable] that will be used to observe the result of the call.
observerForCall() observer
// If the breaker is closed, it returns a non-nil [observer] that will be used to observe the result of the call.
observerForCall(context.Context) (observer, error)
}

// BreakableFunc is the type of the function wrapped by a Breaker.
type BreakableFunc[IN, OUT any] func(context.Context, IN) (OUT, error)

// dedupObservableCall wraps an [observer] ensuring it can only be observed a single time.
func dedupObservableCall(obs observer) observer {
o := sync.Once{}
return observableCall(func(failure bool) {
o.Do(func() {
obs.observe(failure)
})
})
}

// NewCircuit instantiates a new [Circuit] that wraps the provided function. See [Circuit.Call] for calling semantics.
// A Circuit with a nil breaker is a noop wrapper around the provided function and will never open.
func NewCircuit[IN, OUT any](f BreakableFunc[IN, OUT], breaker Breaker, opts ...Option) *Circuit[IN, OUT] {
b := &Circuit[IN, OUT]{
f: f,
breaker: breaker,
f: f,
options: options{
isFailure: defaultFailureCondition,
isFailure: defaultFailureCondition,
observerForCall: defaultObserver,
},
}

if breaker != nil {
b.observerForCall = breaker.observerForCall
breaker.connect(b)
}

Expand Down Expand Up @@ -118,19 +136,31 @@ func (c *Circuit[IN, OUT]) setOpenedAt(i int64) {
// Call calls the wrapped function if the circuit is closed and returns its result. If the circuit is open, it returns
// [ErrCircuitOpen].
//
// The wrapped function is called synchronously, but possilble context errors are recorded as soon as they occur. This
// The wrapped function is called synchronously, but possible context errors are recorded as soon as they occur. This
// ensures the circuit opens quickly, even if the wrapped function blocks.
//
// By default, all errors are considered failures (including [context.Canceled]), but this can be customized via
// [WithFailureCondition] and [IgnoreContextCancelation].
//
// Panics are observed as failures, but are not recovered (i.e.: they are "repanicked" instead).
func (c *Circuit[IN, OUT]) Call(ctx context.Context, in IN) (out OUT, err error) {
obs := c.observerForCall()
if obs == nil {
return out, ErrCircuitOpen
if c.f == nil {
return out, nil
}

obs, err := c.observerForCall(ctx)
if err != nil {
// Note: errors here are not "observed" and do not count towards the breaker's failure rate.
// This includes:
// - ErrCircuitOpen (so no "feedback loops")
// - ErrConcurrencyLimit (for blocking limited circuits)
// - context timeouts while blocked on concurrency limit
return out, err
}

// ensure we dedup the final - potentially wrapped - observer.
obs = dedupObservableCall(obs)
costela marked this conversation as resolved.
Show resolved Hide resolved

ctx, cancel := context.WithCancelCause(ctx)
defer cancel(internalCancellation)

Expand All @@ -149,19 +179,11 @@ func (c *Circuit[IN, OUT]) Call(ctx context.Context, in IN) (out OUT, err error)
return c.f(ctx, in)
}

// observerForCall is a thin wrapper around the breaker to simplify the case where no breaker has been set.
func (c *Circuit[IN, OUT]) observerForCall() observer {
if c.breaker == nil {
return noopObserveable{}
}
return c.breaker.observerForCall()
}

// internalCancellation is used to distinguish between internal and external (to the lib) context cancellations.
var internalCancellation = errors.New("internal cancellation")

// observeCtx observes the given context for cancellation and records it as a failure.
// It assumes [Observable.Observe] is idempotent and deduplicates calls itself.
// It assumes [observer] is idempotent and deduplicates calls itself.
func (c *Circuit[IN, OUT]) observeCtx(obs observer, ctx context.Context) {
// We want to observe a context error as soon as possible to open the breaker, but at the same time we want to
// keep the call to the wrapped function synchronous to avoid all pitfalls that come with asynchronicity.
Expand Down Expand Up @@ -200,11 +222,16 @@ func (s State) String() string {
}

// defaultFailureCondition is the default failure condition used by [NewCircuit].
// It consider any non-nil error a failure.
// It considers any non-nil error a failure.
func defaultFailureCondition(err error) bool {
return err != nil
}

// defaultObserver is the default observer used by [NewCircuit].
func defaultObserver(context.Context) (observer, error) {
return noopObserveable{}, nil
}

type noopObserveable struct{}

func (noopObserveable) observe(bool) {}
12 changes: 7 additions & 5 deletions hoglet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,11 @@ type mockBreaker struct {
}

// observerForCall implements [Breaker]
func (mt *mockBreaker) observerForCall() observer {
func (mt *mockBreaker) observerForCall(context.Context) (observer, error) {
if mt.open {
return nil
return nil, ErrCircuitOpen
} else {
return &mockObservable{breaker: mt}
return &mockObservable{breaker: mt}, nil
}
}

Expand All @@ -99,10 +99,12 @@ type mockObservable struct {
once sync.Once
}

// observe implements [Observer]
// observe implements [observer]
func (mo *mockObservable) observe(failure bool) {
mo.once.Do(func() {
mo.breaker.open = failure
if mo.breaker != nil {
mo.breaker.open = failure
}
})
}

Expand Down
38 changes: 38 additions & 0 deletions limiter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package hoglet

import (
"context"
"fmt"

"golang.org/x/sync/semaphore"
)

func newLimiter(origFactory observerFactory, limit int64, block bool) observerFactory {
sem := semaphore.NewWeighted(limit)

wrappedFactory := func(ctx context.Context) (observer, error) {
o, err := origFactory(ctx)
if err != nil {
return nil, err
}
return observableCall(func(b bool) {
defer sem.Release(1)
o.observe(b)
}), nil
}

if block {
return func(ctx context.Context) (observer, error) {
if err := sem.Acquire(ctx, 1); err != nil {
return nil, fmt.Errorf("%w: %w", ErrWaitingForSlot, err)
}
return wrappedFactory(ctx)
}
}
return func(ctx context.Context) (observer, error) {
if !sem.TryAcquire(1) {
return nil, ErrConcurrencyLimitReached
}
return wrappedFactory(ctx)
}
}
Loading