Skip to content

Commit

Permalink
change WaitTimeout to WaitWithContext
Browse files Browse the repository at this point in the history
  • Loading branch information
hslam committed Feb 8, 2021
1 parent b6917b1 commit 915979b
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 20 deletions.
18 changes: 6 additions & 12 deletions watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
package rpc

import (
"context"
"errors"
"sync"
"sync/atomic"
"time"
)

// ErrWatcherShutdown is returned when the watcher is shut down.
Expand Down Expand Up @@ -35,8 +35,8 @@ func freeEvent(e *event) {
type Watcher interface {
// Wait will return value when the key is triggered.
Wait() ([]byte, error)
// WaitTimeout acts like Wait but takes a timeout.
WaitTimeout(time.Duration) ([]byte, error)
// WaitWithContext acts like Wait but takes a context.
WaitWithContext(context.Context) ([]byte, error)
// Stop stops the watch.
Stop() error
}
Expand Down Expand Up @@ -89,22 +89,16 @@ func (w *watcher) Wait() (value []byte, err error) {
return
}

func (w *watcher) WaitTimeout(timeout time.Duration) (value []byte, err error) {
if timeout <= 0 {
return w.Wait()
}
timer := time.NewTimer(timeout)
func (w *watcher) WaitWithContext(ctx context.Context) (value []byte, err error) {
select {
case e := <-w.C:
timer.Stop()
w.triggerNext()
value = e.Value
err = e.Error
freeEvent(e)
case <-timer.C:
err = ErrTimeout
case <-ctx.Done():
err = ctx.Err()
case <-w.done:
timer.Stop()
err = ErrWatcherShutdown
}
return
Expand Down
32 changes: 24 additions & 8 deletions watcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package rpc

import (
"context"
"testing"
"time"
)
Expand All @@ -16,7 +17,7 @@ func TestWatcherTrigger(t *testing.T) {
watcher.trigger(e)
}
for i := byte(0); i < 255; i++ {
v, err := watcher.WaitTimeout(0)
v, err := watcher.Wait()
if err != nil {
t.Error(err)
} else if len(v) == 0 {
Expand All @@ -42,29 +43,44 @@ func TestWatcherTriggerTimeout(t *testing.T) {
watcher.trigger(e)
}
for i := byte(0); i < 255; i++ {
v, err := watcher.WaitTimeout(time.Minute)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
v, err := watcher.WaitWithContext(ctx)
if err != nil {
t.Error(err)
} else if len(v) == 0 {
t.Error("len == 0")
} else if v[0] != i {
t.Error("out of order")
}
cancel()
}
go func() {
close(watcher.done)
watcher.stop()
}()
_, err := watcher.WaitTimeout(time.Minute)
_, err := watcher.Wait()
if err != ErrWatcherShutdown {
t.Error(err)
}
}

func TestWatcherTriggerTimeoutErr(t *testing.T) {
watcher := &watcher{C: make(chan *event, 10), done: make(chan struct{}, 1)}
_, err := watcher.WaitTimeout(time.Millisecond * 1)
if err != ErrTimeout {
t.Error(err)
{
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*1)
_, err := watcher.WaitWithContext(ctx)
if err == nil {
t.Error()
}
cancel()
}
close(watcher.done)
watcher.stop()
{
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*1)
_, err := watcher.WaitWithContext(ctx)
if err != ErrWatcherShutdown {
t.Error(err)
}
cancel()
}

}

0 comments on commit 915979b

Please sign in to comment.