Skip to content

Commit

Permalink
Merge pull request #31 from projectdiscovery/issue-30-race
Browse files Browse the repository at this point in the history
Using atomic for internal counters
  • Loading branch information
Mzack9999 authored May 16, 2023
2 parents 36a8d33 + 8cbb45c commit fd03f8b
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 23 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v4
with:
go-version: 1.18
go-version: 1.19

- name: Check out code
uses: actions/checkout@v3
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/lint-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v4
with:
go-version: 1.18
go-version: 1.19
- name: Run golangci-lint
uses: golangci/[email protected]
with:
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@

# Dependency directories (remove the comment below to include it)
# vendor/
.idea
.vscode
2 changes: 1 addition & 1 deletion example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (

func main() {
// create a rate limiter by passing context, max tasks/tokens , time interval
limiter := ratelimit.New(context.Background(), 5, time.Duration(10*time.Second))
limiter := ratelimit.New(context.Background(), 5, 10*time.Second)

save := time.Now()

Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/projectdiscovery/ratelimit

go 1.18
go 1.19

require (
github.com/projectdiscovery/utils v0.0.30
Expand Down
4 changes: 2 additions & 2 deletions keyratelimit.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ type MultiLimiter struct {
ctx context.Context
}

// Adds new bucket with key
// Add new bucket with key
func (m *MultiLimiter) Add(opts *Options) error {
if err := opts.Validate(); err != nil {
return err
Expand All @@ -54,7 +54,7 @@ func (m *MultiLimiter) Add(opts *Options) error {
} else {
rlimiter = New(m.ctx, opts.MaxCount, opts.Duration)
}
// ok if true if key already exists
// ok is true if key already exists
_, ok := m.limiters.LoadOrStore(opts.Key, rlimiter)
if ok {
return ErrKeyAlreadyExists.Msgf("key: %v", opts.Key)
Expand Down
38 changes: 21 additions & 17 deletions ratelimit.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@ package ratelimit
import (
"context"
"math"
"sync/atomic"
"time"
)

// equals to -1
var minusOne = ^uint32(0)

// Limiter allows a burst of request during the defined duration
type Limiter struct {
maxCount uint
count uint
maxCount uint32
count atomic.Uint32
ticker *time.Ticker
tokens chan struct{}
ctx context.Context
Expand All @@ -20,9 +24,9 @@ type Limiter struct {
func (limiter *Limiter) run(ctx context.Context) {
defer close(limiter.tokens)
for {
if limiter.count == 0 {
if limiter.count.Load() == 0 {
<-limiter.ticker.C
limiter.count = limiter.maxCount
limiter.count.Store(limiter.maxCount)
}
select {
case <-ctx.Done():
Expand All @@ -33,21 +37,21 @@ func (limiter *Limiter) run(ctx context.Context) {
limiter.ticker.Stop()
return
case limiter.tokens <- struct{}{}:
limiter.count--
limiter.count.Add(minusOne)
case <-limiter.ticker.C:
limiter.count = limiter.maxCount
limiter.count.Store(limiter.maxCount)
}
}
}

// Take one token from the bucket
func (rateLimiter *Limiter) Take() {
<-rateLimiter.tokens
func (limiter *Limiter) Take() {
<-limiter.tokens
}

// GetLimit returns current rate limit per given duration
func (ratelimiter *Limiter) GetLimit() uint {
return ratelimiter.maxCount
func (limiter *Limiter) GetLimit() uint {
return uint(limiter.maxCount)
}

// TODO: SleepandReset should be able to handle multiple calls without resetting multiple times
Expand All @@ -72,9 +76,9 @@ func (ratelimiter *Limiter) GetLimit() uint {
// }

// Stop the rate limiter canceling the internal context
func (ratelimiter *Limiter) Stop() {
if ratelimiter.cancelFunc != nil {
ratelimiter.cancelFunc()
func (limiter *Limiter) Stop() {
if limiter.cancelFunc != nil {
limiter.cancelFunc()
}
}

Expand All @@ -83,13 +87,13 @@ func New(ctx context.Context, max uint, duration time.Duration) *Limiter {
internalctx, cancel := context.WithCancel(context.TODO())

limiter := &Limiter{
maxCount: uint(max),
count: uint(max),
maxCount: uint32(max),
ticker: time.NewTicker(duration),
tokens: make(chan struct{}),
ctx: ctx,
cancelFunc: cancel,
}
limiter.count.Store(uint32(max))
go limiter.run(internalctx)

return limiter
Expand All @@ -100,13 +104,13 @@ func NewUnlimited(ctx context.Context) *Limiter {
internalctx, cancel := context.WithCancel(context.TODO())

limiter := &Limiter{
maxCount: math.MaxUint,
count: math.MaxUint,
maxCount: math.MaxUint32,
ticker: time.NewTicker(time.Millisecond),
tokens: make(chan struct{}),
ctx: ctx,
cancelFunc: cancel,
}
limiter.count.Store(math.MaxUint32)
go limiter.run(internalctx)

return limiter
Expand Down

0 comments on commit fd03f8b

Please sign in to comment.