diff --git a/retry/retry.go b/retry/retry.go index 0d5b0b8..50dffde 100644 --- a/retry/retry.go +++ b/retry/retry.go @@ -1,37 +1,52 @@ package retry import ( + "errors" "time" ) -var _ error = errNoMoreRetry("") +var errNoMoreRetry error = errorNoMoreRetry("no more retry") -type errNoMoreRetry string +type errorNoMoreRetry string -func (err errNoMoreRetry) Error() string { +func (err errorNoMoreRetry) Error() string { return string(err) } +func (errorNoMoreRetry) Unwrap() error { + return errNoMoreRetry +} + // IsNoMoreRetry reports whether error is NoMoreRetry error. func IsNoMoreRetry(err error) bool { - _, ok := err.(errNoMoreRetry) - return ok + if e, ok := err.(interface{ Unwrap() []error }); ok { + for _, err := range e.Unwrap() { + if IsNoMoreRetry(err) { + return true + } + } + return false + } + return errors.Is(err, errNoMoreRetry) } // ErrNoMoreRetry tells function does no more retry. -func ErrNoMoreRetry(err string) error { return errNoMoreRetry(err) } +func ErrNoMoreRetry(err string) error { return errorNoMoreRetry(err) } // Do keeps retrying the function until no error is returned. -func Do(fn func() error, attempts, delay int) (err error) { +func Do(fn func() error, attempts, delay int) error { + var errs []error for i := 0; i < attempts; i++ { - if err = fn(); err == nil || IsNoMoreRetry(err) { - return + err := fn() + if err == nil { + return nil } - - if i < attempts-1 { + errs = append(errs, err) + if IsNoMoreRetry(err) { + break + } else if i < attempts-1 { time.Sleep(time.Second * time.Duration(delay)) } } - - return + return errors.Join(errs...) } diff --git a/retry/retry_test.go b/retry/retry_test.go index 24037e1..722b5bd 100644 --- a/retry/retry_test.go +++ b/retry/retry_test.go @@ -7,6 +7,9 @@ import ( ) func TestRetry(t *testing.T) { + if err := ErrNoMoreRetry("error"); !errors.Is(err, errNoMoreRetry) { + t.Error("expected err is errNoMoreRetry; got not") + } var i int if err := Do(func() error { defer func() { i++ }() @@ -23,8 +26,8 @@ func TestRetry(t *testing.T) { return errors.New("error" + strconv.Itoa(i)) }, 3, 1); err == nil { t.Error("expected non-nil error; got nil error") - } else if err.Error() != "error2" { - t.Errorf("expected error2; got %s", err) + } else if expect := "error0\nerror1\nerror2"; err.Error() != expect { + t.Errorf("expected %s; got %s", expect, err) } i = 0