Skip to content

Commit

Permalink
Merge pull request #347 from takapi327/enhancement/2024-12-Encoder-Ex…
Browse files Browse the repository at this point in the history
…tensions

Enhancement/2024 12 encoder extensions
  • Loading branch information
takapi327 authored Dec 29, 2024
2 parents 51b18b9 + 9cef538 commit 81a1c62
Show file tree
Hide file tree
Showing 27 changed files with 523 additions and 508 deletions.
8 changes: 4 additions & 4 deletions docs/src/main/scala/05-Program.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ import ldbc.dsl.codec.*
enum Status:
case Active, InActive

given Encoder[Status] with
override def encode(value: Status): Boolean = value match
case Status.Active => true
case Status.InActive => false
given Encoder[Status] = Encoder[Boolean].contramap {
case Status.Active => true
case Status.InActive => false
}

val program1: DBIO[Int] =
sql"INSERT INTO user (name, email, status) VALUES (${ "user 1" }, ${ "[email protected]" }, ${ Status.Active })".update
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,5 +151,6 @@ private[ldbc] object TableModelGenerator:
| given ldbc.dsl.codec.Decoder.Elem[$enumName] = new ldbc.dsl.codec.Decoder.Elem[$enumName]:
| override def decode(resultSet: ldbc.sql.ResultSet, columnLabel: String): $enumName = $enumName.fromOrdinal(resultSet.getInt(columnLabel))
| override def decode(resultSet: ldbc.sql.ResultSet, index: Int): $enumName = $enumName.fromOrdinal(resultSet.getInt(index))
| given ldbc.dsl.codec.Encoder[$enumName] = ldbc.dsl.codec.Encoder[Int].contramap(_.ordinal)
|""".stripMargin)
case _ => None
11 changes: 4 additions & 7 deletions module/ldbc-dsl/src/main/scala/ldbc/dsl/Mysql.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import ldbc.dsl.codec.Decoder
* @tparam F
* The effect type
*/
case class Mysql[F[_]: Temporal](statement: String, params: List[Parameter.Dynamic]) extends SQL:
case class Mysql[F[_]: Temporal](statement: String, params: List[Parameter.Dynamic]) extends SQL, ParamBinder[F]:

@targetName("combine")
override def ++(sql: SQL): Mysql[F] =
Expand Down Expand Up @@ -79,9 +79,7 @@ case class Mysql[F[_]: Temporal](statement: String, params: List[Parameter.Dynam
connection =>
for
prepareStatement <- connection.prepareStatement(statement)
result <- params.zipWithIndex.traverse {
case (param, index) => param.bind(prepareStatement, index + 1)
} >> prepareStatement.executeUpdate() <* prepareStatement.close()
result <- paramBind(prepareStatement, params) >> prepareStatement.executeUpdate() <* prepareStatement.close()
yield result
)

Expand All @@ -106,9 +104,8 @@ case class Mysql[F[_]: Temporal](statement: String, params: List[Parameter.Dynam
connection =>
for
prepareStatement <- connection.prepareStatement(statement, Statement.RETURN_GENERATED_KEYS)
resultSet <- params.zipWithIndex.traverse {
case (param, index) => param.bind(prepareStatement, index + 1)
} >> prepareStatement.executeUpdate() >> prepareStatement.getGeneratedKeys()
resultSet <- paramBind(prepareStatement, params) >> prepareStatement.executeUpdate() >> prepareStatement
.getGeneratedKeys()
result <- summon[ResultSetConsumer[F, T]].consume(resultSet) <* prepareStatement.close()
yield result
)
56 changes: 19 additions & 37 deletions module/ldbc-dsl/src/main/scala/ldbc/dsl/Parameter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@

package ldbc.dsl

import java.time.*
import cats.syntax.all.*

import ldbc.sql.PreparedStatement
import ldbc.dsl.codec.Encoder

/**
Expand All @@ -24,39 +23,22 @@ object Parameter:
case class Static(value: String) extends Parameter:
override def toString: String = value

trait Dynamic extends Parameter:

/**
* Methods for setting Scala and Java values to the specified position in PreparedStatement.
*
* @param statement
* An object that represents a precompiled SQL statement.
* @param index
* the parameter value
*/
def bind[F[_]](statement: PreparedStatement[F], index: Int): F[Unit]

trait Dynamic extends Parameter
object Dynamic:

def apply[A](_value: A)(using encoder: Encoder[A]): Dynamic =
new Dynamic:
override def value: String = _value.toString
override def bind[F[_]](statement: PreparedStatement[F], index: Int): F[Unit] =
encoder.encode(_value) match
case value: Boolean => statement.setBoolean(index, value)
case value: Byte => statement.setByte(index, value)
case value: Short => statement.setShort(index, value)
case value: Int => statement.setInt(index, value)
case value: Long => statement.setLong(index, value)
case value: Float => statement.setFloat(index, value)
case value: Double => statement.setDouble(index, value)
case value: BigDecimal => statement.setBigDecimal(index, value)
case value: String => statement.setString(index, value)
case value: Array[Byte] => statement.setBytes(index, value)
case value: LocalTime => statement.setTime(index, value)
case value: LocalDate => statement.setDate(index, value)
case value: LocalDateTime => statement.setTimestamp(index, value)
case None => statement.setNull(index, ldbc.sql.Types.NULL)

given [A](using Encoder[A]): Conversion[A, Dynamic] with
override def apply(value: A): Dynamic = Dynamic(value)
case class Success(encoded: Encoder.Supported) extends Dynamic:
override def value: String = encoded.toString
case class Failure(errors: List[String]) extends Dynamic:
override def value: String = errors.mkString(", ")

def many[A](encoded: Encoder.Encoded): List[Dynamic] =
encoded match
case Encoder.Encoded.Success(list) => list.map(value => Success(value))
case Encoder.Encoded.Failure(errors) => List(Failure(errors.toList))

given [A](using encoder: Encoder[A]): Conversion[A, Dynamic] with
override def apply(value: A): Dynamic = encoder.encode(value) match
case Encoder.Encoded.Success(list) =>
list match
case head :: Nil => Dynamic.Success(head)
case _ => Dynamic.Failure(List("Multiple values are not allowed"))
case Encoder.Encoded.Failure(errors) => Dynamic.Failure(errors.toList)
15 changes: 6 additions & 9 deletions module/ldbc-dsl/src/main/scala/ldbc/dsl/Query.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ object Query:
statement: String,
params: List[Parameter.Dynamic],
decoder: Decoder[T]
) extends Query[F, T]:
) extends Query[F, T],
ParamBinder[F]:

given Decoder[T] = decoder

Expand All @@ -52,10 +53,8 @@ object Query:
connection =>
for
prepareStatement <- connection.prepareStatement(statement)
resultSet <- params.zipWithIndex.traverse {
case (param, index) => param.bind(prepareStatement, index + 1)
} >> prepareStatement.executeQuery()
result <- summon[ResultSetConsumer[F, G[T]]].consume(resultSet) <* prepareStatement.close()
resultSet <- paramBind(prepareStatement, params) >> prepareStatement.executeQuery()
result <- summon[ResultSetConsumer[F, G[T]]].consume(resultSet) <* prepareStatement.close()
yield result
)

Expand All @@ -66,9 +65,7 @@ object Query:
connection =>
for
prepareStatement <- connection.prepareStatement(statement)
resultSet <- params.zipWithIndex.traverse {
case (param, index) => param.bind(prepareStatement, index + 1)
} >> prepareStatement.executeQuery()
result <- summon[ResultSetConsumer[F, T]].consume(resultSet) <* prepareStatement.close()
resultSet <- paramBind(prepareStatement, params) >> prepareStatement.executeQuery()
result <- summon[ResultSetConsumer[F, T]].consume(resultSet) <* prepareStatement.close()
yield result
)
11 changes: 11 additions & 0 deletions module/ldbc-dsl/src/main/scala/ldbc/dsl/codec/Decoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,17 @@ class Decoder[A](f: (resultSet: ResultSet, prefix: Option[String]) => A):
*/
def decode(resultSet: ResultSet, prefix: Option[String]): A = f(resultSet, prefix)

/** `Decoder` is semigroupal: a pair of decoders make a decoder for a pair. */
def product[B](fb: Decoder[B]): Decoder[(A, B)] =
new Decoder((resultSet, prefix) => (this.decode(resultSet, None), fb.decode(resultSet, None)))

/** Lift this `Decoder` into `Option`. */
def opt: Decoder[Option[A]] =
new Decoder((resultSet, prefix) =>
val value = this.decode(resultSet, prefix)
if resultSet.wasNull() then None else Some(value)
)

object Decoder:

given Functor[[A] =>> Decoder[A]] with
Expand Down
88 changes: 67 additions & 21 deletions module/ldbc-dsl/src/main/scala/ldbc/dsl/codec/Encoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import java.time.*

import scala.compiletime.*

import cats.data.NonEmptyList

/**
* Trait for converting Scala types to types that can be handled by PreparedStatement.
*
Expand All @@ -26,71 +28,115 @@ trait Encoder[A]:
* @return
* Types that can be handled by PreparedStatement
*/
def encode(value: A): Encoder.Supported
def encode(value: A): Encoder.Encoded

/** Contramap inputs from a new type `B`, yielding an `Encoder[B]`. */
def contramap[B](func: B => A): Encoder[B] = (value: B) => encode(func(value))

/** Map outputs to a new type `B`, yielding an `Encoder[B]`. */
def either[B](func: B => Either[String, A]): Encoder[B] = (value: B) =>
func(value) match
case Right(value) => encode(value)
case Left(error) => Encoder.Encoded.failure(error)

/** `Encoder` is semigroupal: a pair of encoders make a encoder for a pair. */
def product[B](that: Encoder[B]): Encoder[(A, B)] = (value: (A, B)) => encode(value._1) product that.encode(value._2)

object Encoder:

/** Types that can be handled by PreparedStatement. */
type Supported = Boolean | Byte | Short | Int | Long | Float | Double | BigDecimal | String | Array[Byte] |
LocalTime | LocalDate | LocalDateTime | None.type

def apply[A](using encoder: Encoder[A]): Encoder[A] = encoder

given Encoder[Boolean] with
override def encode(value: Boolean): Boolean = value
override def encode(value: Boolean): Encoded =
Encoded.success(List(value))

given Encoder[Byte] with
override def encode(value: Byte): Byte = value
override def encode(value: Byte): Encoded =
Encoded.success(List(value))

given Encoder[Short] with
override def encode(value: Short): Short = value
override def encode(value: Short): Encoded =
Encoded.success(List(value))

given Encoder[Int] with
override def encode(value: Int): Int = value
override def encode(value: Int): Encoded =
Encoded.success(List(value))

given Encoder[Long] with
override def encode(value: Long): Long = value
override def encode(value: Long): Encoded =
Encoded.success(List(value))

given Encoder[Float] with
override def encode(value: Float): Float = value
override def encode(value: Float): Encoded =
Encoded.success(List(value))

given Encoder[Double] with
override def encode(value: Double): Double = value
override def encode(value: Double): Encoded =
Encoded.success(List(value))

given Encoder[BigDecimal] with
override def encode(value: BigDecimal): BigDecimal = value
override def encode(value: BigDecimal): Encoded =
Encoded.success(List(value))

given Encoder[String] with
override def encode(value: String): String = value
override def encode(value: String): Encoded =
Encoded.success(List(value))

given Encoder[Array[Byte]] with
override def encode(value: Array[Byte]): Array[Byte] = value
override def encode(value: Array[Byte]): Encoded =
Encoded.success(List(value))

given Encoder[LocalTime] with
override def encode(value: LocalTime): LocalTime = value
override def encode(value: LocalTime): Encoded =
Encoded.success(List(value))

given Encoder[LocalDate] with
override def encode(value: LocalDate): LocalDate = value
override def encode(value: LocalDate): Encoded =
Encoded.success(List(value))

given Encoder[LocalDateTime] with
override def encode(value: LocalDateTime): LocalDateTime = value
override def encode(value: LocalDateTime): Encoded =
Encoded.success(List(value))

given (using encoder: Encoder[String]): Encoder[Year] = encoder.contramap(_.toString)

given Encoder[Year] with
override def encode(value: Year): String = value.toString
given (using encoder: Encoder[String]): Encoder[YearMonth] = encoder.contramap(_.toString)

given Encoder[YearMonth] with
override def encode(value: YearMonth): String = value.toString
given (using encoder: Encoder[String]): Encoder[BigInt] = encoder.contramap(_.toString)

given Encoder[None.type] with
override def encode(value: None.type): None.type = value
override def encode(value: None.type): Encoded =
Encoded.success(List(None))

given [A](using encoder: Encoder[A]): Encoder[Option[A]] with
override def encode(value: Option[A]): Encoder.Supported =
override def encode(value: Option[A]): Encoded =
value match
case Some(value) => encoder.encode(value)
case None => None
case None => Encoded.success(List(None))

type MapToTuple[T] <: Tuple = T match
case EmptyTuple => EmptyTuple
case h *: EmptyTuple => Encoder[h] *: EmptyTuple
case h *: t => Encoder[h] *: MapToTuple[t]

inline def fold[T]: MapToTuple[T] = summonAll[MapToTuple[T]]

sealed trait Encoded:
def product(that: Encoded): Encoded
object Encoded:
case class Success(value: List[Encoder.Supported]) extends Encoded:
override def product(that: Encoded): Encoded = that match
case Success(value) => Success(this.value ::: value)
case Failure(errors) => Failure(errors)
case class Failure(errors: NonEmptyList[String]) extends Encoded:
override def product(that: Encoded): Encoded = that match
case Success(value) => Failure(errors)
case Failure(errors) => Failure(errors ::: this.errors)

def success(value: List[Encoder.Supported]): Encoded = Success(value)
def failure(error: String, errors: String*): Encoded =
Failure(NonEmptyList(error, errors.toList))
42 changes: 40 additions & 2 deletions module/ldbc-dsl/src/main/scala/ldbc/dsl/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

package ldbc

import java.time.*

import scala.deriving.Mirror

import cats.{ Foldable, Functor, Reducible }
Expand All @@ -14,11 +16,47 @@ import cats.syntax.all.*

import cats.effect.*

import ldbc.sql.PreparedStatement
import ldbc.dsl.syntax.*
import ldbc.dsl.codec.Encoder

package object dsl:

private[ldbc] trait ParamBinder[F[_]: Temporal]:
protected def paramBind(
prepareStatement: PreparedStatement[F],
params: List[Parameter.Dynamic]
): F[Unit] =
val encoded = params.foldLeft(Temporal[F].pure(List.empty[Encoder.Supported])) {
case (acc, param) =>
for
acc$ <- acc
value <- param match
case Parameter.Dynamic.Success(value) => Temporal[F].pure(value)
case Parameter.Dynamic.Failure(errors) =>
Temporal[F].raiseError(new IllegalArgumentException(errors.mkString(", ")))
yield acc$ :+ value
}
encoded.flatMap(_.zipWithIndex.foldLeft(Temporal[F].unit) {
case (acc, (value, index)) =>
acc *> (value match
case value: Boolean => prepareStatement.setBoolean(index + 1, value)
case value: Byte => prepareStatement.setByte(index + 1, value)
case value: Short => prepareStatement.setShort(index + 1, value)
case value: Int => prepareStatement.setInt(index + 1, value)
case value: Long => prepareStatement.setLong(index + 1, value)
case value: Float => prepareStatement.setFloat(index + 1, value)
case value: Double => prepareStatement.setDouble(index + 1, value)
case value: BigDecimal => prepareStatement.setBigDecimal(index + 1, value)
case value: String => prepareStatement.setString(index + 1, value)
case value: Array[Byte] => prepareStatement.setBytes(index + 1, value)
case value: LocalDate => prepareStatement.setDate(index + 1, value)
case value: LocalTime => prepareStatement.setTime(index + 1, value)
case value: LocalDateTime => prepareStatement.setTimestamp(index + 1, value)
case None => prepareStatement.setNull(index + 1, ldbc.sql.Types.NULL)
)
})

private[ldbc] trait SyncSyntax[F[_]: Temporal] extends StringContextSyntax[F]:

/**
Expand Down Expand Up @@ -47,9 +85,9 @@ package object dsl:
val params = tuple.toList
Mysql[F](
List.fill(params.size)("?").mkString(","),
(Tuple.fromProduct(v).toList zip params).map {
(Tuple.fromProduct(v).toList zip params).flatMap {
case (value, param) =>
Parameter.Dynamic(value.asInstanceOf[Any])(using param.asInstanceOf[Encoder[Any]])
Parameter.Dynamic.many(param.asInstanceOf[Encoder[Any]].encode(value.asInstanceOf[Any]))
}
)

Expand Down
Loading

0 comments on commit 81a1c62

Please sign in to comment.