diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 8fe61bd..a4f6ae4 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -15,7 +15,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v4 with: - go-version: 1.18 + go-version: 1.19 - name: Check out code uses: actions/checkout@v3 diff --git a/.github/workflows/lint-test.yml b/.github/workflows/lint-test.yml index 28fe75e..f52df99 100644 --- a/.github/workflows/lint-test.yml +++ b/.github/workflows/lint-test.yml @@ -13,7 +13,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v4 with: - go-version: 1.18 + go-version: 1.19 - name: Run golangci-lint uses: golangci/golangci-lint-action@v3.4.0 with: diff --git a/.gitignore b/.gitignore index 66fd13c..83540aa 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,5 @@ # Dependency directories (remove the comment below to include it) # vendor/ +.idea +.vscode \ No newline at end of file diff --git a/example/main.go b/example/main.go index 92a3520..e7d54fe 100644 --- a/example/main.go +++ b/example/main.go @@ -10,7 +10,7 @@ import ( func main() { // create a rate limiter by passing context, max tasks/tokens , time interval - limiter := ratelimit.New(context.Background(), 5, time.Duration(10*time.Second)) + limiter := ratelimit.New(context.Background(), 5, 10*time.Second) save := time.Now() diff --git a/go.mod b/go.mod index c29eb0e..6647ec7 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/projectdiscovery/ratelimit -go 1.18 +go 1.19 require ( github.com/projectdiscovery/utils v0.0.30 diff --git a/keyratelimit.go b/keyratelimit.go index 719b62d..5bb269e 100644 --- a/keyratelimit.go +++ b/keyratelimit.go @@ -43,7 +43,7 @@ type MultiLimiter struct { ctx context.Context } -// Adds new bucket with key +// Add new bucket with key func (m *MultiLimiter) Add(opts *Options) error { if err := opts.Validate(); err != nil { return err @@ -54,7 +54,7 @@ func (m *MultiLimiter) Add(opts *Options) error { } else { rlimiter = New(m.ctx, opts.MaxCount, opts.Duration) } - // ok if true if key already exists + // ok is true if key already exists _, ok := m.limiters.LoadOrStore(opts.Key, rlimiter) if ok { return ErrKeyAlreadyExists.Msgf("key: %v", opts.Key) diff --git a/ratelimit.go b/ratelimit.go index bc8b8a4..a480f65 100644 --- a/ratelimit.go +++ b/ratelimit.go @@ -3,13 +3,17 @@ package ratelimit import ( "context" "math" + "sync/atomic" "time" ) +// equals to -1 +var minusOne = ^uint32(0) + // Limiter allows a burst of request during the defined duration type Limiter struct { - maxCount uint - count uint + maxCount uint32 + count atomic.Uint32 ticker *time.Ticker tokens chan struct{} ctx context.Context @@ -20,9 +24,9 @@ type Limiter struct { func (limiter *Limiter) run(ctx context.Context) { defer close(limiter.tokens) for { - if limiter.count == 0 { + if limiter.count.Load() == 0 { <-limiter.ticker.C - limiter.count = limiter.maxCount + limiter.count.Store(limiter.maxCount) } select { case <-ctx.Done(): @@ -33,21 +37,21 @@ func (limiter *Limiter) run(ctx context.Context) { limiter.ticker.Stop() return case limiter.tokens <- struct{}{}: - limiter.count-- + limiter.count.Add(minusOne) case <-limiter.ticker.C: - limiter.count = limiter.maxCount + limiter.count.Store(limiter.maxCount) } } } // Take one token from the bucket -func (rateLimiter *Limiter) Take() { - <-rateLimiter.tokens +func (limiter *Limiter) Take() { + <-limiter.tokens } // GetLimit returns current rate limit per given duration -func (ratelimiter *Limiter) GetLimit() uint { - return ratelimiter.maxCount +func (limiter *Limiter) GetLimit() uint { + return uint(limiter.maxCount) } // TODO: SleepandReset should be able to handle multiple calls without resetting multiple times @@ -72,9 +76,9 @@ func (ratelimiter *Limiter) GetLimit() uint { // } // Stop the rate limiter canceling the internal context -func (ratelimiter *Limiter) Stop() { - if ratelimiter.cancelFunc != nil { - ratelimiter.cancelFunc() +func (limiter *Limiter) Stop() { + if limiter.cancelFunc != nil { + limiter.cancelFunc() } } @@ -83,13 +87,13 @@ func New(ctx context.Context, max uint, duration time.Duration) *Limiter { internalctx, cancel := context.WithCancel(context.TODO()) limiter := &Limiter{ - maxCount: uint(max), - count: uint(max), + maxCount: uint32(max), ticker: time.NewTicker(duration), tokens: make(chan struct{}), ctx: ctx, cancelFunc: cancel, } + limiter.count.Store(uint32(max)) go limiter.run(internalctx) return limiter @@ -100,13 +104,13 @@ func NewUnlimited(ctx context.Context) *Limiter { internalctx, cancel := context.WithCancel(context.TODO()) limiter := &Limiter{ - maxCount: math.MaxUint, - count: math.MaxUint, + maxCount: math.MaxUint32, ticker: time.NewTicker(time.Millisecond), tokens: make(chan struct{}), ctx: ctx, cancelFunc: cancel, } + limiter.count.Store(math.MaxUint32) go limiter.run(internalctx) return limiter