From ebf989c5723f6db2ddfeb8f9ea17af2df9a47e69 Mon Sep 17 00:00:00 2001 From: Evgenii Zakharin Date: Sun, 13 Dec 2020 18:11:00 +0300 Subject: [PATCH] added middleware in consumer --- consumer.go | 40 ++++++++++++++++++++++++++++++++------- examples/consumer/main.go | 12 +++++++++++- 2 files changed, 44 insertions(+), 8 deletions(-) diff --git a/consumer.go b/consumer.go index 4cd684e..74fe5bc 100644 --- a/consumer.go +++ b/consumer.go @@ -9,24 +9,35 @@ import ( "github.com/streadway/amqp" ) +// MiddlewareFunc defines the handler +type MiddlewareFunc func(handler HandlerFunc) HandlerFunc + +type contextKey int + +// QueueNameKey key in context +const QueueNameKey contextKey = 0 + // Consumer struct type Consumer struct { Connection - queues map[string]*Queue - exchanges map[string]*Exchange - wg *sync.WaitGroup + queues map[string]*Queue + exchanges map[string]*Exchange + middlewares []MiddlewareFunc + wg *sync.WaitGroup } // NewConsumer returns a new Consumer struct func NewConsumer(uri string, logger Logger) *Consumer { + middlewares := []MiddlewareFunc{} exchanges := make(map[string]*Exchange) queues := make(map[string]*Queue) err := make(chan error) ctx, cancel := context.WithCancel(context.Background()) wg := &sync.WaitGroup{} return &Consumer{ - exchanges: exchanges, - queues: queues, + middlewares: middlewares, + exchanges: exchanges, + queues: queues, Connection: Connection{ uri: uri, err: err, @@ -96,6 +107,11 @@ func (c *Consumer) RegisterExchange(exchange *Exchange) { c.exchanges[exchange.Name] = exchange } +//RegisterMiddleware register middleware +func (c *Consumer) RegisterMiddleware(m ...MiddlewareFunc) { + c.middlewares = append(c.middlewares, m...) +} + func (c *Consumer) reconnect() error { if err := c.connect(); err != nil { return err @@ -200,6 +216,15 @@ func (c *Consumer) recoveryWorker(queue *Queue, workerNumber int, delivery *amqp go c.consumeWorker(queue, workerNumber) } +func buildChain(f HandlerFunc, m ...MiddlewareFunc) HandlerFunc { + // if our chain is done, use the original handler func + if len(m) == 0 { + return f + } + // otherwise nest the handler funcs + return m[0](buildChain(f, m[1:cap(m)]...)) +} + func (c *Consumer) consumeWorker(queue *Queue, workerNumber int) { defer c.wg.Done() @@ -209,8 +234,9 @@ func (c *Consumer) consumeWorker(queue *Queue, workerNumber int) { case delivery := <-queue.deliveries: defer c.recoveryWorker(queue, workerNumber, &delivery) - c.logger.Debugf("Got event: queue=%s, worker=%d", queue.Name, workerNumber) - if queue.handler(c.ctx, delivery) { + ctx := context.WithValue(c.ctx, QueueNameKey, queue.Name) + result := buildChain(queue.handler, c.middlewares...)(ctx, delivery) + if result { if err := delivery.Ack(false); err != nil { c.logger.Errorf("Falied ack %s", queue.Name) } diff --git a/examples/consumer/main.go b/examples/consumer/main.go index b384b53..5f46aea 100644 --- a/examples/consumer/main.go +++ b/examples/consumer/main.go @@ -13,8 +13,17 @@ import ( rmqclient "github.com/zaharinea/go-rmq-client" ) +func loggingMiddleware(handler rmqclient.HandlerFunc) rmqclient.HandlerFunc { + return func(ctx context.Context, msg amqp.Delivery) bool { + queueName := ctx.Value(rmqclient.QueueNameKey).(string) + fmt.Printf("start processing event: queue=%s, msg=%s\n", queueName, string(msg.Body)) + res := handler(ctx, msg) + fmt.Printf("end processing event: queue=%s, msg=%s\n", queueName, string(msg.Body)) + return res + } +} + func handler(ctx context.Context, msg amqp.Delivery) bool { - fmt.Printf("event: msg=%s\n", string(msg.Body)) time.Sleep(time.Second * 3) return true } @@ -38,6 +47,7 @@ func main() { "x-dead-letter-routing-key": "queue2-failed", }).SetRequeue(false).SetHandler(handler).SetCountWorkers(4) consumer.RegisterQueue(queue2, queue2Failed) + consumer.RegisterMiddleware(loggingMiddleware) consumer.Start()