diff --git a/watcher.go b/watcher.go index caa5a68..d2a7330 100644 --- a/watcher.go +++ b/watcher.go @@ -4,10 +4,10 @@ package rpc import ( + "context" "errors" "sync" "sync/atomic" - "time" ) // ErrWatcherShutdown is returned when the watcher is shut down. @@ -35,8 +35,8 @@ func freeEvent(e *event) { type Watcher interface { // Wait will return value when the key is triggered. Wait() ([]byte, error) - // WaitTimeout acts like Wait but takes a timeout. - WaitTimeout(time.Duration) ([]byte, error) + // WaitWithContext acts like Wait but takes a context. + WaitWithContext(context.Context) ([]byte, error) // Stop stops the watch. Stop() error } @@ -89,22 +89,16 @@ func (w *watcher) Wait() (value []byte, err error) { return } -func (w *watcher) WaitTimeout(timeout time.Duration) (value []byte, err error) { - if timeout <= 0 { - return w.Wait() - } - timer := time.NewTimer(timeout) +func (w *watcher) WaitWithContext(ctx context.Context) (value []byte, err error) { select { case e := <-w.C: - timer.Stop() w.triggerNext() value = e.Value err = e.Error freeEvent(e) - case <-timer.C: - err = ErrTimeout + case <-ctx.Done(): + err = ctx.Err() case <-w.done: - timer.Stop() err = ErrWatcherShutdown } return diff --git a/watcher_test.go b/watcher_test.go index fbe1993..ae8d360 100644 --- a/watcher_test.go +++ b/watcher_test.go @@ -4,6 +4,7 @@ package rpc import ( + "context" "testing" "time" ) @@ -16,7 +17,7 @@ func TestWatcherTrigger(t *testing.T) { watcher.trigger(e) } for i := byte(0); i < 255; i++ { - v, err := watcher.WaitTimeout(0) + v, err := watcher.Wait() if err != nil { t.Error(err) } else if len(v) == 0 { @@ -42,7 +43,8 @@ func TestWatcherTriggerTimeout(t *testing.T) { watcher.trigger(e) } for i := byte(0); i < 255; i++ { - v, err := watcher.WaitTimeout(time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + v, err := watcher.WaitWithContext(ctx) if err != nil { t.Error(err) } else if len(v) == 0 { @@ -50,11 +52,12 @@ func TestWatcherTriggerTimeout(t *testing.T) { } else if v[0] != i { t.Error("out of order") } + cancel() } go func() { - close(watcher.done) + watcher.stop() }() - _, err := watcher.WaitTimeout(time.Minute) + _, err := watcher.Wait() if err != ErrWatcherShutdown { t.Error(err) } @@ -62,9 +65,22 @@ func TestWatcherTriggerTimeout(t *testing.T) { func TestWatcherTriggerTimeoutErr(t *testing.T) { watcher := &watcher{C: make(chan *event, 10), done: make(chan struct{}, 1)} - _, err := watcher.WaitTimeout(time.Millisecond * 1) - if err != ErrTimeout { - t.Error(err) + { + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*1) + _, err := watcher.WaitWithContext(ctx) + if err == nil { + t.Error() + } + cancel() } - close(watcher.done) + watcher.stop() + { + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*1) + _, err := watcher.WaitWithContext(ctx) + if err != ErrWatcherShutdown { + t.Error(err) + } + cancel() + } + }