Skip to content

Commit

Permalink
Reset timer when changing ticker duration (#18)
Browse files Browse the repository at this point in the history
* Reset timer when changing ticker duration

For an existing ticker.
Address linting issues.

* Add tests and fix race condition
  • Loading branch information
raphael authored Jul 3, 2024
1 parent bf9bf86 commit 12093f5
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 12 deletions.
22 changes: 16 additions & 6 deletions pool/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,13 +325,17 @@ func (node *Node) Shutdown(ctx context.Context) error {
node.wg.Wait()
node.lock.Lock()
defer node.lock.Unlock()
node.cleanup() // cleanup first then close maps
if err := node.cleanup(); err != nil { // cleanup first then close maps
node.logger.Error(fmt.Errorf("failed to cleanup: %w", err))
}
node.tickerMap.Close()
node.keepAliveMap.Close()
node.workerMap.Close()
node.shutdownMap.Close()
node.nodeReader.Close()
node.nodeStream.Destroy(ctx)
if err := node.nodeStream.Destroy(ctx); err != nil {
node.logger.Error(fmt.Errorf("failed to destroy node event stream: %w", err))
}
node.shutdown = true
node.logger.Info("shutdown")
return nil
Expand Down Expand Up @@ -368,7 +372,9 @@ func (node *Node) Close(ctx context.Context) error {
node.shutdownMap.Close()
}
node.nodeReader.Close()
node.nodeStream.Destroy(ctx)
if err := node.nodeStream.Destroy(ctx); err != nil {
node.logger.Error(fmt.Errorf("failed to destroy node event stream: %w", err))
}
node.closed = true
close(node.stop)
node.lock.Unlock()
Expand Down Expand Up @@ -469,7 +475,9 @@ func (node *Node) handleNodeEvents(c <-chan *streaming.Event) {
}
case <-node.stop:
node.nodeReader.Close()
node.nodeStream.Destroy(ctx)
if err := node.nodeStream.Destroy(ctx); err != nil {
node.logger.Error(fmt.Errorf("failed to destroy node event stream: %w", err))
}
return
}
}
Expand Down Expand Up @@ -639,7 +647,9 @@ func (node *Node) handleShutdownMapUpdate() {
for _, w := range node.localWorkers {
// Add to stream instead of calling stop directly to ensure that the
// worker is stopped only after all pending events have been processed.
w.stream.Add(context.Background(), evShutdown, []byte(requestingNode))
if _, err := w.stream.Add(context.Background(), evShutdown, []byte(requestingNode)); err != nil {
node.logger.Error(fmt.Errorf("failed to add shutdown event to worker stream %q: %w", workerStreamName(w.ID), err))
}
}
}

Expand Down Expand Up @@ -757,7 +767,7 @@ func (jh jumpHash) Hash(key string, numBuckets int64) int64 {
var j int64

jh.h.Reset()
io.WriteString(jh.h, key)
io.WriteString(jh.h, key) // nolint: errcheck
sum := jh.h.Sum64()

for j < numBuckets {
Expand Down
9 changes: 7 additions & 2 deletions pool/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,10 @@ func (sched *scheduler) startJobs(ctx context.Context, jobs []*JobParam) error {
sched.logger.Error(err, "failed to dispatch job", "job", job.Key)
continue
}
sched.jobMap.Set(ctx, job.Key, time.Now().String())
if _, err := sched.jobMap.Set(ctx, job.Key, time.Now().String()); err != nil {
sched.logger.Error(err, "failed to store job", "job", job.Key)
continue
}
}
return nil
}
Expand All @@ -164,7 +167,9 @@ func (sched *scheduler) stopJobs(ctx context.Context, plan *JobPlan) error {
sched.logger.Error(err, "failed to stop job", "job", key)
continue
}
sched.jobMap.Delete(ctx, key)
if _, err := sched.jobMap.Delete(ctx, key); err != nil {
sched.logger.Error(err, "failed to delete job", "job", key)
}
}
return nil
}
Expand Down
16 changes: 12 additions & 4 deletions pool/ticker.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,12 @@ func (node *Node) NewTicker(ctx context.Context, name string, d time.Duration, o
logger: logger,
}
if current, ok := node.tickerMap.Get(name); ok {
t.next = current
} else {
_, curd := deserialize(current)
if d == curd {
t.next = current
}
}
if t.next == "" {
next := serialize(time.Now().Add(d), d)
if _, err := t.tickerMap.Set(ctx, t.name, next); err != nil {
return nil, fmt.Errorf("failed to store tick and duration: %s", err)
Expand Down Expand Up @@ -94,7 +98,9 @@ func (t *Ticker) Reset(d time.Duration) {
func (t *Ticker) Stop() {
t.lock.Lock()
t.timer.Stop()
t.tickerMap.Delete(context.Background(), t.name)
if _, err := t.tickerMap.Delete(context.Background(), t.name); err != nil {
t.logger.Error(err, "msg", "failed to delete ticker")
}
t.tickerMap.Unsubscribe(t.mapch)
t.mapch = nil
t.lock.Unlock()
Expand All @@ -104,7 +110,9 @@ func (t *Ticker) Stop() {
// handleEvents handles events from the ticker timer and map.
func (t *Ticker) handleEvents() {
defer t.wg.Done()
t.lock.Lock()
ch := t.mapch
t.lock.Unlock()
for {
select {
case _, ok := <-ch:
Expand Down Expand Up @@ -168,7 +176,7 @@ func (t *Ticker) handleTick() {
// initTimer sets the timer to fire at the next tick.
func (t *Ticker) initTimer() {
next, _ := deserialize(t.next)
d := next.Sub(time.Now())
d := time.Until(next)
if d < 0 {
d = 0
}
Expand Down
45 changes: 45 additions & 0 deletions pool/ticker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ func TestNewTicker(t *testing.T) {
require.NotNil(t, ticker)
ts := <-ticker.C
assert.WithinDuration(t, now.Add(d), ts, time.Second, "invalid tick value")
next, dur := deserialize(ticker.next)
assert.WithinDuration(t, now.Add(d), next, time.Second, "invalid next tick value")
assert.Equal(t, d, dur, "invalid duration")
ticker.Stop()
var ok bool
timer := time.NewTimer(2 * d)
Expand All @@ -39,6 +42,48 @@ func TestNewTicker(t *testing.T) {
assert.NoError(t, node.Shutdown(ctx))
}

func TestReplaceTickerTimer(t *testing.T) {
var (
rdb = redis.NewClient(&redis.Options{Addr: "localhost:6379", Password: redisPwd})
tctx = testContext(t)
ctx = log.Context(tctx, log.WithOutput(io.Discard))
testName = strings.Replace(t.Name(), "/", "_", -1)
node = newTestNode(t, ctx, rdb, testName)
d1, d2 = 10 * time.Millisecond, 20 * time.Millisecond
)
now := time.Now()
ticker, err := node.NewTicker(ctx, testName, d1)
assert.NoError(t, err)
require.NotNil(t, ticker)
next, dur := deserialize(ticker.next)
assert.WithinDuration(t, now.Add(d1), next, time.Second, "invalid next tick value")
assert.Equal(t, d1, dur, "invalid duration")
ticker2, err := node.NewTicker(ctx, testName, d2)
assert.NoError(t, err)
require.NotNil(t, ticker2)
next, dur = deserialize(ticker2.next)
assert.WithinDuration(t, now.Add(d2), next, time.Second, "invalid next tick value")
assert.Equal(t, d2, dur, "invalid duration")
ticker.Stop()
ticker2.Stop()
var ok, ok2 bool
timer := time.NewTimer(2 * d1)
select {
case <-timer.C:
ok = true
case <-ticker.C:
}
timer2 := time.NewTimer(2 * d2)
select {
case <-timer2.C:
ok2 = true
case <-ticker2.C:
}
assert.True(t, ok, "ticker did not stop")
assert.True(t, ok2, "ticker2 did not stop")
assert.NoError(t, node.Shutdown(ctx))
}

func TestReset(t *testing.T) {
var (
rdb = redis.NewClient(&redis.Options{Addr: "localhost:6379", Password: redisPwd})
Expand Down

0 comments on commit 12093f5

Please sign in to comment.