Skip to content

Commit

Permalink
Retry: add MTL-specific tests
Browse files Browse the repository at this point in the history
  • Loading branch information
iRevive committed Jul 28, 2024
1 parent ff5bc74 commit 8af8301
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 52 deletions.
6 changes: 3 additions & 3 deletions core/shared/src/main/scala/cats/effect/IO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -662,14 +662,14 @@ sealed abstract class IO[+A] private () extends IOPlatform[A] {
* @param policy
* the policy to use
*
* @param onRetry
* @param onError
* the effect to invoke on every retry decision
*/
def retry(
policy: Retry[IO, Throwable],
onRetry: (Retry.Status, Throwable, Retry.Decision) => IO[Unit]
onError: (Retry.Status, Throwable, Retry.Decision) => IO[Unit]
): IO[A] =
Retry.retry(policy, onRetry)(this)
Retry.retry(policy, onError)(this)

/**
* Inverse of `attempt`
Expand Down
8 changes: 4 additions & 4 deletions std/shared/src/main/scala/cats/effect/std/Retry.scala
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ object Retry {
* Creates an error matcher that matches all errors.
*/
def all[F[_]: Applicative, E]: ErrorMatcher[F, E] =
new Impl[F, E]({ (_: E) => Applicative[F].pure(true) })
new Impl[F, E]({ case _: E => Applicative[F].pure(true) })

/**
* Creates an error matcher using the given `matcher` under the hood.
Expand Down Expand Up @@ -486,17 +486,17 @@ object Retry {
* @param policy
* the policy to use
*
* @param onRetry
* @param onError
* the effect to invoke on every retry decision
*
* @param fa
* the effect
*/
def retry[F[_], A, E](
policy: Retry[F, E],
onRetry: (Status, E, Decision) => F[Unit]
onError: (Status, E, Decision) => F[Unit]
)(fa: F[A])(implicit F: GenTemporal[F, E]): F[A] =
doRetry(policy, Some(onRetry))(fa)
doRetry(policy, Some(onError))(fa)

/**
* The return policy that always gives up.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ final class RetryOps[F[_], A] private[syntax] (private val fa: F[A]) extends Any

def retry[E](
policy: Retry[F, E],
onRetry: (Retry.Status, E, Retry.Decision) => F[Unit]
onError: (Retry.Status, E, Retry.Decision) => F[Unit]
)(implicit F: GenTemporal[F, E]): F[A] =
Retry.retry(policy, onRetry)(fa)
Retry.retry(policy, onError)(fa)

}
187 changes: 144 additions & 43 deletions tests/shared/src/test/scala/cats/effect/std/RetrySpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
package cats.effect.std

import cats.{Hash, Show}
import cats.effect.{BaseSpec, IO, Ref}
import cats.data.EitherT
import cats.effect.{BaseSpec, IO, Ref, Temporal}
import cats.mtl.Handle
import cats.syntax.applicative._
import cats.syntax.flatMap._
import cats.syntax.functor._
import cats.syntax.semigroup._

Expand Down Expand Up @@ -72,11 +75,10 @@ class RetrySpec extends BaseSpec {
val delay = 2.second
val capDelay = 1.second

val policy =
Retry
.constantDelay[IO, Throwable](delay)
.withCappedDelay(capDelay)
.withMaxRetries(maxRetries)
val policy = Retry
.constantDelay[IO, Throwable](delay)
.withCappedDelay(capDelay)
.withMaxRetries(maxRetries)

val expected = {
val retries = List.tabulate(maxRetries) { i =>
Expand Down Expand Up @@ -199,11 +201,10 @@ class RetrySpec extends BaseSpec {
val delay = 1.second

val error = new Error1
val policy =
Retry
.constantDelay[IO, Throwable](delay)
.withMaxRetries(maxRetries)
.withErrorMatcher(Retry.ErrorMatcher[IO, Throwable].only[Error1])
val policy = Retry
.constantDelay[IO, Throwable](delay)
.withMaxRetries(maxRetries)
.withErrorMatcher(Retry.ErrorMatcher[IO, Throwable].only[Error1])

val expected = List(
RetryAttempt(Status(0, Duration.Zero), Decision.retry(delay), error),
Expand All @@ -218,11 +219,10 @@ class RetrySpec extends BaseSpec {
val maxRetries = 5
val delay = 1.second

val policy =
Retry
.constantDelay[IO, Throwable](delay)
.withMaxRetries(maxRetries)
.withErrorMatcher(Retry.ErrorMatcher[IO, Throwable].only[Error1])
val policy = Retry
.constantDelay[IO, Throwable](delay)
.withMaxRetries(maxRetries)
.withErrorMatcher(Retry.ErrorMatcher[IO, Throwable].only[Error1])

val expected = List(
RetryAttempt(Status(0, Duration.Zero), Decision.giveUp)
Expand All @@ -235,12 +235,11 @@ class RetrySpec extends BaseSpec {
val delay = 1.second
val maxRetries = 1

val policy =
Retry
.constantDelay[IO, Throwable](delay)
.withMaxRetries(maxRetries)
.withErrorMatcher(Retry.ErrorMatcher[IO, Throwable].only[Error1])
.withErrorMatcher(Retry.ErrorMatcher[IO, Throwable].only[Error2])
val policy = Retry
.constantDelay[IO, Throwable](delay)
.withMaxRetries(maxRetries)
.withErrorMatcher(Retry.ErrorMatcher[IO, Throwable].only[Error1])
.withErrorMatcher(Retry.ErrorMatcher[IO, Throwable].only[Error2])

val error = new Error1
val expected = List(
Expand All @@ -254,12 +253,11 @@ class RetrySpec extends BaseSpec {
val delay = 1.second
val maxRetries = 2

val policy =
Retry
.constantDelay[IO, Throwable](delay)
.withMaxRetries(maxRetries)
.withErrorMatcher(Retry.ErrorMatcher[IO, Throwable].only[Error1])
.withErrorMatcher(Retry.ErrorMatcher[IO, Throwable].only[Error2])
val policy = Retry
.constantDelay[IO, Throwable](delay)
.withMaxRetries(maxRetries)
.withErrorMatcher(Retry.ErrorMatcher[IO, Throwable].only[Error1])
.withErrorMatcher(Retry.ErrorMatcher[IO, Throwable].only[Error2])

val error = new Error2
val expected = List(
Expand All @@ -276,11 +274,10 @@ class RetrySpec extends BaseSpec {
val delay = 1.second

val error = new Error1
val policy =
Retry
.constantDelay[IO, Throwable](delay)
.withMaxRetries(maxRetries)
.withErrorMatcher(Retry.ErrorMatcher[IO, Throwable].except[Error1])
val policy = Retry
.constantDelay[IO, Throwable](delay)
.withMaxRetries(maxRetries)
.withErrorMatcher(Retry.ErrorMatcher[IO, Throwable].except[Error1])

val expected = List(
RetryAttempt(Status(0, Duration.Zero), Decision.giveUp, error)
Expand All @@ -294,11 +291,10 @@ class RetrySpec extends BaseSpec {
val delay = 1.second

val error = new Error1
val policy =
Retry
.constantDelay[IO, Throwable](delay)
.withMaxRetries(maxRetries)
.withErrorMatcher(Retry.ErrorMatcher[IO, Throwable].except[RuntimeException])
val policy = Retry
.constantDelay[IO, Throwable](delay)
.withMaxRetries(maxRetries)
.withErrorMatcher(Retry.ErrorMatcher[IO, Throwable].except[RuntimeException])

val expected = List(
RetryAttempt(Status(0, Duration.Zero), Decision.giveUp, error)
Expand All @@ -313,11 +309,10 @@ class RetrySpec extends BaseSpec {
val delay = 1.second

val error = new Error2
val policy =
Retry
.constantDelay[IO, Throwable](delay)
.withMaxRetries(maxRetries)
.withErrorMatcher(Retry.ErrorMatcher[IO, Throwable].except[Error1])
val policy = Retry
.constantDelay[IO, Throwable](delay)
.withMaxRetries(maxRetries)
.withErrorMatcher(Retry.ErrorMatcher[IO, Throwable].except[Error1])

val expected = List(
RetryAttempt(Status(0, Duration.Zero), Decision.retry(delay), error),
Expand Down Expand Up @@ -453,6 +448,112 @@ class RetrySpec extends BaseSpec {

}

"Retry MTL" should {

sealed trait Errors
final class Error1 extends Errors
final class Error2 extends Errors

type RetryAttempt = (Status, Decision, Errors)

def mtlRetry[F[_], E, A](
action: F[A],
policy: Retry[F, E],
onRetry: (Status, E, Decision) => F[Unit]
)(implicit F: Temporal[F], H: Handle[F, E]): F[A] =
F.tailRecM(Status.initial) { status =>
H.attempt(action).flatMap {
case Left(error) =>
policy
.decide(status, error)
.flatTap(decision => onRetry(status, error, decision))
.flatMap {
case retry: Decision.Retry =>
F.delayBy(F.pure(Left(status.withRetry(retry.delay))), retry.delay)

case _: Decision.GiveUp =>
H.raise(error)
}

case Right(success) =>
F.pure(Right(success))
}
}

implicit val outputHash: Hash[(Either[Errors, Unit], List[RetryAttempt])] =
Hash.fromUniversalHashCode

implicit val outputShow: Show[(Either[Errors, Unit], List[RetryAttempt])] =
Show.fromToString

"give up on mismatched errors" in ticked { implicit ticker =>
type F[A] = EitherT[IO, Errors, A]

val maxRetries = 2
val delay = 1.second

val error = new Error2
val policy = Retry
.constantDelay[F, Errors](delay)
.withMaxRetries(maxRetries)
.withErrorMatcher(Retry.ErrorMatcher[F, Errors].only[Error1])

val expected: List[RetryAttempt] = List(
(Status(0, Duration.Zero), Decision.giveUp, error)
)

val io: F[Unit] = Handle[F, Errors].raise[Errors, Unit](error)

val run =
for {
ref <- IO.ref(List.empty[RetryAttempt])
result <- mtlRetry[F, Errors, Unit](
io,
policy,
(s, e: Errors, d) => EitherT.liftF(ref.update(_ :+ (s, d, e)))
).value
attempts <- ref.get
} yield (result, attempts)

run must completeAs((Left(error), expected))
}

"retry only on matching errors" in ticked { implicit ticker =>
type F[A] = EitherT[IO, Errors, A]

val maxRetries = 2
val delay = 1.second

val error = new Error1
val policy = Retry
.constantDelay[F, Errors](delay)
.withMaxRetries(maxRetries)
.withErrorMatcher(Retry.ErrorMatcher[F, Errors].only[Error1])

val expected: List[RetryAttempt] = List(
(Status(0, Duration.Zero), Decision.retry(delay), error),
(Status(1, 1.second), Decision.retry(delay), error),
(Status(2, 2.seconds), Decision.giveUp, error)
)

val io: F[Unit] = Handle[F, Errors].raise[Errors, Unit](error)

val run =
for {
ref <- IO.ref(List.empty[RetryAttempt])
result <- mtlRetry[F, Errors, Unit](
io,
policy,
(s, e: Errors, d) => EitherT.liftF(ref.update(_ :+ (s, d, e)))
).value
attempts <- ref.get
} yield (result, attempts)

run must completeAs((Left(error), expected))
}

}

private def run[A](policy: Retry[IO, Throwable])(io: IO[A]): IO[List[RetryAttempt]] =
for {
ref <- IO.ref(List.empty[RetryAttempt])
Expand Down

0 comments on commit 8af8301

Please sign in to comment.