Skip to content

Commit

Permalink
Retry: introduce ErrorMatcher
Browse files Browse the repository at this point in the history
  • Loading branch information
iRevive committed Jul 28, 2024
1 parent a25babb commit ff5bc74
Show file tree
Hide file tree
Showing 2 changed files with 232 additions and 88 deletions.
104 changes: 94 additions & 10 deletions std/shared/src/main/scala/cats/effect/std/Retry.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package cats.effect.std

import cats.{~>, Monad, Semigroup, Show}
import cats.{~>, Applicative, Monad, Semigroup, Show}
import cats.effect.kernel.GenTemporal
import cats.syntax.apply._
import cats.syntax.flatMap._
Expand All @@ -25,14 +25,15 @@ import cats.syntax.functor._
import cats.syntax.monadError._

import scala.concurrent.duration._
import scala.reflect.{classTag, ClassTag}

/**
* Glossary:
* - individual delay - the delay between retries
* - cumulative delay - the total delay accumulated across all retries
*/
sealed trait Retry[F[_], E] {
import Retry.{Decision, Status}
import Retry.{Decision, ErrorMatcher, Status}

/**
* The name of the policy. The name is used for informative purposes.
Expand Down Expand Up @@ -66,7 +67,7 @@ sealed trait Retry[F[_], E] {
* {{{
* val timeoutExceptionOnly = Retry
* .exponential[IO, Throwable](1.second)
* .withErrorMatcher { case e: TimeoutException => IO.pure(true) }
* .withErrorMatcher(Retry.ErrorMatcher[IO, Throwable].only[TimeoutException])
*
* // will retry using exponential backoff strategy
* Retry.retry(timeoutExceptionOnly)(IO.raiseError(new TimeoutException("oops")))
Expand All @@ -78,7 +79,7 @@ sealed trait Retry[F[_], E] {
* @param matcher
* the matcher to use
*/
def withErrorMatcher(matcher: PartialFunction[E, F[Boolean]]): Retry[F, E]
def withErrorMatcher(matcher: ErrorMatcher[F, E]): Retry[F, E]

/**
* Sets the name for the policy. The name is used for informative purposes.
Expand Down Expand Up @@ -366,6 +367,91 @@ object Retry {
private final case class RetryImpl(delay: FiniteDuration) extends Retry
}

/**
* The error matcher decides whether the retry decision should be calculated for the raised
* error or not.
*/
sealed trait ErrorMatcher[F[_], -E] {
def matches(e: E): F[Boolean]
}

object ErrorMatcher {

/**
* Creates an error matcher that matches all errors.
*/
def all[F[_]: Applicative, E]: ErrorMatcher[F, E] =
new Impl[F, E]({ (_: E) => Applicative[F].pure(true) })

/**
* Creates an error matcher using the given `matcher` under the hood.
*
* @param matcher
* the matcher to use
*/
def matches[F[_]: Applicative, E](
matcher: PartialFunction[E, Boolean]
): ErrorMatcher[F, E] =
new Impl[F, E](matcher.andThen(b => Applicative[F].pure(b)))

/**
* Creates an error matcher using the given `matcher` under the hood.
*
* @param matcher
* the matcher to use
*/
def when[F[_]: Applicative, E](
matcher: PartialFunction[E, F[Boolean]]
): ErrorMatcher[F, E] =
new Impl[F, E](matcher)

/**
* A partially-applied constructor.
*
* @example
* {{{
* val onlyTimeoutException = ErrorMatcher[IO, Throwable].only[TimeoutException]
* val allButTimeoutException = ErrorMatcher[IO, Throwable].except[TimeoutException]
* }}}
*/
def apply[F[_], E]: ApplyPartiallyApplied[F, E] =
new ApplyPartiallyApplied(dummy = true)

final class ApplyPartiallyApplied[F[_], E](private val dummy: Boolean) extends AnyVal {

/**
* Creates a new error matcher the matches only errors of type `E1`.
*
* @example
* matches only `TimeoutException` errors:
* {{{
* val matcher = ErrorMatcher[IO, Throwable].only[TimeoutException]
* }}}
*/
def only[E1 <: E: ClassTag](implicit F: Applicative[F]): ErrorMatcher[F, E] =
matches { case _: E1 => true }

/**
* Creates a new error matcher the matches all errors except `E1`.
*
* @example
* matches all errors except the `TimeoutException` error:
* {{{
* val matcher = ErrorMatcher[IO, Throwable].except[TimeoutException]
* }}}
*/
def except[E1 <: E: ClassTag](implicit F: Applicative[F]): ErrorMatcher[F, E] =
matches { case e => !classTag[E1].runtimeClass.isInstance(e) }
}

private final class Impl[F[_]: Applicative, E](
matcher: PartialFunction[E, F[Boolean]]
) extends ErrorMatcher[F, E] {
def matches(e: E): F[Boolean] =
matcher.applyOrElse(e, (_: E) => Applicative[F].pure(false))
}
}

sealed trait BackoffMultiplier extends Product with Serializable
object BackoffMultiplier {

Expand Down Expand Up @@ -557,7 +643,7 @@ object Retry {
* the function to decide whether to continue
*/
def named[F[_]: Monad, E](name: String)(decider: (Status, E) => F[Decision]): Retry[F, E] =
RetryImpl(name, decider, { _: E => Monad[F].pure(true) })
RetryImpl(name, decider, ErrorMatcher.all)

implicit def retrySemigroup[F[_], E]: Semigroup[Retry[F, E]] =
(x, y) => x && y
Expand Down Expand Up @@ -593,13 +679,11 @@ object Retry {
private final case class RetryImpl[F[_]: Monad, E](
name: String,
decider: (Status, E) => F[Decision],
errorMatcher: PartialFunction[E, F[Boolean]]
errorMatcher: ErrorMatcher[F, E]
) extends Retry[F, E] {

def decide(status: Status, error: E): F[Decision] =
errorMatcher
.applyOrElse(error, (_: E) => Monad[F].pure(false))
.ifM(decider(status, error), Monad[F].pure(Decision.giveUp))
errorMatcher.matches(error).ifM(decider(status, error), Monad[F].pure(Decision.giveUp))

def and(other: Retry[F, E]): Retry[F, E] =
Retry.named(s"($name && ${other.name})") { (status, error) =>
Expand All @@ -619,7 +703,7 @@ object Retry {
}
}

def withErrorMatcher(matcher: PartialFunction[E, F[Boolean]]): Retry[F, E] =
def withErrorMatcher(matcher: ErrorMatcher[F, E]): Retry[F, E] =
copy(errorMatcher = matcher)

def withName(name: String): Retry[F, E] =
Expand Down
216 changes: 138 additions & 78 deletions tests/shared/src/test/scala/cats/effect/std/RetrySpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,84 +67,6 @@ class RetrySpec extends BaseSpec {
run(policy)(errorIO) must completeAs(expected)
}

"withErrorMatcher - retry only on matched errors" in ticked { implicit ticker =>
val maxRetries = 5
val delay = 1.second

val error = new Error1
val policy =
Retry.constantDelay[IO, Throwable](delay).withMaxRetries(maxRetries).withErrorMatcher {
case _: Error1 => IO.pure(true)
}

val expected = List(
RetryAttempt(Status(0, Duration.Zero), Decision.retry(delay), error),
RetryAttempt(Status(1, 1.second), Decision.retry(delay), error),
RetryAttempt(Status(2, 2.seconds), Decision.retry(delay), error),
RetryAttempt(Status(3, 3.seconds), Decision.retry(delay), error),
RetryAttempt(Status(4, 4.seconds), Decision.retry(delay), error),
RetryAttempt(Status(5, 5.seconds), Decision.giveUp, error)
)

run(policy)(IO.raiseError(error)) must completeAs(expected)
}

"withErrorMatcher - give up on mismatched errors" in ticked { implicit ticker =>
val maxRetries = 5
val delay = 1.second

val policy =
Retry.constantDelay[IO, Throwable](delay).withMaxRetries(maxRetries).withErrorMatcher {
case _: Error1 => IO.pure(true)
}

val expected = List(
RetryAttempt(Status(0, Duration.Zero), Decision.giveUp)
)

run(policy)(errorIO) must completeAs(expected)
}

"withErrorMatcher - keep the last matcher - give up on mismatched" in ticked { implicit t =>
val delay = 1.second
val maxRetries = 1

val policy =
Retry
.constantDelay[IO, Throwable](delay)
.withMaxRetries(maxRetries)
.withErrorMatcher { case _: Error1 => IO.pure(true) }
.withErrorMatcher { case _: Error2 => IO.pure(true) }

val error = new Error1
val expected = List(
RetryAttempt(Status(0, Duration.Zero), Decision.giveUp, error)
)

run(policy)(IO.raiseError(error)) must completeAs(expected)
}

"withErrorMatcher - keep the last matcher" in ticked { implicit ticker =>
val delay = 1.second
val maxRetries = 2

val policy =
Retry
.constantDelay[IO, Throwable](delay)
.withMaxRetries(maxRetries)
.withErrorMatcher { case _: Error1 => IO.pure(true) }
.withErrorMatcher { case _: Error2 => IO.pure(true) }

val error = new Error2
val expected = List(
RetryAttempt(Status(0, Duration.Zero), Decision.retry(delay), error),
RetryAttempt(Status(1, 1.second), Decision.retry(delay), error),
RetryAttempt(Status(2, 2.seconds), Decision.giveUp, error)
)

run(policy)(IO.raiseError(error)) must completeAs(expected)
}

"withCappedDelay - cap the individual delay" in ticked { implicit ticker =>
val maxRetries = 5
val delay = 2.second
Expand Down Expand Up @@ -270,6 +192,144 @@ class RetrySpec extends BaseSpec {

}

"Retry#withErrorMatcher" should {

"retry only on matched errors" in ticked { implicit ticker =>
val maxRetries = 2
val delay = 1.second

val error = new Error1
val policy =
Retry
.constantDelay[IO, Throwable](delay)
.withMaxRetries(maxRetries)
.withErrorMatcher(Retry.ErrorMatcher[IO, Throwable].only[Error1])

val expected = List(
RetryAttempt(Status(0, Duration.Zero), Decision.retry(delay), error),
RetryAttempt(Status(1, 1.second), Decision.retry(delay), error),
RetryAttempt(Status(2, 2.seconds), Decision.giveUp, error)
)

run(policy)(IO.raiseError(error)) must completeAs(expected)
}

"give up on mismatched errors" in ticked { implicit ticker =>
val maxRetries = 5
val delay = 1.second

val policy =
Retry
.constantDelay[IO, Throwable](delay)
.withMaxRetries(maxRetries)
.withErrorMatcher(Retry.ErrorMatcher[IO, Throwable].only[Error1])

val expected = List(
RetryAttempt(Status(0, Duration.Zero), Decision.giveUp)
)

run(policy)(errorIO) must completeAs(expected)
}

"keep the last matcher - give up on mismatched" in ticked { implicit ticker =>
val delay = 1.second
val maxRetries = 1

val policy =
Retry
.constantDelay[IO, Throwable](delay)
.withMaxRetries(maxRetries)
.withErrorMatcher(Retry.ErrorMatcher[IO, Throwable].only[Error1])
.withErrorMatcher(Retry.ErrorMatcher[IO, Throwable].only[Error2])

val error = new Error1
val expected = List(
RetryAttempt(Status(0, Duration.Zero), Decision.giveUp, error)
)

run(policy)(IO.raiseError(error)) must completeAs(expected)
}

"keep the last matcher - retry on matching errors" in ticked { implicit ticker =>
val delay = 1.second
val maxRetries = 2

val policy =
Retry
.constantDelay[IO, Throwable](delay)
.withMaxRetries(maxRetries)
.withErrorMatcher(Retry.ErrorMatcher[IO, Throwable].only[Error1])
.withErrorMatcher(Retry.ErrorMatcher[IO, Throwable].only[Error2])

val error = new Error2
val expected = List(
RetryAttempt(Status(0, Duration.Zero), Decision.retry(delay), error),
RetryAttempt(Status(1, 1.second), Decision.retry(delay), error),
RetryAttempt(Status(2, 2.seconds), Decision.giveUp, error)
)

run(policy)(IO.raiseError(error)) must completeAs(expected)
}

"ErrorMatcher.except - give up on 'excepted' errors" in ticked { implicit ticker =>
val maxRetries = 2
val delay = 1.second

val error = new Error1
val policy =
Retry
.constantDelay[IO, Throwable](delay)
.withMaxRetries(maxRetries)
.withErrorMatcher(Retry.ErrorMatcher[IO, Throwable].except[Error1])

val expected = List(
RetryAttempt(Status(0, Duration.Zero), Decision.giveUp, error)
)

run(policy)(IO.raiseError(error)) must completeAs(expected)
}

"ErrorMatcher.except - give up on subtypes" in ticked { implicit ticker =>
val maxRetries = 2
val delay = 1.second

val error = new Error1
val policy =
Retry
.constantDelay[IO, Throwable](delay)
.withMaxRetries(maxRetries)
.withErrorMatcher(Retry.ErrorMatcher[IO, Throwable].except[RuntimeException])

val expected = List(
RetryAttempt(Status(0, Duration.Zero), Decision.giveUp, error)
)

run(policy)(IO.raiseError(error)) must completeAs(expected)
}

"ErrorMatcher.except - recover on all errors but the 'excepted' one" in ticked {
implicit ticker =>
val maxRetries = 2
val delay = 1.second

val error = new Error2
val policy =
Retry
.constantDelay[IO, Throwable](delay)
.withMaxRetries(maxRetries)
.withErrorMatcher(Retry.ErrorMatcher[IO, Throwable].except[Error1])

val expected = List(
RetryAttempt(Status(0, Duration.Zero), Decision.retry(delay), error),
RetryAttempt(Status(1, 1.second), Decision.retry(delay), error),
RetryAttempt(Status(2, 2.seconds), Decision.giveUp, error)
)

run(policy)(IO.raiseError(error)) must completeAs(expected)
}

}

"Retry.exponentialBackoff" should {
// it's not random :)
val RandomNextDouble = 1.0
Expand Down

0 comments on commit ff5bc74

Please sign in to comment.