diff --git a/main.go b/main.go index f1f6218..63b6aa6 100644 --- a/main.go +++ b/main.go @@ -45,6 +45,10 @@ type Config struct { // Default: Prefix string + // SkipOnError + // Default: false + SkipOnError bool + // Period Period time.Duration @@ -53,16 +57,22 @@ type Config struct { Filter func(*fiber.Ctx) bool // Key allows to use a custom handler to create custom keys - // Default: func(c *fiber.Ctx) string { - // return c.IP() + // Default: func(ctx *fiber.Ctx) string { + // return ctx.IP() // } Key func(*fiber.Ctx) string // Handler is called when a request hits the limit - // Default: func(c *fiber.Ctx) { - // c.Status(cfg.StatusCode).SendString(cfg.Message) + // Default: func(ctx *fiber.Ctx) { + // ctx.Status(cfg.StatusCode).SendString(cfg.Message) // } Handler func(*fiber.Ctx) + + // ErrHandler is called when a error happen inside go_limiiter lib + // Default: func(err error, ctx *fiber.Ctx) { + // ctx.Status(http.StatusInternalServerError).SendString(err.Error()) + // } + ErrHandler func(error, *fiber.Ctx) } // New ... @@ -72,14 +82,20 @@ func New(config Config) func(*fiber.Ctx) { } if config.Handler == nil { - config.Handler = func(c *fiber.Ctx) { - c.Status(config.StatusCode).SendString(config.Message) + config.Handler = func(ctx *fiber.Ctx) { + ctx.Status(config.StatusCode).SendString(config.Message) + } + } + + if config.ErrHandler == nil { + config.ErrHandler = func(err error, ctx *fiber.Ctx) { + ctx.Status(http.StatusInternalServerError).SendString(err.Error()) } } if config.Key == nil { - config.Key = func(c *fiber.Ctx) string { - return c.IP() + config.Key = func(ctx *fiber.Ctx) string { + return ctx.IP() } } @@ -122,35 +138,43 @@ func New(config Config) func(*fiber.Ctx) { // override default limiter prefix limiter.Prefix = config.Prefix - return func(c *fiber.Ctx) { + return func(ctx *fiber.Ctx) { // Filter request to skip middleware - if config.Filter != nil && config.Filter(c) { - c.Next() + if config.Filter != nil && config.Filter(ctx) { + ctx.Next() + return } - result, err := limiter.Allow(config.Key(c), limit) + result, err := limiter.Allow(config.Key(ctx), limit) // if we have error lets just pass the request if err != nil { - c.Next() + if config.SkipOnError { + ctx.Next() + + return + } + + config.ErrHandler(err, ctx) + return } // Check if hits exceed the max if !result.Allowed { // Call Handler func - config.Handler(c) + config.Handler(ctx) // Return response with Retry-After header // https://tools.ietf.org/html/rfc6584 - c.Set("Retry-After", strconv.FormatInt(time.Now().Add(result.RetryAfter).Unix(), 10)) + ctx.Set("Retry-After", strconv.FormatInt(time.Now().Add(result.RetryAfter).Unix(), 10)) return } // We can continue, update RateLimit headers - c.Set("X-RateLimit-Limit", strconv.Itoa(config.Max)) - c.Set("X-RateLimit-Remaining", strconv.FormatInt(result.Remaining, 10)) - c.Set("X-RateLimit-Reset", strconv.FormatInt(time.Now().Add(result.ResetAfter).Unix(), 10)) + ctx.Set("X-RateLimit-Limit", strconv.Itoa(config.Max)) + ctx.Set("X-RateLimit-Remaining", strconv.FormatInt(result.Remaining, 10)) + ctx.Set("X-RateLimit-Reset", strconv.FormatInt(time.Now().Add(result.ResetAfter).Unix(), 10)) // Bye! - c.Next() + ctx.Next() } }