diff --git a/pool/node.go b/pool/node.go index 2004990..71acdd5 100644 --- a/pool/node.go +++ b/pool/node.go @@ -325,13 +325,17 @@ func (node *Node) Shutdown(ctx context.Context) error { node.wg.Wait() node.lock.Lock() defer node.lock.Unlock() - node.cleanup() // cleanup first then close maps + if err := node.cleanup(); err != nil { // cleanup first then close maps + node.logger.Error(fmt.Errorf("failed to cleanup: %w", err)) + } node.tickerMap.Close() node.keepAliveMap.Close() node.workerMap.Close() node.shutdownMap.Close() node.nodeReader.Close() - node.nodeStream.Destroy(ctx) + if err := node.nodeStream.Destroy(ctx); err != nil { + node.logger.Error(fmt.Errorf("failed to destroy node event stream: %w", err)) + } node.shutdown = true node.logger.Info("shutdown") return nil @@ -368,7 +372,9 @@ func (node *Node) Close(ctx context.Context) error { node.shutdownMap.Close() } node.nodeReader.Close() - node.nodeStream.Destroy(ctx) + if err := node.nodeStream.Destroy(ctx); err != nil { + node.logger.Error(fmt.Errorf("failed to destroy node event stream: %w", err)) + } node.closed = true close(node.stop) node.lock.Unlock() @@ -469,7 +475,9 @@ func (node *Node) handleNodeEvents(c <-chan *streaming.Event) { } case <-node.stop: node.nodeReader.Close() - node.nodeStream.Destroy(ctx) + if err := node.nodeStream.Destroy(ctx); err != nil { + node.logger.Error(fmt.Errorf("failed to destroy node event stream: %w", err)) + } return } } @@ -639,7 +647,9 @@ func (node *Node) handleShutdownMapUpdate() { for _, w := range node.localWorkers { // Add to stream instead of calling stop directly to ensure that the // worker is stopped only after all pending events have been processed. - w.stream.Add(context.Background(), evShutdown, []byte(requestingNode)) + if _, err := w.stream.Add(context.Background(), evShutdown, []byte(requestingNode)); err != nil { + node.logger.Error(fmt.Errorf("failed to add shutdown event to worker stream %q: %w", workerStreamName(w.ID), err)) + } } } @@ -757,7 +767,7 @@ func (jh jumpHash) Hash(key string, numBuckets int64) int64 { var j int64 jh.h.Reset() - io.WriteString(jh.h, key) + io.WriteString(jh.h, key) // nolint: errcheck sum := jh.h.Sum64() for j < numBuckets { diff --git a/pool/scheduler.go b/pool/scheduler.go index 6ff192c..56de6d2 100644 --- a/pool/scheduler.go +++ b/pool/scheduler.go @@ -137,7 +137,10 @@ func (sched *scheduler) startJobs(ctx context.Context, jobs []*JobParam) error { sched.logger.Error(err, "failed to dispatch job", "job", job.Key) continue } - sched.jobMap.Set(ctx, job.Key, time.Now().String()) + if _, err := sched.jobMap.Set(ctx, job.Key, time.Now().String()); err != nil { + sched.logger.Error(err, "failed to store job", "job", job.Key) + continue + } } return nil } @@ -164,7 +167,9 @@ func (sched *scheduler) stopJobs(ctx context.Context, plan *JobPlan) error { sched.logger.Error(err, "failed to stop job", "job", key) continue } - sched.jobMap.Delete(ctx, key) + if _, err := sched.jobMap.Delete(ctx, key); err != nil { + sched.logger.Error(err, "failed to delete job", "job", key) + } } return nil } diff --git a/pool/ticker.go b/pool/ticker.go index 94930b8..4cba99d 100644 --- a/pool/ticker.go +++ b/pool/ticker.go @@ -53,8 +53,12 @@ func (node *Node) NewTicker(ctx context.Context, name string, d time.Duration, o logger: logger, } if current, ok := node.tickerMap.Get(name); ok { - t.next = current - } else { + _, curd := deserialize(current) + if d == curd { + t.next = current + } + } + if t.next == "" { next := serialize(time.Now().Add(d), d) if _, err := t.tickerMap.Set(ctx, t.name, next); err != nil { return nil, fmt.Errorf("failed to store tick and duration: %s", err) @@ -94,7 +98,9 @@ func (t *Ticker) Reset(d time.Duration) { func (t *Ticker) Stop() { t.lock.Lock() t.timer.Stop() - t.tickerMap.Delete(context.Background(), t.name) + if _, err := t.tickerMap.Delete(context.Background(), t.name); err != nil { + t.logger.Error(err, "msg", "failed to delete ticker") + } t.tickerMap.Unsubscribe(t.mapch) t.mapch = nil t.lock.Unlock() @@ -104,7 +110,9 @@ func (t *Ticker) Stop() { // handleEvents handles events from the ticker timer and map. func (t *Ticker) handleEvents() { defer t.wg.Done() + t.lock.Lock() ch := t.mapch + t.lock.Unlock() for { select { case _, ok := <-ch: @@ -168,7 +176,7 @@ func (t *Ticker) handleTick() { // initTimer sets the timer to fire at the next tick. func (t *Ticker) initTimer() { next, _ := deserialize(t.next) - d := next.Sub(time.Now()) + d := time.Until(next) if d < 0 { d = 0 } diff --git a/pool/ticker_test.go b/pool/ticker_test.go index 3e270eb..03c95c7 100644 --- a/pool/ticker_test.go +++ b/pool/ticker_test.go @@ -27,6 +27,9 @@ func TestNewTicker(t *testing.T) { require.NotNil(t, ticker) ts := <-ticker.C assert.WithinDuration(t, now.Add(d), ts, time.Second, "invalid tick value") + next, dur := deserialize(ticker.next) + assert.WithinDuration(t, now.Add(d), next, time.Second, "invalid next tick value") + assert.Equal(t, d, dur, "invalid duration") ticker.Stop() var ok bool timer := time.NewTimer(2 * d) @@ -39,6 +42,48 @@ func TestNewTicker(t *testing.T) { assert.NoError(t, node.Shutdown(ctx)) } +func TestReplaceTickerTimer(t *testing.T) { + var ( + rdb = redis.NewClient(&redis.Options{Addr: "localhost:6379", Password: redisPwd}) + tctx = testContext(t) + ctx = log.Context(tctx, log.WithOutput(io.Discard)) + testName = strings.Replace(t.Name(), "/", "_", -1) + node = newTestNode(t, ctx, rdb, testName) + d1, d2 = 10 * time.Millisecond, 20 * time.Millisecond + ) + now := time.Now() + ticker, err := node.NewTicker(ctx, testName, d1) + assert.NoError(t, err) + require.NotNil(t, ticker) + next, dur := deserialize(ticker.next) + assert.WithinDuration(t, now.Add(d1), next, time.Second, "invalid next tick value") + assert.Equal(t, d1, dur, "invalid duration") + ticker2, err := node.NewTicker(ctx, testName, d2) + assert.NoError(t, err) + require.NotNil(t, ticker2) + next, dur = deserialize(ticker2.next) + assert.WithinDuration(t, now.Add(d2), next, time.Second, "invalid next tick value") + assert.Equal(t, d2, dur, "invalid duration") + ticker.Stop() + ticker2.Stop() + var ok, ok2 bool + timer := time.NewTimer(2 * d1) + select { + case <-timer.C: + ok = true + case <-ticker.C: + } + timer2 := time.NewTimer(2 * d2) + select { + case <-timer2.C: + ok2 = true + case <-ticker2.C: + } + assert.True(t, ok, "ticker did not stop") + assert.True(t, ok2, "ticker2 did not stop") + assert.NoError(t, node.Shutdown(ctx)) +} + func TestReset(t *testing.T) { var ( rdb = redis.NewClient(&redis.Options{Addr: "localhost:6379", Password: redisPwd})