From 0b045fbfc66281b7591b4fe314f3fd69908bb94d Mon Sep 17 00:00:00 2001 From: Radu Berinde Date: Mon, 31 Jul 2023 09:11:07 -0700 Subject: [PATCH] add Wait and WaitCtx convenience functions The common usage for `TryToFulfill` is to sleep in a loop. This code adds a convenience `Wait` method and a corresponding `WaitCtx` which respects context cancelation. --- token_bucket.go | 37 ++++++++++++++++++++++++++++++++++++- token_bucket_test.go | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 1 deletion(-) diff --git a/token_bucket.go b/token_bucket.go index 79f2a37..40e33d9 100644 --- a/token_bucket.go +++ b/token_bucket.go @@ -15,6 +15,7 @@ package tokenbucket import ( + "context" "time" ) @@ -50,7 +51,8 @@ func (tb *TokenBucket) Init(rate TokensPerSecond, burst Tokens) { }) } -// Init the token bucket with a custom "Now" fuction. +// Init the token bucket with a custom "Now" function. +// Note that custom wait functions don't work with Wait and WaitCtx. func (tb *TokenBucket) InitWithNowFn(rate TokensPerSecond, burst Tokens, nowFn func() time.Time) { *tb = TokenBucket{ rate: rate, @@ -137,6 +139,39 @@ func (tb *TokenBucket) TryToFulfill(amount Tokens) (fulfilled bool, tryAgainAfte return true, 0 } +// Wait removes the given amount, waiting as long as necessary. +func (tb *TokenBucket) Wait(amount Tokens) { + for { + fulfilled, tryAgainAfter := tb.TryToFulfill(amount) + if fulfilled { + return + } + time.Sleep(tryAgainAfter) + } +} + +// WaitCtx removes the given amount, waiting as long as necessary or until the +// context is canceled. +func (tb *TokenBucket) WaitCtx(ctx context.Context, amount Tokens) error { + // We want to check for context cancelation even if we don't need to wait. + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + for { + fulfilled, tryAgainAfter := tb.TryToFulfill(amount) + if fulfilled { + return nil + } + select { + case <-time.After(tryAgainAfter): + case <-ctx.Done(): + return ctx.Err() + } + } +} + // Exhausted returns the cumulative duration over which this token bucket was // exhausted. Exported only for metrics. func (tb *TokenBucket) Exhausted() time.Duration { diff --git a/token_bucket_test.go b/token_bucket_test.go index a6bfd3e..d591947 100644 --- a/token_bucket_test.go +++ b/token_bucket_test.go @@ -15,6 +15,8 @@ package tokenbucket import ( + "context" + "strings" "testing" "time" ) @@ -135,3 +137,35 @@ func TestTokenBucket(t *testing.T) { // the positive. checkExhausted(initialExhausted + (20+90)*time.Millisecond) } + +func TestWaitCtx(t *testing.T) { + var tb TokenBucket + tb.Init(1, 100) + // Drain the initial tokens. + if fulfilled, _ := tb.TryToFulfill(100); !fulfilled { + t.Fatalf("could not drain initial tokens") + } + waitResult := make(chan error, 1) + ctx, ctxCancel := context.WithCancel(context.Background()) + go func() { + // This would take 100 seconds to return unless we cancel the context. + waitResult <- tb.WaitCtx(ctx, 100) + }() + + time.Sleep(10 * time.Millisecond) + select { + case <-waitResult: + t.Fatal("WaitCtx terminated unexpectedly") + default: + } + + ctxCancel() + select { + case err := <-waitResult: + if err == nil || !strings.Contains(err.Error(), "context canceled") { + t.Errorf("unexpected error from WaitCtx: %v", err) + } + case <-time.After(10 * time.Second): + t.Fatalf("WaitCtx did not return after context cancelation") + } +}