Skip to content

Commit

Permalink
Merge pull request #26 from bringg/fix_expr_time
Browse files Browse the repository at this point in the history
Add Registry for registering rate limit algorithm engines
  • Loading branch information
Shareed2k authored Apr 1, 2021
2 parents 3fb5c4c + a724a3c commit 9184186
Show file tree
Hide file tree
Showing 17 changed files with 205 additions and 241 deletions.
7 changes: 7 additions & 0 deletions algorithm/all/all.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package all

import (
_ "github.com/bringg/go_redis_ratelimit/algorithm/cloudflare"
_ "github.com/bringg/go_redis_ratelimit/algorithm/gcra"
_ "github.com/bringg/go_redis_ratelimit/algorithm/sliding_window"
)
43 changes: 25 additions & 18 deletions algorithm/cloudflare/cloudflare.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,51 +11,58 @@ import (
const AlgorithmName = "cloudflare"

type Cloudflare struct {
Limit algorithm.Limit
RDB algorithm.Rediser
RDB algorithm.Rediser
}

func init() {
algorithm.Register(&algorithm.RegInfo{
Name: AlgorithmName,
NewAlgorithm: NewAlgorithm,
})
}

key string
func NewAlgorithm(rdb algorithm.Rediser) (algorithm.Algorithm, error) {
return &Cloudflare{
RDB: rdb,
}, nil
}

func (c *Cloudflare) Allow() (*algorithm.Result, error) {
rate := c.Limit.GetRate() - 1
func (c *Cloudflare) Allow(key string, limit algorithm.Limit) (*algorithm.Result, error) {
rate := limit.GetRate() - 1
rateLimiter := ratelimiter.New(&RedisDataStore{
RDB: c.RDB,
}, rate, c.Limit.GetPeriod())
RDB: c.RDB,
ExpirationTime: 2 * limit.GetPeriod(),
}, rate, limit.GetPeriod())

limitStatus, err := rateLimiter.Check(c.key)
limitStatus, err := rateLimiter.Check(key)
if err != nil {
return nil, err
}

rateKey := mapKey(c.key, time.Now().UTC().Truncate(c.Limit.GetPeriod()))
rateKey := mapKey(key, time.Now().UTC().Truncate(limit.GetPeriod()))
currentRate := int64(limitStatus.CurrentRate)

if limitStatus.IsLimited {
return &algorithm.Result{
Limit: c.Limit,
Limit: limit,
Key: rateKey,
Allowed: false,
Remaining: 0,
RetryAfter: *limitStatus.LimitDuration,
ResetAfter: c.Limit.GetPeriod(),
ResetAfter: limit.GetPeriod(),
}, nil
}

if err := rateLimiter.Inc(c.key); err != nil {
if err := rateLimiter.Inc(key); err != nil {
return nil, err
}

return &algorithm.Result{
Limit: c.Limit,
Limit: limit,
Key: rateKey,
Allowed: true,
Remaining: rate - currentRate,
RetryAfter: 0,
ResetAfter: c.Limit.GetPeriod(),
ResetAfter: limit.GetPeriod(),
}, nil
}

func (c *Cloudflare) SetKey(key string) {
c.key = key
}
5 changes: 3 additions & 2 deletions algorithm/cloudflare/redis_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ import (
)

type RedisDataStore struct {
RDB algorithm.Rediser
RDB algorithm.Rediser
ExpirationTime time.Duration
}

func (s *RedisDataStore) Inc(key string, window time.Time) error {
Expand All @@ -20,7 +21,7 @@ func (s *RedisDataStore) Inc(key string, window time.Time) error {

if _, err := s.RDB.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, time.Since(window)+time.Second)
pipe.Expire(ctx, key, s.ExpirationTime)

return nil
}); err != nil {
Expand Down
35 changes: 20 additions & 15 deletions algorithm/gcra/gcra.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,38 @@ import (
const AlgorithmName = "gcra"

type GCRA struct {
Limit algorithm.Limit
RDB algorithm.Rediser
limiter *redis_rate.Limiter
}

key string
func init() {
algorithm.Register(&algorithm.RegInfo{
Name: AlgorithmName,
NewAlgorithm: NewAlgorithm,
})
}

func (c *GCRA) Allow() (*algorithm.Result, error) {
res, err := redis_rate.NewLimiter(c.RDB).Allow(context.Background(), c.key, redis_rate.Limit{
Rate: int(c.Limit.GetRate()),
Period: c.Limit.GetPeriod(),
Burst: int(c.Limit.GetBurst()),
func NewAlgorithm(rdb algorithm.Rediser) (algorithm.Algorithm, error) {
return &GCRA{
limiter: redis_rate.NewLimiter(rdb),
}, nil
}

func (c *GCRA) Allow(key string, limit algorithm.Limit) (*algorithm.Result, error) {
res, err := c.limiter.Allow(context.Background(), key, redis_rate.Limit{
Rate: int(limit.GetRate()),
Period: limit.GetPeriod(),
Burst: int(limit.GetBurst()),
})
if err != nil {
return nil, err
}

return &algorithm.Result{
Limit: c.Limit,
Key: c.key,
Limit: limit,
Key: key,
Allowed: res.Allowed == 1,
Remaining: int64(res.Remaining),
RetryAfter: res.RetryAfter,
ResetAfter: res.ResetAfter,
}, nil
}

// SetKey _
func (c *GCRA) SetKey(key string) {
c.key = key
}
64 changes: 0 additions & 64 deletions algorithm/gcra/gcra_lua.go

This file was deleted.

37 changes: 26 additions & 11 deletions algorithm/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,35 @@ import (
"github.com/go-redis/redis/v8"
)

var Registry []*RegInfo

type (
Limit interface {
GetAlgorithm() string
GetBurst() int64
GetRate() int64
GetBurst() int64
GetAlgorithm() string
GetPeriod() time.Duration
}

Rediser interface {
TxPipeline() redis.Pipeliner
TxPipelined(ctx context.Context, fn func(pipe redis.Pipeliner) error) ([]redis.Cmder, error)
Del(ctx context.Context, keys ...string) *redis.IntCmd
Get(ctx context.Context, key string) *redis.StringCmd
Incr(ctx context.Context, key string) *redis.IntCmd
Eval(ctx context.Context, script string, keys []string, args ...interface{}) *redis.Cmd
EvalSha(ctx context.Context, sha1 string, keys []string, args ...interface{}) *redis.Cmd
ScriptExists(ctx context.Context, hashes ...string) *redis.BoolSliceCmd
ScriptLoad(ctx context.Context, script string) *redis.StringCmd
ZRangeByScoreWithScores(ctx context.Context, key string, opt *redis.ZRangeBy) *redis.ZSliceCmd
ZRemRangeByScore(ctx context.Context, key string, min string, max string) *redis.IntCmd
ZCard(ctx context.Context, key string) *redis.IntCmd
Get(ctx context.Context, key string) *redis.StringCmd
Del(ctx context.Context, keys ...string) *redis.IntCmd
ScriptLoad(ctx context.Context, script string) *redis.StringCmd
ScriptExists(ctx context.Context, hashes ...string) *redis.BoolSliceCmd
ZAdd(ctx context.Context, key string, members ...*redis.Z) *redis.IntCmd
Expire(ctx context.Context, key string, expiration time.Duration) *redis.BoolCmd
ZRemRangeByScore(ctx context.Context, key string, min string, max string) *redis.IntCmd
Eval(ctx context.Context, script string, keys []string, args ...interface{}) *redis.Cmd
EvalSha(ctx context.Context, sha1 string, keys []string, args ...interface{}) *redis.Cmd
TxPipelined(ctx context.Context, fn func(pipe redis.Pipeliner) error) ([]redis.Cmder, error)
ZRangeByScoreWithScores(ctx context.Context, key string, opt *redis.ZRangeBy) *redis.ZSliceCmd
}

Algorithm interface {
Allow(key string, limit Limit) (*Result, error)
}

Result struct {
Expand Down Expand Up @@ -60,4 +66,13 @@ type (
// until Limit and Remaining will be equal.
ResetAfter time.Duration
}

RegInfo struct {
Name string
NewAlgorithm func(rdb Rediser) (Algorithm, error)
}
)

func Register(info *RegInfo) {
Registry = append(Registry, info)
}
23 changes: 14 additions & 9 deletions algorithm/sliding_window/sliding_window.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,26 @@ import (
const AlgorithmName = "sliding_window"

type SlidingWindow struct {
Limit algorithm.Limit
RDB algorithm.Rediser
RDB algorithm.Rediser
}

key string
func init() {
algorithm.Register(&algorithm.RegInfo{
Name: AlgorithmName,
NewAlgorithm: NewAlgorithm,
})
}

func (c *SlidingWindow) SetKey(key string) {
c.key = key
func NewAlgorithm(rdb algorithm.Rediser) (algorithm.Algorithm, error) {
return &SlidingWindow{
RDB: rdb,
}, nil
}

func (c *SlidingWindow) Allow() (r *algorithm.Result, err error) {
limit := c.Limit
func (c *SlidingWindow) Allow(key string, limit algorithm.Limit) (r *algorithm.Result, err error) {
values := []interface{}{limit.GetRate(), limit.GetPeriod().Seconds()}

v, err := script2.Run(context.Background(), c.RDB, []string{c.key}, values...).Result()
v, err := script.Run(context.Background(), c.RDB, []string{key}, values...).Result()
if err != nil {
return nil, err
}
Expand All @@ -39,7 +44,7 @@ func (c *SlidingWindow) Allow() (r *algorithm.Result, err error) {

return &algorithm.Result{
Limit: limit,
Key: c.key,
Key: key,
Allowed: values[0].(int64) == 1,
Remaining: values[1].(int64),
RetryAfter: dur(retryAfter),
Expand Down
2 changes: 1 addition & 1 deletion algorithm/sliding_window/sliding_window_lua.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package sliding_window

import "github.com/go-redis/redis/v8"

var script2 = redis.NewScript(`
var script = redis.NewScript(`
-- this script has side-effects, so it requires replicate commands mode
redis.replicate_commands()
Expand Down
10 changes: 7 additions & 3 deletions examples/cloudflare/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (

"github.com/go-redis/redis/v8"

"github.com/bringg/go_redis_ratelimit"
limiter "github.com/bringg/go_redis_ratelimit"
"github.com/bringg/go_redis_ratelimit/algorithm/cloudflare"
)

Expand All @@ -15,10 +15,14 @@ func main() {
if err != nil {
log.Fatal(err)
}

client := redis.NewClient(option)
l, err := limiter.NewLimiter(client)
if err != nil {
log.Fatal(err)
}

limiter := go_redis_ratelimit.NewLimiter(client)
res, err := limiter.Allow("api_gateway:klu4ik", &go_redis_ratelimit.Limit{
res, err := l.Allow("api_gateway:klu4ik", &limiter.Limit{
Algorithm: cloudflare.AlgorithmName,
Rate: 10,
Period: 10 * time.Second,
Expand Down
6 changes: 5 additions & 1 deletion examples/gcra/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ func main() {
}
client := redis.NewClient(option)

limiter := go_redis_ratelimit.NewLimiter(client)
limiter, err := go_redis_ratelimit.NewLimiter(client)
if err != nil {
log.Fatal(err)
}

res, err := limiter.Allow("api_gateway:klu4ik", &go_redis_ratelimit.Limit{
Algorithm: gcra.AlgorithmName,
Rate: 10,
Expand Down
6 changes: 5 additions & 1 deletion examples/sliding_window/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ func main() {
}
client := redis.NewClient(option)

limiter := go_redis_ratelimit.NewLimiter(client)
limiter, err := go_redis_ratelimit.NewLimiter(client)
if err != nil {
log.Fatal(err)
}

res, err := limiter.Allow("api_gateway:klu4ik", &go_redis_ratelimit.Limit{
Algorithm: sliding_window.AlgorithmName,
Rate: 10,
Expand Down
Loading

0 comments on commit 9184186

Please sign in to comment.