diff --git a/core/shared/src/main/scala/cats/effect/IO.scala b/core/shared/src/main/scala/cats/effect/IO.scala index 7fa83b5176f..572a53fc739 100644 --- a/core/shared/src/main/scala/cats/effect/IO.scala +++ b/core/shared/src/main/scala/cats/effect/IO.scala @@ -662,14 +662,14 @@ sealed abstract class IO[+A] private () extends IOPlatform[A] { * @param policy * the policy to use * - * @param onRetry + * @param onError * the effect to invoke on every retry decision */ def retry( policy: Retry[IO, Throwable], - onRetry: (Retry.Status, Throwable, Retry.Decision) => IO[Unit] + onError: (Retry.Status, Throwable, Retry.Decision) => IO[Unit] ): IO[A] = - Retry.retry(policy, onRetry)(this) + Retry.retry(policy, onError)(this) /** * Inverse of `attempt` diff --git a/std/shared/src/main/scala/cats/effect/std/Retry.scala b/std/shared/src/main/scala/cats/effect/std/Retry.scala index 52befc03c24..20030aa20dd 100644 --- a/std/shared/src/main/scala/cats/effect/std/Retry.scala +++ b/std/shared/src/main/scala/cats/effect/std/Retry.scala @@ -486,7 +486,7 @@ object Retry { * @param policy * the policy to use * - * @param onRetry + * @param onError * the effect to invoke on every retry decision * * @param fa @@ -494,9 +494,9 @@ object Retry { */ def retry[F[_], A, E]( policy: Retry[F, E], - onRetry: (Status, E, Decision) => F[Unit] + onError: (Status, E, Decision) => F[Unit] )(fa: F[A])(implicit F: GenTemporal[F, E]): F[A] = - doRetry(policy, Some(onRetry))(fa) + doRetry(policy, Some(onError))(fa) /** * The return policy that always gives up. diff --git a/std/shared/src/main/scala/cats/effect/std/syntax/RetrySyntax.scala b/std/shared/src/main/scala/cats/effect/std/syntax/RetrySyntax.scala index 2c3ee4bf2f0..3edfc04018d 100644 --- a/std/shared/src/main/scala/cats/effect/std/syntax/RetrySyntax.scala +++ b/std/shared/src/main/scala/cats/effect/std/syntax/RetrySyntax.scala @@ -31,8 +31,8 @@ final class RetryOps[F[_], A] private[syntax] (private val fa: F[A]) extends Any def retry[E]( policy: Retry[F, E], - onRetry: (Retry.Status, E, Retry.Decision) => F[Unit] + onError: (Retry.Status, E, Retry.Decision) => F[Unit] )(implicit F: GenTemporal[F, E]): F[A] = - Retry.retry(policy, onRetry)(fa) + Retry.retry(policy, onError)(fa) } diff --git a/tests/shared/src/test/scala/cats/effect/std/RetrySpec.scala b/tests/shared/src/test/scala/cats/effect/std/RetrySpec.scala index 316c0da36e6..8be20e94ce8 100644 --- a/tests/shared/src/test/scala/cats/effect/std/RetrySpec.scala +++ b/tests/shared/src/test/scala/cats/effect/std/RetrySpec.scala @@ -17,8 +17,11 @@ package cats.effect.std import cats.{Hash, Show} -import cats.effect.{BaseSpec, IO, Ref} +import cats.data.EitherT +import cats.effect.{BaseSpec, IO, Ref, Temporal} +import cats.mtl.Handle import cats.syntax.applicative._ +import cats.syntax.flatMap._ import cats.syntax.functor._ import cats.syntax.semigroup._ @@ -453,6 +456,112 @@ class RetrySpec extends BaseSpec { } + "Retry MTL" should { + + sealed trait Errors + final class Error1 extends Errors + final class Error2 extends Errors + + type RetryAttempt = (Status, Decision, Errors) + + def mtlRetry[F[_], E, A]( + action: F[A], + policy: Retry[F, E], + onRetry: (Status, E, Decision) => F[Unit] + )(implicit F: Temporal[F], H: Handle[F, E]): F[A] = + F.tailRecM(Status.initial) { status => + H.attempt(action).flatMap { + case Left(error) => + policy + .decide(status, error) + .flatTap(decision => onRetry(status, error, decision)) + .flatMap { + case retry: Decision.Retry => + F.delayBy(F.pure(Left(status.withRetry(retry.delay))), retry.delay) + + case _: Decision.GiveUp => + H.raise(error) + } + + case Right(success) => + F.pure(Right(success)) + } + } + + implicit val outputHash: Hash[(Either[Errors, Unit], List[RetryAttempt])] = + Hash.fromUniversalHashCode + + implicit val outputShow: Show[(Either[Errors, Unit], List[RetryAttempt])] = + Show.fromToString + + "give up on mismatched errors" in ticked { implicit ticker => + type F[A] = EitherT[IO, Errors, A] + + val maxRetries = 2 + val delay = 1.second + + val error = new Error2 + val policy = Retry + .constantDelay[F, Errors](delay) + .withMaxRetries(maxRetries) + .withErrorMatcher(Retry.ErrorMatcher[F, Errors].only[Error1]) + + val expected: List[RetryAttempt] = List( + (Status(0, Duration.Zero), Decision.giveUp, error) + ) + + val io: F[Unit] = Handle[F, Errors].raise[Errors, Unit](error) + + val run = + for { + ref <- IO.ref(List.empty[RetryAttempt]) + result <- mtlRetry[F, Errors, Unit]( + io, + policy, + (s, e: Errors, d) => EitherT.liftF(ref.update(_ :+ (s, d, e))) + ).value + attempts <- ref.get + } yield (result, attempts) + + run must completeAs((Left(error), expected)) + } + + "retry only on matching errors" in ticked { implicit ticker => + type F[A] = EitherT[IO, Errors, A] + + val maxRetries = 2 + val delay = 1.second + + val error = new Error1 + val policy = Retry + .constantDelay[F, Errors](delay) + .withMaxRetries(maxRetries) + .withErrorMatcher(Retry.ErrorMatcher[F, Errors].only[Error1]) + + val expected: List[RetryAttempt] = List( + (Status(0, Duration.Zero), Decision.retry(delay), error), + (Status(1, 1.second), Decision.retry(delay), error), + (Status(2, 2.seconds), Decision.giveUp, error) + ) + + val io: F[Unit] = Handle[F, Errors].raise[Errors, Unit](error) + + val run = + for { + ref <- IO.ref(List.empty[RetryAttempt]) + result <- mtlRetry[F, Errors, Unit]( + io, + policy, + (s, e: Errors, d) => EitherT.liftF(ref.update(_ :+ (s, d, e))) + ).value + attempts <- ref.get + } yield (result, attempts) + + run must completeAs((Left(error), expected)) + } + + } + private def run[A](policy: Retry[IO, Throwable])(io: IO[A]): IO[List[RetryAttempt]] = for { ref <- IO.ref(List.empty[RetryAttempt])