diff --git a/extensions/prometheus/prometheus.go b/extensions/prometheus/prometheus.go index 709cf73..2a667c7 100644 --- a/extensions/prometheus/prometheus.go +++ b/extensions/prometheus/prometheus.go @@ -3,7 +3,6 @@ package hogprom import ( "context" "errors" - "fmt" "strconv" "time" @@ -13,95 +12,102 @@ import ( const ( namespace = "hoglet" + subsystem = "circuit" ) -// WithPrometheusMetrics returns a [hoglet.BreakerMiddleware] that registers prometheus metrics for the circuit. +// NewCollector returns a [hoglet.BreakerMiddleware] that exposes prometheus metrics for the circuit. +// It implements prometheus.Collector and can therefore be registered with a prometheus.Registerer. // -// ⚠️ Note: the provided name must be unique across all hoglet instances using the same registerer. -func WithPrometheusMetrics(circuitName string, reg prometheus.Registerer) hoglet.BreakerMiddleware { - return func(next hoglet.ObserverFactory) (hoglet.ObserverFactory, error) { - callDurations := prometheus.NewHistogramVec( - prometheus.HistogramOpts{ - Namespace: namespace, - Subsystem: "circuit", - Name: "call_durations_seconds", - Help: "Call durations in seconds", - ConstLabels: prometheus.Labels{ - "circuit": circuitName, - }, +// ⚠️ Note: the provided name must be unique across all hoglet instances ultimately registered to the same +// prometheus.Registerer. +func NewCollector(circuitName string) *Middleware { + callDurations := prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: namespace, + Subsystem: subsystem, + Name: "call_durations_seconds", + Help: "Call durations in seconds", + ConstLabels: prometheus.Labels{ + "circuit": circuitName, }, - []string{"success"}, - ) - - droppedCalls := prometheus.NewCounterVec( - prometheus.CounterOpts{ - Namespace: namespace, - Subsystem: "circuit", - Name: "dropped_calls_total", - Help: "Total number of calls with an open circuit (i.e.: calls that did not reach the wrapped function)", - ConstLabels: prometheus.Labels{ - "circuit": circuitName, - }, + }, + []string{"success"}, + ) + + droppedCalls := prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: namespace, + Subsystem: subsystem, + Name: "dropped_calls_total", + Help: "Total number of calls with an open circuit (i.e.: calls that did not reach the wrapped function)", + ConstLabels: prometheus.Labels{ + "circuit": circuitName, }, - []string{"cause"}, - ) - - inflightCalls := prometheus.NewGauge( - prometheus.GaugeOpts{ - Namespace: namespace, - Subsystem: "circuit", - Name: "inflight_calls_current", - Help: "Current number of calls in-flight", - ConstLabels: prometheus.Labels{ - "circuit": circuitName, - }, + }, + []string{"cause"}, + ) + + inflightCalls := prometheus.NewGauge( + prometheus.GaugeOpts{ + Namespace: namespace, + Subsystem: subsystem, + Name: "inflight_calls_current", + Help: "Current number of calls in-flight", + ConstLabels: prometheus.Labels{ + "circuit": circuitName, }, - ) - - for _, c := range []prometheus.Collector{ - callDurations, - droppedCalls, - inflightCalls, - } { - if err := reg.Register(c); err != nil { - return nil, fmt.Errorf("hoglet: registering collector: %w", err) - } - } - - return &prometheusObserverFactory{ - next: next, + }, + ) - timesource: wallclock{}, - - callDurations: callDurations, - droppedCalls: droppedCalls, - inflightCalls: inflightCalls, - }, nil + return &Middleware{ + callDurations: callDurations, + droppedCalls: droppedCalls, + inflightCalls: inflightCalls, } } -type prometheusObserverFactory struct { - next hoglet.ObserverFactory - - timesource timesource - +type Middleware struct { callDurations *prometheus.HistogramVec droppedCalls *prometheus.CounterVec inflightCalls prometheus.Gauge } -func (pos *prometheusObserverFactory) ObserverForCall(ctx context.Context, state hoglet.State) (hoglet.Observer, error) { - o, err := pos.next.ObserverForCall(ctx, state) +func (m Middleware) Collect(ch chan<- prometheus.Metric) { + m.callDurations.Collect(ch) + m.droppedCalls.Collect(ch) + m.inflightCalls.Collect(ch) +} + +func (m Middleware) Describe(ch chan<- *prometheus.Desc) { + prometheus.DescribeByCollect(m, ch) +} + +func (m Middleware) Wrap(of hoglet.ObserverFactory) (hoglet.ObserverFactory, error) { + return &wrappedMiddleware{ + Middleware: m, + next: of, + timesource: wallclock{}, + }, nil +} + +type wrappedMiddleware struct { + Middleware + next hoglet.ObserverFactory + timesource timesource +} + +func (wm *wrappedMiddleware) ObserverForCall(ctx context.Context, state hoglet.State) (hoglet.Observer, error) { + o, err := wm.next.ObserverForCall(ctx, state) if err != nil { - pos.droppedCalls.WithLabelValues(errToCause(err)).Inc() + wm.droppedCalls.WithLabelValues(errToCause(err)).Inc() return nil, err } - start := pos.timesource.Now() - pos.inflightCalls.Inc() + start := wm.timesource.Now() + wm.inflightCalls.Inc() return hoglet.ObserverFunc(func(b bool) { // invert failure → success to make the metric more intuitive - pos.callDurations.WithLabelValues(strconv.FormatBool(!b)).Observe(pos.timesource.Since(start).Seconds()) - pos.inflightCalls.Dec() + wm.callDurations.WithLabelValues(strconv.FormatBool(!b)).Observe(wm.timesource.Since(start).Seconds()) + wm.inflightCalls.Dec() o.Observe(b) }), nil } diff --git a/extensions/prometheus/prometheus_test.go b/extensions/prometheus/prometheus_test.go index 0ee5a3d..f713f8a 100644 --- a/extensions/prometheus/prometheus_test.go +++ b/extensions/prometheus/prometheus_test.go @@ -7,7 +7,6 @@ import ( "time" "github.com/exaring/hoglet" - "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/require" ) @@ -44,21 +43,20 @@ func (m mockTimesource) Since(t time.Time) time.Duration { } func TestWithPrometheusMetrics(t *testing.T) { - reg := prometheus.NewPedanticRegistry() - m := WithPrometheusMetrics("test", reg) - of, err := m(&mockObserverFactory{}) + m := NewCollector("test") + of, err := m.Wrap(&mockObserverFactory{}) require.NoError(t, err) mt := &mockTimesource{time.Now()} - of.(*prometheusObserverFactory).timesource = mt + of.(*wrappedMiddleware).timesource = mt inflightOut0 := `# HELP hoglet_circuit_inflight_calls_current Current number of calls in-flight # TYPE hoglet_circuit_inflight_calls_current gauge hoglet_circuit_inflight_calls_current{circuit="test"} 0 ` - if err := testutil.GatherAndCompare(reg, strings.NewReader(inflightOut0)); err != nil { + if err := testutil.CollectAndCompare(m, strings.NewReader(inflightOut0)); err != nil { t.Fatal(err) } @@ -72,7 +70,7 @@ func TestWithPrometheusMetrics(t *testing.T) { # TYPE hoglet_circuit_inflight_calls_current gauge hoglet_circuit_inflight_calls_current{circuit="test"} 0 ` - if err := testutil.GatherAndCompare(reg, strings.NewReader(droppedOut1)); err != nil { + if err := testutil.CollectAndCompare(m, strings.NewReader(droppedOut1)); err != nil { t.Fatal(err) } @@ -86,7 +84,7 @@ func TestWithPrometheusMetrics(t *testing.T) { # TYPE hoglet_circuit_inflight_calls_current gauge hoglet_circuit_inflight_calls_current{circuit="test"} 1 ` - if err := testutil.GatherAndCompare(reg, strings.NewReader(inflightOut1)); err != nil { + if err := testutil.CollectAndCompare(m, strings.NewReader(inflightOut1)); err != nil { t.Fatal(err) } @@ -118,7 +116,7 @@ func TestWithPrometheusMetrics(t *testing.T) { hoglet_circuit_inflight_calls_current{circuit="test"} 0 ` - if err := testutil.GatherAndCompare(reg, strings.NewReader(durationsOut1)); err != nil { + if err := testutil.CollectAndCompare(m, strings.NewReader(durationsOut1)); err != nil { t.Fatal(err) } } diff --git a/hoglet.go b/hoglet.go index 8afeb98..fb17a0a 100644 --- a/hoglet.go +++ b/hoglet.go @@ -53,8 +53,16 @@ type ObserverFactory interface { ObserverForCall(context.Context, State) (Observer, error) } -// BreakerMiddleware is a function that wraps an [ObserverFactory] and returns a new [ObserverFactory]. -type BreakerMiddleware func(ObserverFactory) (ObserverFactory, error) +// BreakerMiddleware wraps an [ObserverFactory] and returns a new [ObserverFactory]. +type BreakerMiddleware interface { + Wrap(ObserverFactory) (ObserverFactory, error) +} + +type BreakerMiddlewareFunc func(ObserverFactory) (ObserverFactory, error) + +func (f BreakerMiddlewareFunc) Wrap(of ObserverFactory) (ObserverFactory, error) { + return f(of) +} // WrappedFunc is the type of the function wrapped by a Breaker. type WrappedFunc[IN, OUT any] func(context.Context, IN) (OUT, error) diff --git a/limiter.go b/limiter.go index dbcb803..91b05d2 100644 --- a/limiter.go +++ b/limiter.go @@ -13,7 +13,7 @@ import ( // - or blocks until a slot is available if blocking is true, potentially returning [ErrWaitingForSlot]. The returned // error wraps the underlying cause (e.g. [context.Canceled] or [context.DeadlineExceeded]). func ConcurrencyLimiter(limit int64, block bool) BreakerMiddleware { - return func(next ObserverFactory) (ObserverFactory, error) { + return BreakerMiddlewareFunc(func(next ObserverFactory) (ObserverFactory, error) { cl := concurrencyLimiter{ sem: semaphore.NewWeighted(limit), next: next, @@ -26,7 +26,7 @@ func ConcurrencyLimiter(limit int64, block bool) BreakerMiddleware { return concurrencyLimiterNonBlocking{ concurrencyLimiter: cl, }, nil - } + }) } type concurrencyLimiter struct { diff --git a/limiter_test.go b/limiter_test.go index f0cfb62..586b465 100644 --- a/limiter_test.go +++ b/limiter_test.go @@ -84,7 +84,7 @@ func Test_ConcurrencyLimiter(t *testing.T) { defer wgStop.Wait() cl := hoglet.ConcurrencyLimiter(tt.args.limit, tt.args.block) - of, err := cl(mockObserverFactory{}) + of, err := cl.Wrap(mockObserverFactory{}) require.NoError(t, err) for i := 0; i < tt.calls; i++ { wantPanic := tt.wantPanicOn != nil && *tt.wantPanicOn == i diff --git a/options.go b/options.go index 6784c69..6910dee 100644 --- a/options.go +++ b/options.go @@ -52,7 +52,7 @@ func IgnoreContextCanceled(err error) bool { // middleware and should therefore be AFTER it in the parameter list. func WithBreakerMiddleware(bm BreakerMiddleware) Option { return optionFunc(func(o *options) error { - b, err := bm(o.observerFactory) + b, err := bm.Wrap(o.observerFactory) if err != nil { return fmt.Errorf("creating middleware: %w", err) }