Skip to content

Commit

Permalink
add Wait and WaitCtx convenience functions
Browse files Browse the repository at this point in the history
The common usage for `TryToFulfill` is to sleep in a loop. This code
adds a convenience `Wait` method and a corresponding `WaitCtx` which
respects context cancelation.
  • Loading branch information
RaduBerinde committed Jul 31, 2023
1 parent 182959a commit 2c0e679
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 1 deletion.
37 changes: 36 additions & 1 deletion token_bucket.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package tokenbucket

import (
"context"
"time"
)

Expand Down Expand Up @@ -50,7 +51,8 @@ func (tb *TokenBucket) Init(rate TokensPerSecond, burst Tokens) {
})
}

// Init the token bucket with a custom "Now" fuction.
// Init the token bucket with a custom "Now" function.
// Note that custom wait functions don't work with Wait and WaitCtx.
func (tb *TokenBucket) InitWithNowFn(rate TokensPerSecond, burst Tokens, nowFn func() time.Time) {
*tb = TokenBucket{
rate: rate,
Expand Down Expand Up @@ -137,6 +139,39 @@ func (tb *TokenBucket) TryToFulfill(amount Tokens) (fulfilled bool, tryAgainAfte
return true, 0
}

// Wait removes the given amount, waiting as long as necessary.
func (tb *TokenBucket) Wait(amount Tokens) {
for {
fulfilled, tryAgainAfter := tb.TryToFulfill(amount)
if fulfilled {
return
}
time.Sleep(tryAgainAfter)
}
}

// WaitCtx removes the given amount, waiting as long as necessary or until the
// context is canceled.
func (tb *TokenBucket) WaitCtx(ctx context.Context, amount Tokens) error {
// We want to check for context cancelation even if we don't need to wait.
select {
case <-ctx.Done():
return ctx.Err()
default:
}
for {
fulfilled, tryAgainAfter := tb.TryToFulfill(amount)
if fulfilled {
return nil
}
select {
case <-time.After(tryAgainAfter):
case <-ctx.Done():
return ctx.Err()
}
}
}

// Exhausted returns the cumulative duration over which this token bucket was
// exhausted. Exported only for metrics.
func (tb *TokenBucket) Exhausted() time.Duration {
Expand Down
35 changes: 35 additions & 0 deletions token_bucket_test.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
// G
// Copyright 2023 The Cockroach Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -15,6 +16,8 @@
package tokenbucket

import (
"context"
"strings"
"testing"
"time"
)
Expand Down Expand Up @@ -135,3 +138,35 @@ func TestTokenBucket(t *testing.T) {
// the positive.
checkExhausted(initialExhausted + (20+90)*time.Millisecond)
}

func TestWaitCtx(t *testing.T) {
var tb TokenBucket
tb.Init(1, 100)
// Drain the initial tokens.
if fulfilled, _ := tb.TryToFulfill(100); !fulfilled {
t.Fatalf("could not drain initial tokens")
}
waitResult := make(chan error, 1)
ctx, ctxCancel := context.WithCancel(context.Background())
go func() {
// This would take 100 seconds to return unless we cancel the context.
waitResult <- tb.WaitCtx(ctx, 100)
}()

time.Sleep(10 * time.Millisecond)
select {
case <-waitResult:
t.Fatal("WaitCtx terminated unexpectedly")
default:
}

ctxCancel()
select {
case err := <-waitResult:
if err == nil || !strings.Contains(err.Error(), "context canceled") {
t.Errorf("unexpected error from WaitCtx: %v", err)
}
case <-time.After(10 * time.Second):
t.Fatalf("WaitCtx did not return after context cancelation")
}
}

0 comments on commit 2c0e679

Please sign in to comment.