Skip to content

Commit

Permalink
Merge pull request #351 from takapi327/refactor/2025-01-Fixed-DBIO-wa…
Browse files Browse the repository at this point in the history
…rning

Refactor/2025 01 fixed dbio warning
  • Loading branch information
takapi327 authored Jan 2, 2025
2 parents cb32ef0 + 772cef5 commit 12f3b56
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 58 deletions.
203 changes: 154 additions & 49 deletions module/ldbc-dsl/src/main/scala/ldbc/dsl/DBIO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

package ldbc.dsl

import cats.{ Functor, Monad, MonadError }
import cats.*
import cats.syntax.all.*

import cats.effect.*
Expand All @@ -24,100 +24,205 @@ import ldbc.dsl.logging.*
* @tparam T
* The result type of the query
*/
trait DBIO[F[_]: Temporal, T]:
trait DBIO[F[_], T]:

private[ldbc] def execute(connection: Connection[F])(using logHandler: LogHandler[F]): F[T]
/**
* The function that actually executes the query.
*
* @param connection
* The connection to the database
* @param logHandler
* The log handler
* @return
* The result of the query
*/
def run(connection: Connection[F])(using logHandler: LogHandler[F]): F[T]

/**
* Functions for managing the processing of connections in a read-only manner.
*/
def readOnly(connection: Connection[F])(using LogHandler[F]): F[T] =
connection.setReadOnly(true) *> execute(connection) <* connection.setReadOnly(false)
def readOnly(connection: Connection[F])(using LogHandler[F]): F[T]

/**
* Functions to manage the processing of connections for writing.
*/
def commit(connection: Connection[F])(using LogHandler[F]): F[T] =
connection.setReadOnly(false) *> connection.setAutoCommit(true) *> execute(connection)
def commit(connection: Connection[F])(using LogHandler[F]): F[T]

/**
* Functions to manage the processing of connections, always rolling back.
*/
def rollback(connection: Connection[F])(using LogHandler[F]): F[T] =
connection.setReadOnly(false) *> connection.setAutoCommit(false) *> execute(connection) <* connection
.rollback() <* connection.setAutoCommit(true)
def rollback(connection: Connection[F])(using LogHandler[F]): F[T]

/**
* Functions to manage the processing of connections in a transaction.
*/
def transaction(connection: Connection[F])(using LogHandler[F]): F[T] =
val acquire = connection.setReadOnly(false) *> connection.setAutoCommit(false) *> Temporal[F].pure(connection)

val release = (connection: Connection[F], exitCase: ExitCase) =>
(exitCase match
case ExitCase.Errored(_) | ExitCase.Canceled => connection.rollback()
case _ => connection.commit()
)
*> connection.setAutoCommit(true)

Resource
.makeCase(acquire)(release)
.use(execute)
def transaction(connection: Connection[F])(using LogHandler[F]): F[T]

object DBIO:

private[ldbc] case class Impl[F[_]: Temporal, T](
statement: String,
params: List[Parameter],
run: Connection[F] => F[T]
func: Connection[F] => F[T]
) extends DBIO[F, T]:

private[ldbc] def execute(connection: Connection[F])(using logHandler: LogHandler[F]): F[T] =
run(connection)
override def run(connection: Connection[F])(using logHandler: LogHandler[F]): F[T] =
func(connection)
.onError(ex => logHandler.run(LogEvent.ProcessingFailure(statement, params.map(_.value), ex)))
<* logHandler.run(LogEvent.Success(statement, params.map(_.value)))

def pure[F[_]: Temporal, T](value: T): DBIO[F, T] =
new DBIO[F, T]:
override private[ldbc] def execute(connection: Connection[F])(using LogHandler[F]): F[T] = Monad[F].pure(value)
override def readOnly(connection: Connection[F])(using LogHandler[F]): F[T] = Monad[F].pure(value)
override def commit(connection: Connection[F])(using LogHandler[F]): F[T] = Monad[F].pure(value)
override def rollback(connection: Connection[F])(using LogHandler[F]): F[T] = Monad[F].pure(value)
override def transaction(connection: Connection[F])(using LogHandler[F]): F[T] = Monad[F].pure(value)
override def readOnly(connection: Connection[F])(using LogHandler[F]): F[T] =
connection.setReadOnly(true) *> run(connection) <* connection.setReadOnly(false)

override def commit(connection: Connection[F])(using LogHandler[F]): F[T] =
connection.setReadOnly(false) *> connection.setAutoCommit(true) *> run(connection)

override def rollback(connection: Connection[F])(using LogHandler[F]): F[T] =
connection.setReadOnly(false) *> connection.setAutoCommit(false) *> run(connection) <* connection
.rollback() <* connection.setAutoCommit(true)

override def transaction(connection: Connection[F])(using LogHandler[F]): F[T] =
val acquire = connection.setReadOnly(false) *> connection.setAutoCommit(false) *> Temporal[F].pure(connection)

val release = (connection: Connection[F], exitCase: ExitCase) =>
(exitCase match
case ExitCase.Errored(_) | ExitCase.Canceled => connection.rollback()
case _ => connection.commit()
)
*> connection.setAutoCommit(true)

def raiseError[F[_]: Temporal, A](e: Throwable): DBIO[F, A] =
Resource
.makeCase(acquire)(release)
.use(run)

def pure[F[_]: Monad, A](value: A): DBIO[F, A] =
new DBIO[F, A]:
override private[ldbc] def execute(connection: Connection[F])(using LogHandler[F]): F[A] =
MonadError[F, Throwable].raiseError(e)
override def run(connection: Connection[F])(using logHandler: LogHandler[F]): F[A] = Monad[F].pure(value)
override def readOnly(connection: Connection[F])(using LogHandler[F]): F[A] = Monad[F].pure(value)
override def commit(connection: Connection[F])(using LogHandler[F]): F[A] = Monad[F].pure(value)
override def rollback(connection: Connection[F])(using LogHandler[F]): F[A] = Monad[F].pure(value)
override def transaction(connection: Connection[F])(using LogHandler[F]): F[A] = Monad[F].pure(value)

given [F[_]: Temporal]: Functor[[T] =>> DBIO[F, T]] with
override def map[A, B](fa: DBIO[F, A])(f: A => B): DBIO[F, B] =
new DBIO[F, B]:
override private[ldbc] def execute(connection: Connection[F])(using LogHandler[F]): F[B] =
fa.execute(connection).map(f)
def raiseError[F[_], A](e: Throwable)(using ev: MonadThrow[F]): DBIO[F, A] =
new DBIO[F, A]:
override def run(connection: Connection[F])(using logHandler: LogHandler[F]): F[A] = ev.raiseError(e)
override def readOnly(connection: Connection[F])(using LogHandler[F]): F[A] = ev.raiseError(e)
override def commit(connection: Connection[F])(using LogHandler[F]): F[A] = ev.raiseError(e)
override def rollback(connection: Connection[F])(using LogHandler[F]): F[A] = ev.raiseError(e)
override def transaction(connection: Connection[F])(using LogHandler[F]): F[A] = ev.raiseError(e)

given [F[_]: Temporal]: MonadError[[T] =>> DBIO[F, T], Throwable] with
override def pure[A](x: A): DBIO[F, A] = DBIO.pure(x)

override def flatMap[A, B](fa: DBIO[F, A])(f: A => DBIO[F, B]): DBIO[F, B] =
new DBIO[F, B]:
override private[ldbc] def execute(connection: Connection[F])(using LogHandler[F]): F[B] =
fa.execute(connection).flatMap(a => f(a).execute(connection))
override def run(connection: Connection[F])(using logHandler: LogHandler[F]): F[B] =
fa.run(connection).flatMap(a => f(a).run(connection))
override def readOnly(connection: Connection[F])(using LogHandler[F]): F[B] =
connection.setReadOnly(true) *> run(connection) <* connection.setReadOnly(false)
override def commit(connection: Connection[F])(using LogHandler[F]): F[B] =
connection.setReadOnly(false) *> connection.setAutoCommit(true) *> run(connection)
override def rollback(connection: Connection[F])(using LogHandler[F]): F[B] =
connection.setReadOnly(false) *> connection.setAutoCommit(false) *> run(connection) <* connection
.rollback() <* connection.setAutoCommit(true)
override def transaction(connection: Connection[F])(using LogHandler[F]): F[B] =
val acquire = connection.setReadOnly(false) *> connection.setAutoCommit(false) *> Temporal[F].pure(connection)
val release = (connection: Connection[F], exitCase: ExitCase) =>
(exitCase match
case ExitCase.Errored(_) | ExitCase.Canceled => connection.rollback()
case _ => connection.commit()
)
*> connection.setAutoCommit(true)
Resource
.makeCase(acquire)(release)
.use(run)

override def tailRecM[A, B](a: A)(f: A => DBIO[F, Either[A, B]]): DBIO[F, B] =
new DBIO[F, B]:
override private[ldbc] def execute(connection: Connection[F])(using logHandler: LogHandler[F]): F[B] =
MonadError[F, Throwable].tailRecM(a)(a => f(a).execute(connection))
override def run(connection: Connection[F])(using logHandler: LogHandler[F]): F[B] =
Temporal[F].tailRecM(a)(a => f(a).run(connection))

override def readOnly(connection: Connection[F])(using LogHandler[F]): F[B] =
connection.setReadOnly(true) *> run(connection) <* connection.setReadOnly(false)

override def commit(connection: Connection[F])(using LogHandler[F]): F[B] =
connection.setReadOnly(false) *> connection.setAutoCommit(true) *> run(connection)

override def rollback(connection: Connection[F])(using LogHandler[F]): F[B] =
connection.setReadOnly(false) *> connection.setAutoCommit(false) *> run(connection) <* connection
.rollback() <* connection.setAutoCommit(true)

override def transaction(connection: Connection[F])(using LogHandler[F]): F[B] =
val acquire = connection.setReadOnly(false) *> connection.setAutoCommit(false) *> Temporal[F].pure(connection)

val release = (connection: Connection[F], exitCase: ExitCase) =>
(exitCase match
case ExitCase.Errored(_) | ExitCase.Canceled => connection.rollback()
case _ => connection.commit()
)
*> connection.setAutoCommit(true)

Resource
.makeCase(acquire)(release)
.use(run)

override def ap[A, B](ff: DBIO[F, A => B])(fa: DBIO[F, A]): DBIO[F, B] =
new DBIO[F, B]:
override private[ldbc] def execute(connection: Connection[F])(using logHandler: LogHandler[F]): F[B] =
(ff.execute(connection), fa.execute(connection)).mapN(_(_))
override def run(connection: Connection[F])(using logHandler: LogHandler[F]): F[B] =
(ff.run(connection), fa.run(connection)).mapN(_(_))

override def readOnly(connection: Connection[F])(using LogHandler[F]): F[B] =
connection.setReadOnly(true) *> run(connection) <* connection.setReadOnly(false)

override def commit(connection: Connection[F])(using LogHandler[F]): F[B] =
connection.setReadOnly(false) *> connection.setAutoCommit(true) *> run(connection)

override def rollback(connection: Connection[F])(using LogHandler[F]): F[B] =
connection.setReadOnly(false) *> connection.setAutoCommit(false) *> run(connection) <* connection
.rollback() <* connection.setAutoCommit(true)

override def transaction(connection: Connection[F])(using LogHandler[F]): F[B] =
val acquire = connection.setReadOnly(false) *> connection.setAutoCommit(false) *> Temporal[F].pure(connection)

val release = (connection: Connection[F], exitCase: ExitCase) =>
(exitCase match
case ExitCase.Errored(_) | ExitCase.Canceled => connection.rollback()
case _ => connection.commit()
)
*> connection.setAutoCommit(true)

Resource
.makeCase(acquire)(release)
.use(run)

override def raiseError[A](e: Throwable): DBIO[F, A] =
DBIO.raiseError(e)

override def handleErrorWith[A](fa: DBIO[F, A])(f: Throwable => DBIO[F, A]): DBIO[F, A] =
new DBIO[F, A]:
override private[ldbc] def execute(connection: Connection[F])(using LogHandler[F]): F[A] =
fa.execute(connection).handleErrorWith(e => f(e).execute(connection))
override def run(connection: Connection[F])(using LogHandler[F]): F[A] =
fa.run(connection).handleErrorWith(e => f(e).run(connection))

override def readOnly(connection: Connection[F])(using LogHandler[F]): F[A] =
connection.setReadOnly(true) *> run(connection) <* connection.setReadOnly(false)

override def commit(connection: Connection[F])(using LogHandler[F]): F[A] =
connection.setReadOnly(false) *> connection.setAutoCommit(true) *> run(connection)

override def rollback(connection: Connection[F])(using LogHandler[F]): F[A] =
connection.setReadOnly(false) *> connection.setAutoCommit(false) *> run(connection) <* connection
.rollback() <* connection.setAutoCommit(true)

override def transaction(connection: Connection[F])(using LogHandler[F]): F[A] =
val acquire = connection.setReadOnly(false) *> connection.setAutoCommit(false) *> Temporal[F].pure(connection)

val release = (connection: Connection[F], exitCase: ExitCase) =>
(exitCase match
case ExitCase.Errored(_) | ExitCase.Canceled => connection.rollback()
case _ => connection.commit()
)
*> connection.setAutoCommit(true)

Resource
.makeCase(acquire)(release)
.use(run)
18 changes: 9 additions & 9 deletions tests/src/test/scala/ldbc/tests/DBIOTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class DBIOTest extends CatsEffectSuite:
val program = DBIO.pure[IO, Int](1)
assertIO(
connection.use { conn =>
program.execute(conn)
program.run(conn)
},
1
)
Expand All @@ -45,7 +45,7 @@ class DBIOTest extends CatsEffectSuite:
val program3 = program2.ap(program1)
assertIO(
connection.use { conn =>
program3.execute(conn)
program3.run(conn)
},
2
)
Expand All @@ -56,7 +56,7 @@ class DBIOTest extends CatsEffectSuite:
val program2 = program1.map(_ + 1)
assertIO(
connection.use { conn =>
program2.execute(conn)
program2.run(conn)
},
2
)
Expand All @@ -67,7 +67,7 @@ class DBIOTest extends CatsEffectSuite:
val program2 = program1.flatMap(n => DBIO.pure[IO, Int](n + 1))
assertIO(
connection.use { conn =>
program2.execute(conn)
program2.run(conn)
},
2
)
Expand All @@ -78,7 +78,7 @@ class DBIOTest extends CatsEffectSuite:
val program2 = program1.tailRecM[DBIO, String](_.map(n => Right(n.toString)))
assertIO(
connection.use { conn =>
program2.execute(conn)
program2.run(conn)
},
"1"
)
Expand All @@ -88,7 +88,7 @@ class DBIOTest extends CatsEffectSuite:
val program = DBIO.raiseError[IO, Int](new Exception("error"))
interceptMessageIO[Exception]("error")(
connection.use { conn =>
program.execute(conn)
program.run(conn)
}
)
}
Expand All @@ -98,7 +98,7 @@ class DBIOTest extends CatsEffectSuite:
val program2 = program1.handleErrorWith(e => DBIO.pure[IO, Int](0))
assertIO(
connection.use { conn =>
program2.execute(conn)
program2.run(conn)
},
0
)
Expand All @@ -108,7 +108,7 @@ class DBIOTest extends CatsEffectSuite:
val program = DBIO.pure[IO, Int](1)
assertIO(
connection.use { conn =>
program.attempt.execute(conn)
program.attempt.run(conn)
},
Right(1)
)
Expand All @@ -118,7 +118,7 @@ class DBIOTest extends CatsEffectSuite:
val program: DBIO[Int] = DBIO.raiseError[IO, Int](new Exception("error"))
assertIOBoolean(
connection.use { conn =>
program.attempt.execute(conn).map(_.isLeft)
program.attempt.run(conn).map(_.isLeft)
}
)
}

0 comments on commit 12f3b56

Please sign in to comment.