Skip to content

Commit

Permalink
Add case for rollback to empty cache
Browse files Browse the repository at this point in the history
  • Loading branch information
kajebiii committed Sep 22, 2023
1 parent 3df9381 commit 5da82d1
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 10 deletions.
30 changes: 20 additions & 10 deletions zio-cache/shared/src/main/scala/zio/cache/Cache.scala
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,8 @@ object Cache {
value = map.putIfAbsent(k, MapValue.Pending(new MapKey(k), promise))
}
val result = if (value eq null) {
lookupValueOf(in, promise)
val rollbackResultIfError = if (rollbackIfError) Right(None) else Left(())
lookupValueOf(in, promise, rollbackResultIfError)
} else {
value match {
case MapValue.Pending(_, promiseInProgress) =>
Expand All @@ -279,8 +280,7 @@ object Cache {
get(in)
} else {
// Only trigger the lookup if we're still the current value, `completedResult`
val rollbackResultIfError: Option[MapValue.Complete[Key, Error, Value]] =
if (rollbackIfError) Some(completedResult) else None
val rollbackResultIfError = if (rollbackIfError) Right(Some(completedResult)) else Left(())
lookupValueOf(in, promise, rollbackResultIfError).when {
map.replace(k, completedResult, MapValue.Refreshing(promise, completedResult))
}
Expand Down Expand Up @@ -309,7 +309,12 @@ object Cache {
private def lookupValueOf(
in: In,
promise: Promise[Error, Value],
rollbackResultIfError: Option[MapValue.Complete[Key, Error, Value]] = None
/**
* Left(()): Put the lookup result.
* Right(None): Remove key if there is a error.
* Right(Some(rollbackResult)): Rollback if there is a error.
*/
rollbackResultIfError: Either[Unit, Option[MapValue.Complete[Key, Error, Value]]] = Left(())
): IO[Error, Value] =
ZIO.suspendSucceed {
val key = keyBy(in)
Expand All @@ -320,12 +325,17 @@ object Cache {
val now = Unsafe.unsafe(implicit u => clock.unsafe.instant())
val entryStats = EntryStats(now)

rollbackResultIfError match {
case Some(rollbackResult) if exit.isFailure =>
map.put(key, rollbackResult)
case _ =>
map.put(key, MapValue.Complete(new MapKey(key), exit, entryStats, now.plus(timeToLive(exit))))
}
if (exit.isSuccess)
map.put(key, MapValue.Complete(new MapKey(key), exit, entryStats, now.plus(timeToLive(exit))))
else
rollbackResultIfError match {
case Left(()) =>
map.put(key, MapValue.Complete(new MapKey(key), exit, entryStats, now.plus(timeToLive(exit))))
case Right(None) =>
map.remove(key)
case Right(Some(rollbackResult)) =>
map.put(key, rollbackResult)
}

promise.done(exit) *> ZIO.done(exit)
}
Expand Down
27 changes: 27 additions & 0 deletions zio-cache/shared/src/test/scala/zio/cache/CacheSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,33 @@ object CacheSpec extends ZIOSpecDefault {
assert(val2)(isRight(equalTo(5))) &&
assert(val3)(isRight(equalTo(5))) &&
assert(val4)(isRight(equalTo(7)))
},
test("should update only if it is a value when the key doesn't exist in the cache") {

val error = new RuntimeException("Must be a multiple of 3")

def inc(n: Int) = n + 1

def retrieve(number: Ref[Int])(key: Int) =
number
.updateAndGet(inc)
.flatMap {
case n if n % 3 == 0 =>
ZIO.fail(error)
case n =>
ZIO.succeed(key * n)
}

val seed = 2
val key = 1
val cap = 30
for {
ref <- Ref.make(seed)
cache <- Cache.make(cap, Duration.Infinity, Lookup(retrieve(ref)))
count0 <- cache.size
_ <- ZIO.foreachDiscard(1 to cap)(key => cache.refreshValue(key).either)
count1 <- cache.size
} yield assertTrue(count0 == 0) && assertTrue(count1 == cap / 3 * 2)
}
),
test("size") {
Expand Down

0 comments on commit 5da82d1

Please sign in to comment.