diff --git a/docs/src/main/scala/05-Program.scala b/docs/src/main/scala/05-Program.scala index 9b85beab6..42bcc0df6 100644 --- a/docs/src/main/scala/05-Program.scala +++ b/docs/src/main/scala/05-Program.scala @@ -23,10 +23,10 @@ import ldbc.dsl.codec.* enum Status: case Active, InActive - given Encoder[Status] with - override def encode(value: Status): Boolean = value match - case Status.Active => true - case Status.InActive => false + given Encoder[Status] = Encoder[Boolean].contramap { + case Status.Active => true + case Status.InActive => false + } val program1: DBIO[Int] = sql"INSERT INTO user (name, email, status) VALUES (${ "user 1" }, ${ "user@example.com" }, ${ Status.Active })".update diff --git a/module/ldbc-codegen/shared/src/main/scala/ldbc/codegen/TableModelGenerator.scala b/module/ldbc-codegen/shared/src/main/scala/ldbc/codegen/TableModelGenerator.scala index 97fc295d2..a3a79c8d3 100644 --- a/module/ldbc-codegen/shared/src/main/scala/ldbc/codegen/TableModelGenerator.scala +++ b/module/ldbc-codegen/shared/src/main/scala/ldbc/codegen/TableModelGenerator.scala @@ -151,5 +151,6 @@ private[ldbc] object TableModelGenerator: | given ldbc.dsl.codec.Decoder.Elem[$enumName] = new ldbc.dsl.codec.Decoder.Elem[$enumName]: | override def decode(resultSet: ldbc.sql.ResultSet, columnLabel: String): $enumName = $enumName.fromOrdinal(resultSet.getInt(columnLabel)) | override def decode(resultSet: ldbc.sql.ResultSet, index: Int): $enumName = $enumName.fromOrdinal(resultSet.getInt(index)) + | given ldbc.dsl.codec.Encoder[$enumName] = ldbc.dsl.codec.Encoder[Int].contramap(_.ordinal) |""".stripMargin) case _ => None diff --git a/module/ldbc-dsl/src/main/scala/ldbc/dsl/Mysql.scala b/module/ldbc-dsl/src/main/scala/ldbc/dsl/Mysql.scala index ea824cf50..91fa1f2cd 100644 --- a/module/ldbc-dsl/src/main/scala/ldbc/dsl/Mysql.scala +++ b/module/ldbc-dsl/src/main/scala/ldbc/dsl/Mysql.scala @@ -28,7 +28,7 @@ import ldbc.dsl.codec.Decoder * @tparam F * The effect type */ -case class Mysql[F[_]: Temporal](statement: String, params: List[Parameter.Dynamic]) extends SQL: +case class Mysql[F[_]: Temporal](statement: String, params: List[Parameter.Dynamic]) extends SQL, ParamBinder[F]: @targetName("combine") override def ++(sql: SQL): Mysql[F] = @@ -79,9 +79,7 @@ case class Mysql[F[_]: Temporal](statement: String, params: List[Parameter.Dynam connection => for prepareStatement <- connection.prepareStatement(statement) - result <- params.zipWithIndex.traverse { - case (param, index) => param.bind(prepareStatement, index + 1) - } >> prepareStatement.executeUpdate() <* prepareStatement.close() + result <- paramBind(prepareStatement, params) >> prepareStatement.executeUpdate() <* prepareStatement.close() yield result ) @@ -106,9 +104,8 @@ case class Mysql[F[_]: Temporal](statement: String, params: List[Parameter.Dynam connection => for prepareStatement <- connection.prepareStatement(statement, Statement.RETURN_GENERATED_KEYS) - resultSet <- params.zipWithIndex.traverse { - case (param, index) => param.bind(prepareStatement, index + 1) - } >> prepareStatement.executeUpdate() >> prepareStatement.getGeneratedKeys() + resultSet <- paramBind(prepareStatement, params) >> prepareStatement.executeUpdate() >> prepareStatement + .getGeneratedKeys() result <- summon[ResultSetConsumer[F, T]].consume(resultSet) <* prepareStatement.close() yield result ) diff --git a/module/ldbc-dsl/src/main/scala/ldbc/dsl/Parameter.scala b/module/ldbc-dsl/src/main/scala/ldbc/dsl/Parameter.scala index 855685e1c..d115f3c10 100644 --- a/module/ldbc-dsl/src/main/scala/ldbc/dsl/Parameter.scala +++ b/module/ldbc-dsl/src/main/scala/ldbc/dsl/Parameter.scala @@ -6,9 +6,8 @@ package ldbc.dsl -import java.time.* +import cats.syntax.all.* -import ldbc.sql.PreparedStatement import ldbc.dsl.codec.Encoder /** @@ -24,39 +23,22 @@ object Parameter: case class Static(value: String) extends Parameter: override def toString: String = value - trait Dynamic extends Parameter: - - /** - * Methods for setting Scala and Java values to the specified position in PreparedStatement. - * - * @param statement - * An object that represents a precompiled SQL statement. - * @param index - * the parameter value - */ - def bind[F[_]](statement: PreparedStatement[F], index: Int): F[Unit] - + trait Dynamic extends Parameter object Dynamic: - - def apply[A](_value: A)(using encoder: Encoder[A]): Dynamic = - new Dynamic: - override def value: String = _value.toString - override def bind[F[_]](statement: PreparedStatement[F], index: Int): F[Unit] = - encoder.encode(_value) match - case value: Boolean => statement.setBoolean(index, value) - case value: Byte => statement.setByte(index, value) - case value: Short => statement.setShort(index, value) - case value: Int => statement.setInt(index, value) - case value: Long => statement.setLong(index, value) - case value: Float => statement.setFloat(index, value) - case value: Double => statement.setDouble(index, value) - case value: BigDecimal => statement.setBigDecimal(index, value) - case value: String => statement.setString(index, value) - case value: Array[Byte] => statement.setBytes(index, value) - case value: LocalTime => statement.setTime(index, value) - case value: LocalDate => statement.setDate(index, value) - case value: LocalDateTime => statement.setTimestamp(index, value) - case None => statement.setNull(index, ldbc.sql.Types.NULL) - - given [A](using Encoder[A]): Conversion[A, Dynamic] with - override def apply(value: A): Dynamic = Dynamic(value) + case class Success(encoded: Encoder.Supported) extends Dynamic: + override def value: String = encoded.toString + case class Failure(errors: List[String]) extends Dynamic: + override def value: String = errors.mkString(", ") + + def many[A](encoded: Encoder.Encoded): List[Dynamic] = + encoded match + case Encoder.Encoded.Success(list) => list.map(value => Success(value)) + case Encoder.Encoded.Failure(errors) => List(Failure(errors.toList)) + + given [A](using encoder: Encoder[A]): Conversion[A, Dynamic] with + override def apply(value: A): Dynamic = encoder.encode(value) match + case Encoder.Encoded.Success(list) => + list match + case head :: Nil => Dynamic.Success(head) + case _ => Dynamic.Failure(List("Multiple values are not allowed")) + case Encoder.Encoded.Failure(errors) => Dynamic.Failure(errors.toList) diff --git a/module/ldbc-dsl/src/main/scala/ldbc/dsl/Query.scala b/module/ldbc-dsl/src/main/scala/ldbc/dsl/Query.scala index 981441b16..ad9007c69 100644 --- a/module/ldbc-dsl/src/main/scala/ldbc/dsl/Query.scala +++ b/module/ldbc-dsl/src/main/scala/ldbc/dsl/Query.scala @@ -41,7 +41,8 @@ object Query: statement: String, params: List[Parameter.Dynamic], decoder: Decoder[T] - ) extends Query[F, T]: + ) extends Query[F, T], + ParamBinder[F]: given Decoder[T] = decoder @@ -52,10 +53,8 @@ object Query: connection => for prepareStatement <- connection.prepareStatement(statement) - resultSet <- params.zipWithIndex.traverse { - case (param, index) => param.bind(prepareStatement, index + 1) - } >> prepareStatement.executeQuery() - result <- summon[ResultSetConsumer[F, G[T]]].consume(resultSet) <* prepareStatement.close() + resultSet <- paramBind(prepareStatement, params) >> prepareStatement.executeQuery() + result <- summon[ResultSetConsumer[F, G[T]]].consume(resultSet) <* prepareStatement.close() yield result ) @@ -66,9 +65,7 @@ object Query: connection => for prepareStatement <- connection.prepareStatement(statement) - resultSet <- params.zipWithIndex.traverse { - case (param, index) => param.bind(prepareStatement, index + 1) - } >> prepareStatement.executeQuery() - result <- summon[ResultSetConsumer[F, T]].consume(resultSet) <* prepareStatement.close() + resultSet <- paramBind(prepareStatement, params) >> prepareStatement.executeQuery() + result <- summon[ResultSetConsumer[F, T]].consume(resultSet) <* prepareStatement.close() yield result ) diff --git a/module/ldbc-dsl/src/main/scala/ldbc/dsl/codec/Decoder.scala b/module/ldbc-dsl/src/main/scala/ldbc/dsl/codec/Decoder.scala index 11b622b97..625f2ff26 100644 --- a/module/ldbc-dsl/src/main/scala/ldbc/dsl/codec/Decoder.scala +++ b/module/ldbc-dsl/src/main/scala/ldbc/dsl/codec/Decoder.scala @@ -35,6 +35,17 @@ class Decoder[A](f: (resultSet: ResultSet, prefix: Option[String]) => A): */ def decode(resultSet: ResultSet, prefix: Option[String]): A = f(resultSet, prefix) + /** `Decoder` is semigroupal: a pair of decoders make a decoder for a pair. */ + def product[B](fb: Decoder[B]): Decoder[(A, B)] = + new Decoder((resultSet, prefix) => (this.decode(resultSet, None), fb.decode(resultSet, None))) + + /** Lift this `Decoder` into `Option`. */ + def opt: Decoder[Option[A]] = + new Decoder((resultSet, prefix) => + val value = this.decode(resultSet, prefix) + if resultSet.wasNull() then None else Some(value) + ) + object Decoder: given Functor[[A] =>> Decoder[A]] with diff --git a/module/ldbc-dsl/src/main/scala/ldbc/dsl/codec/Encoder.scala b/module/ldbc-dsl/src/main/scala/ldbc/dsl/codec/Encoder.scala index aa33dd3c7..e2992b47d 100644 --- a/module/ldbc-dsl/src/main/scala/ldbc/dsl/codec/Encoder.scala +++ b/module/ldbc-dsl/src/main/scala/ldbc/dsl/codec/Encoder.scala @@ -10,6 +10,8 @@ import java.time.* import scala.compiletime.* +import cats.data.NonEmptyList + /** * Trait for converting Scala types to types that can be handled by PreparedStatement. * @@ -26,7 +28,19 @@ trait Encoder[A]: * @return * Types that can be handled by PreparedStatement */ - def encode(value: A): Encoder.Supported + def encode(value: A): Encoder.Encoded + + /** Contramap inputs from a new type `B`, yielding an `Encoder[B]`. */ + def contramap[B](func: B => A): Encoder[B] = (value: B) => encode(func(value)) + + /** Map outputs to a new type `B`, yielding an `Encoder[B]`. */ + def either[B](func: B => Either[String, A]): Encoder[B] = (value: B) => + func(value) match + case Right(value) => encode(value) + case Left(error) => Encoder.Encoded.failure(error) + + /** `Encoder` is semigroupal: a pair of encoders make a encoder for a pair. */ + def product[B](that: Encoder[B]): Encoder[(A, B)] = (value: (A, B)) => encode(value._1) product that.encode(value._2) object Encoder: @@ -34,59 +48,75 @@ object Encoder: type Supported = Boolean | Byte | Short | Int | Long | Float | Double | BigDecimal | String | Array[Byte] | LocalTime | LocalDate | LocalDateTime | None.type + def apply[A](using encoder: Encoder[A]): Encoder[A] = encoder + given Encoder[Boolean] with - override def encode(value: Boolean): Boolean = value + override def encode(value: Boolean): Encoded = + Encoded.success(List(value)) given Encoder[Byte] with - override def encode(value: Byte): Byte = value + override def encode(value: Byte): Encoded = + Encoded.success(List(value)) given Encoder[Short] with - override def encode(value: Short): Short = value + override def encode(value: Short): Encoded = + Encoded.success(List(value)) given Encoder[Int] with - override def encode(value: Int): Int = value + override def encode(value: Int): Encoded = + Encoded.success(List(value)) given Encoder[Long] with - override def encode(value: Long): Long = value + override def encode(value: Long): Encoded = + Encoded.success(List(value)) given Encoder[Float] with - override def encode(value: Float): Float = value + override def encode(value: Float): Encoded = + Encoded.success(List(value)) given Encoder[Double] with - override def encode(value: Double): Double = value + override def encode(value: Double): Encoded = + Encoded.success(List(value)) given Encoder[BigDecimal] with - override def encode(value: BigDecimal): BigDecimal = value + override def encode(value: BigDecimal): Encoded = + Encoded.success(List(value)) given Encoder[String] with - override def encode(value: String): String = value + override def encode(value: String): Encoded = + Encoded.success(List(value)) given Encoder[Array[Byte]] with - override def encode(value: Array[Byte]): Array[Byte] = value + override def encode(value: Array[Byte]): Encoded = + Encoded.success(List(value)) given Encoder[LocalTime] with - override def encode(value: LocalTime): LocalTime = value + override def encode(value: LocalTime): Encoded = + Encoded.success(List(value)) given Encoder[LocalDate] with - override def encode(value: LocalDate): LocalDate = value + override def encode(value: LocalDate): Encoded = + Encoded.success(List(value)) given Encoder[LocalDateTime] with - override def encode(value: LocalDateTime): LocalDateTime = value + override def encode(value: LocalDateTime): Encoded = + Encoded.success(List(value)) + + given (using encoder: Encoder[String]): Encoder[Year] = encoder.contramap(_.toString) - given Encoder[Year] with - override def encode(value: Year): String = value.toString + given (using encoder: Encoder[String]): Encoder[YearMonth] = encoder.contramap(_.toString) - given Encoder[YearMonth] with - override def encode(value: YearMonth): String = value.toString + given (using encoder: Encoder[String]): Encoder[BigInt] = encoder.contramap(_.toString) given Encoder[None.type] with - override def encode(value: None.type): None.type = value + override def encode(value: None.type): Encoded = + Encoded.success(List(None)) given [A](using encoder: Encoder[A]): Encoder[Option[A]] with - override def encode(value: Option[A]): Encoder.Supported = + override def encode(value: Option[A]): Encoded = value match case Some(value) => encoder.encode(value) - case None => None + case None => Encoded.success(List(None)) type MapToTuple[T] <: Tuple = T match case EmptyTuple => EmptyTuple @@ -94,3 +124,19 @@ object Encoder: case h *: t => Encoder[h] *: MapToTuple[t] inline def fold[T]: MapToTuple[T] = summonAll[MapToTuple[T]] + + sealed trait Encoded: + def product(that: Encoded): Encoded + object Encoded: + case class Success(value: List[Encoder.Supported]) extends Encoded: + override def product(that: Encoded): Encoded = that match + case Success(value) => Success(this.value ::: value) + case Failure(errors) => Failure(errors) + case class Failure(errors: NonEmptyList[String]) extends Encoded: + override def product(that: Encoded): Encoded = that match + case Success(value) => Failure(errors) + case Failure(errors) => Failure(errors ::: this.errors) + + def success(value: List[Encoder.Supported]): Encoded = Success(value) + def failure(error: String, errors: String*): Encoded = + Failure(NonEmptyList(error, errors.toList)) diff --git a/module/ldbc-dsl/src/main/scala/ldbc/dsl/package.scala b/module/ldbc-dsl/src/main/scala/ldbc/dsl/package.scala index 7295b27d2..4acee0e7f 100644 --- a/module/ldbc-dsl/src/main/scala/ldbc/dsl/package.scala +++ b/module/ldbc-dsl/src/main/scala/ldbc/dsl/package.scala @@ -6,6 +6,8 @@ package ldbc +import java.time.* + import scala.deriving.Mirror import cats.{ Foldable, Functor, Reducible } @@ -14,11 +16,47 @@ import cats.syntax.all.* import cats.effect.* +import ldbc.sql.PreparedStatement import ldbc.dsl.syntax.* import ldbc.dsl.codec.Encoder package object dsl: + private[ldbc] trait ParamBinder[F[_]: Temporal]: + protected def paramBind( + prepareStatement: PreparedStatement[F], + params: List[Parameter.Dynamic] + ): F[Unit] = + val encoded = params.foldLeft(Temporal[F].pure(List.empty[Encoder.Supported])) { + case (acc, param) => + for + acc$ <- acc + value <- param match + case Parameter.Dynamic.Success(value) => Temporal[F].pure(value) + case Parameter.Dynamic.Failure(errors) => + Temporal[F].raiseError(new IllegalArgumentException(errors.mkString(", "))) + yield acc$ :+ value + } + encoded.flatMap(_.zipWithIndex.foldLeft(Temporal[F].unit) { + case (acc, (value, index)) => + acc *> (value match + case value: Boolean => prepareStatement.setBoolean(index + 1, value) + case value: Byte => prepareStatement.setByte(index + 1, value) + case value: Short => prepareStatement.setShort(index + 1, value) + case value: Int => prepareStatement.setInt(index + 1, value) + case value: Long => prepareStatement.setLong(index + 1, value) + case value: Float => prepareStatement.setFloat(index + 1, value) + case value: Double => prepareStatement.setDouble(index + 1, value) + case value: BigDecimal => prepareStatement.setBigDecimal(index + 1, value) + case value: String => prepareStatement.setString(index + 1, value) + case value: Array[Byte] => prepareStatement.setBytes(index + 1, value) + case value: LocalDate => prepareStatement.setDate(index + 1, value) + case value: LocalTime => prepareStatement.setTime(index + 1, value) + case value: LocalDateTime => prepareStatement.setTimestamp(index + 1, value) + case None => prepareStatement.setNull(index + 1, ldbc.sql.Types.NULL) + ) + }) + private[ldbc] trait SyncSyntax[F[_]: Temporal] extends StringContextSyntax[F]: /** @@ -47,9 +85,9 @@ package object dsl: val params = tuple.toList Mysql[F]( List.fill(params.size)("?").mkString(","), - (Tuple.fromProduct(v).toList zip params).map { + (Tuple.fromProduct(v).toList zip params).flatMap { case (value, param) => - Parameter.Dynamic(value.asInstanceOf[Any])(using param.asInstanceOf[Encoder[Any]]) + Parameter.Dynamic.many(param.asInstanceOf[Encoder[Any]].encode(value.asInstanceOf[Any])) } ) diff --git a/module/ldbc-query-builder/src/main/scala/ldbc/query/builder/Table.scala b/module/ldbc-query-builder/src/main/scala/ldbc/query/builder/Table.scala index 444dbc337..4d11de303 100644 --- a/module/ldbc-query-builder/src/main/scala/ldbc/query/builder/Table.scala +++ b/module/ldbc-query-builder/src/main/scala/ldbc/query/builder/Table.scala @@ -10,7 +10,7 @@ import scala.language.dynamics import scala.deriving.Mirror import scala.quoted.* -import ldbc.dsl.codec.Decoder +import ldbc.dsl.codec.* import ldbc.statement.{ AbstractTable, Column } import ldbc.query.builder.formatter.Naming import ldbc.query.builder.interpreter.* @@ -66,11 +66,28 @@ object Table: .asInstanceOf[mirror.MirroredElemTypes] ) ) + + val encoder: Encoder[P] = (value: P) => + val list: List[(Any, Column[?])] = Tuple.fromProduct(value).toList.zip(columns) + list + .map { case (value, column) => column.encoder.encode(value.asInstanceOf) } + .foldLeft(Encoder.Encoded.success(List.empty[Encoder.Supported])) { + case (Encoder.Encoded.Success(fs1), Encoder.Encoded.Success(fs2)) => + Encoder.Encoded.success(fs1 ::: fs2) + case (Encoder.Encoded.Failure(e1), Encoder.Encoded.Failure(e2)) => + Encoder.Encoded.failure(e1.head, (e1.tail ++ e2.toList)*) + case (Encoder.Encoded.Failure(e), _) => + Encoder.Encoded.failure(e.head, e.tail*) + case (_, Encoder.Encoded.Failure(e)) => + Encoder.Encoded.failure(e.head, e.tail*) + } + val alias = columns.flatMap(_.alias).mkString(", ") Column.Impl[P]( columns.map(_.name).mkString(", "), if alias.isEmpty then None else Some(alias), decoder, + encoder, Some(columns.length), Some(columns.map(column => s"${ column.name } = ?").mkString(", ")) ) @@ -107,7 +124,7 @@ object Table: '{ $naming.format($name) } } - val decodes = Expr.ofSeq( + val codecs = Expr.ofSeq( symbol.caseFields .map { field => field.tree match @@ -117,7 +134,10 @@ object Table: val decoder = Expr.summon[Decoder.Elem[tpe]].getOrElse { report.errorAndAbort(s"Decoder for type $tpe not found") } - decoder.asExprOf[Decoder.Elem[?]] + val encoder = Expr.summon[Encoder[tpe]].getOrElse { + report.errorAndAbort(s"Encoder for type $tpe not found") + } + '{ ($decoder, $encoder) } case _ => report.errorAndAbort(s"Type $tpt is not a type") } @@ -127,9 +147,10 @@ object Table: val columns = '{ ${ Expr.ofSeq(labels) } - .zip($decodes) + .zip($codecs) .map { - case (label: String, decoder: Decoder.Elem[t]) => Column[t](label, $naming.format($name))(using decoder) + case (label: String, codec: (Decoder.Elem[t], Encoder[?])) => + Column[t](label, $naming.format($name))(using codec._1, codec._2.asInstanceOf[Encoder[t]]) } .toList } @@ -170,7 +191,7 @@ object Table: '{ $naming.format($name) } } - val decodes = Expr.ofSeq( + val codecs = Expr.ofSeq( symbol.caseFields .map { field => field.tree match @@ -180,7 +201,10 @@ object Table: val decoder = Expr.summon[Decoder.Elem[tpe]].getOrElse { report.errorAndAbort(s"Decoder for type $tpe not found") } - decoder.asExprOf[Decoder.Elem[?]] + val encoder = Expr.summon[Encoder[tpe]].getOrElse { + report.errorAndAbort(s"Encoder for type $tpe not found") + } + '{ ($decoder, $encoder) } case _ => report.errorAndAbort(s"Type $tpt is not a type") } @@ -188,9 +212,10 @@ object Table: val columns = '{ ${ Expr.ofSeq(labels) } - .zip($decodes) + .zip($codecs) .map { - case (label: String, decoder: Decoder.Elem[t]) => Column[t](label, $name)(using decoder) + case (label: String, codec: (Decoder.Elem[t], Encoder[?])) => + Column[t](label, $name)(using codec._1, codec._2.asInstanceOf[Encoder[t]]) } .toList } diff --git a/module/ldbc-query-builder/src/main/scala/ldbc/query/builder/TableQuery.scala b/module/ldbc-query-builder/src/main/scala/ldbc/query/builder/TableQuery.scala index 051f17cef..73dcb6ee8 100644 --- a/module/ldbc-query-builder/src/main/scala/ldbc/query/builder/TableQuery.scala +++ b/module/ldbc-query-builder/src/main/scala/ldbc/query/builder/TableQuery.scala @@ -9,7 +9,7 @@ package ldbc.query.builder import scala.deriving.Mirror import ldbc.dsl.Parameter -import ldbc.dsl.codec.Decoder +import ldbc.dsl.codec.* import ldbc.statement.{ TableQuery as AbstractTableQuery, * } private[ldbc] case class TableQueryImpl[A <: SharedTable & AbstractTable[?], B <: Product]( @@ -39,6 +39,7 @@ private[ldbc] case class TableQueryImpl[A <: SharedTable & AbstractTable[?], B < table.columns.map(_.name).mkString(", "), if alias.isEmpty then None else Some(alias), decoder, + column.opt.encoder.asInstanceOf[Encoder[Option[B]]], Some(table.columns.length), Some(table.columns.map(column => s"${ column.name } = ?").mkString(", ")) ) diff --git a/module/ldbc-query-builder/src/main/scala/ldbc/query/builder/syntax/package.scala b/module/ldbc-query-builder/src/main/scala/ldbc/query/builder/syntax/package.scala index fea69638e..eb2a4cc16 100644 --- a/module/ldbc-query-builder/src/main/scala/ldbc/query/builder/syntax/package.scala +++ b/module/ldbc-query-builder/src/main/scala/ldbc/query/builder/syntax/package.scala @@ -23,7 +23,7 @@ import ldbc.statement.syntax.* package object syntax: - private trait SyncSyntax[F[_]: Temporal] extends QuerySyntax[F], CommandSyntax[F], DslSyntax[F]: + private trait SyncSyntax[F[_]: Temporal] extends QuerySyntax[F], CommandSyntax[F], DslSyntax[F], ParamBinder[F]: type TableQuery[T] = ldbc.statement.TableQuery[Table[T], Table.Opt[T]] val TableQuery = ldbc.query.builder.TableQuery @@ -49,9 +49,9 @@ package object syntax: connection => for prepareStatement <- connection.prepareStatement(command.statement) - result <- command.params.zipWithIndex.traverse { - case (param, index) => param.bind[F](prepareStatement, index + 1) - } >> prepareStatement.executeUpdate() <* prepareStatement.close() + result <- + paramBind(prepareStatement, command.params) >> prepareStatement.executeUpdate() <* prepareStatement + .close() yield result ) @@ -64,9 +64,9 @@ package object syntax: connection => for prepareStatement <- connection.prepareStatement(command.statement, Statement.RETURN_GENERATED_KEYS) - resultSet <- command.params.zipWithIndex.traverse { - case (param, index) => param.bind[F](prepareStatement, index + 1) - } >> prepareStatement.executeUpdate() >> prepareStatement.getGeneratedKeys() + resultSet <- + paramBind(prepareStatement, command.params) >> prepareStatement.executeUpdate() >> prepareStatement + .getGeneratedKeys() result <- summon[ResultSetConsumer[F, T]].consume(resultSet) <* prepareStatement.close() yield result ) diff --git a/module/ldbc-schema/src/main/scala/ldbc/schema/ColumnImpl.scala b/module/ldbc-schema/src/main/scala/ldbc/schema/ColumnImpl.scala index f18aa169e..55b23777e 100644 --- a/module/ldbc-schema/src/main/scala/ldbc/schema/ColumnImpl.scala +++ b/module/ldbc-schema/src/main/scala/ldbc/schema/ColumnImpl.scala @@ -6,7 +6,7 @@ package ldbc.schema -import ldbc.dsl.codec.Decoder +import ldbc.dsl.codec.{ Decoder, Encoder } import ldbc.statement.Column import ldbc.schema.attribute.Attribute @@ -14,6 +14,7 @@ private[ldbc] case class ColumnImpl[T]( name: String, alias: Option[String], decoder: Decoder[T], + encoder: Encoder[T], dataType: Option[DataType[T]], attributes: List[Attribute[T]] ) extends Column[T]: diff --git a/module/ldbc-schema/src/main/scala/ldbc/schema/Table.scala b/module/ldbc-schema/src/main/scala/ldbc/schema/Table.scala index 84f2cd5b7..6a2a8dd13 100644 --- a/module/ldbc-schema/src/main/scala/ldbc/schema/Table.scala +++ b/module/ldbc-schema/src/main/scala/ldbc/schema/Table.scala @@ -9,7 +9,7 @@ package ldbc.schema import scala.language.dynamics import scala.deriving.Mirror -import ldbc.dsl.codec.Decoder +import ldbc.dsl.codec.{ Decoder, Encoder } import ldbc.statement.{ AbstractTable, Column } import ldbc.schema.interpreter.* import ldbc.schema.attribute.Attribute @@ -18,19 +18,23 @@ trait Table[T](val $name: String) extends AbstractTable[T]: type Column[A] = ldbc.statement.Column[A] - protected final def column[A](name: String)(using elem: Decoder.Elem[A]): Column[A] = + protected final def column[A](name: String)(using elem: Decoder.Elem[A], encoder: Encoder[A]): Column[A] = val decoder = new Decoder[A]((resultSet, prefix) => elem.decode(resultSet, prefix.getOrElse(s"${ $name }.$name"))) - ColumnImpl[A](name, Some(s"${ $name }.$name"), decoder, None, List.empty) + ColumnImpl[A](name, Some(s"${ $name }.$name"), decoder, encoder, None, List.empty) - protected final def column[A](name: String, dataType: DataType[A])(using elem: Decoder.Elem[A]): Column[A] = + protected final def column[A](name: String, dataType: DataType[A])(using + elem: Decoder.Elem[A], + encoder: Encoder[A] + ): Column[A] = val decoder = new Decoder[A]((resultSet, prefix) => elem.decode(resultSet, prefix.getOrElse(s"${ $name }.$name"))) - ColumnImpl[A](name, Some(s"${ $name }.$name"), decoder, Some(dataType), List.empty) + ColumnImpl[A](name, Some(s"${ $name }.$name"), decoder, encoder, Some(dataType), List.empty) protected final def column[A](name: String, dataType: DataType[A], attributes: Attribute[A]*)(using - elem: Decoder.Elem[A] + elem: Decoder.Elem[A], + encoder: Encoder[A] ): Column[A] = val decoder = new Decoder[A]((resultSet, prefix) => elem.decode(resultSet, prefix.getOrElse(s"${ $name }.$name"))) - ColumnImpl[A](name, Some(s"${ $name }.$name"), decoder, Some(dataType), attributes.toList) + ColumnImpl[A](name, Some(s"${ $name }.$name"), decoder, encoder, Some(dataType), attributes.toList) /** * Methods for setting key information for tables. diff --git a/module/ldbc-schema/src/main/scala/ldbc/schema/TableQuery.scala b/module/ldbc-schema/src/main/scala/ldbc/schema/TableQuery.scala index 5bb333a93..b2efc8984 100644 --- a/module/ldbc-schema/src/main/scala/ldbc/schema/TableQuery.scala +++ b/module/ldbc-schema/src/main/scala/ldbc/schema/TableQuery.scala @@ -23,6 +23,7 @@ private[ldbc] case class TableQueryImpl[A <: AbstractTable[?]]( column.name, column.alias, column.opt.decoder, + column.opt.encoder, Some(column.values), Some(column.updateStatement) ) diff --git a/module/ldbc-schema/src/main/scala/ldbc/schema/syntax/package.scala b/module/ldbc-schema/src/main/scala/ldbc/schema/syntax/package.scala index dd80b737c..b0863fc28 100644 --- a/module/ldbc-schema/src/main/scala/ldbc/schema/syntax/package.scala +++ b/module/ldbc-schema/src/main/scala/ldbc/schema/syntax/package.scala @@ -23,7 +23,7 @@ import ldbc.statement.syntax.* package object syntax: - private trait SyncSyntax[F[_]: Temporal] extends QuerySyntax[F], CommandSyntax[F], DslSyntax[F]: + private trait SyncSyntax[F[_]: Temporal] extends QuerySyntax[F], CommandSyntax[F], DslSyntax[F], ParamBinder[F]: extension [A, B](query: Query[A, B]) @@ -46,9 +46,9 @@ package object syntax: connection => for prepareStatement <- connection.prepareStatement(command.statement) - result <- command.params.zipWithIndex.traverse { - case (param, index) => param.bind[F](prepareStatement, index + 1) - } >> prepareStatement.executeUpdate() <* prepareStatement.close() + result <- + paramBind(prepareStatement, command.params) >> prepareStatement.executeUpdate() <* prepareStatement + .close() yield result ) @@ -61,9 +61,9 @@ package object syntax: connection => for prepareStatement <- connection.prepareStatement(command.statement, Statement.RETURN_GENERATED_KEYS) - resultSet <- command.params.zipWithIndex.traverse { - case (param, index) => param.bind[F](prepareStatement, index + 1) - } >> prepareStatement.executeUpdate() >> prepareStatement.getGeneratedKeys() + resultSet <- + paramBind(prepareStatement, command.params) >> prepareStatement.executeUpdate() >> prepareStatement + .getGeneratedKeys() result <- summon[ResultSetConsumer[F, T]].consume(resultSet) <* prepareStatement.close() yield result ) diff --git a/module/ldbc-schema/src/test/scala/ldbc/schema/ColumnImplTest.scala b/module/ldbc-schema/src/test/scala/ldbc/schema/ColumnImplTest.scala index 180c796ff..a45914ac1 100644 --- a/module/ldbc-schema/src/test/scala/ldbc/schema/ColumnImplTest.scala +++ b/module/ldbc-schema/src/test/scala/ldbc/schema/ColumnImplTest.scala @@ -8,17 +8,18 @@ package ldbc.schema import org.scalatest.flatspec.AnyFlatSpec -import ldbc.dsl.codec.Decoder +import ldbc.dsl.codec.{ Decoder, Encoder } import ldbc.schema.DataType.* import ldbc.schema.attribute.* class ColumnImplTest extends AnyFlatSpec: private def column[A](name: String, dataType: DataType[A], attributes: Attribute[A]*)(using - elem: Decoder.Elem[A] + elem: Decoder.Elem[A], + encoder: Encoder[A] ): Column[A] = val decoder = new Decoder[A]((resultSet, prefix) => elem.decode(resultSet, prefix.getOrElse(name))) - ColumnImpl[A](name, None, decoder, Some(dataType), attributes.toList) + ColumnImpl[A](name, None, decoder, encoder, Some(dataType), attributes.toList) it should "The query string of the Column model generated with only label and DataType matches the specified string." in { assert(column[Long]("id", BIGINT).statement === "`id` BIGINT NOT NULL") diff --git a/module/ldbc-statement/src/main/scala/ldbc/statement/Column.scala b/module/ldbc-statement/src/main/scala/ldbc/statement/Column.scala index 3f709badd..8f0b84a2a 100644 --- a/module/ldbc-statement/src/main/scala/ldbc/statement/Column.scala +++ b/module/ldbc-statement/src/main/scala/ldbc/statement/Column.scala @@ -8,7 +8,7 @@ package ldbc.statement import scala.annotation.targetName -import cats.Applicative +import cats.InvariantSemigroupal import org.typelevel.twiddles.TwiddleSyntax @@ -21,10 +21,11 @@ import ldbc.statement.Expression.* /** * Trait for representing SQL Column * - * @tparam T + * @tparam A * Scala types that match SQL DataType */ -trait Column[T]: +trait Column[A]: + self => /** Column Field Name */ def name: String @@ -33,10 +34,13 @@ trait Column[T]: def alias: Option[String] /** Functions for setting aliases on columns */ - def as(name: String): Column[T] + def as(name: String): Column[A] /** Function to get a value of type T from a ResultSet */ - def decoder: Decoder[T] + def decoder: Decoder[A] + + /** Function to set a value of type T to a PreparedStatement */ + def encoder: Encoder[A] /** Indicator of how many columns are specified */ def values: Int = 1 @@ -53,19 +57,69 @@ trait Column[T]: /** Used in Update statement `Column = VALUES(Column), Column = VALUES(Column)` used in the Duplicate Key Update statement */ def duplicateKeyUpdateStatement: String - def opt: Column[Option[T]] = Column.Opt[T](name, alias, decoder) + def opt: Column[Option[A]] = Column.Opt[A](name, alias, decoder, encoder) def count(using decoder: Decoder.Elem[Int]): Column.Count = Column.Count(name, alias) - def asc: OrderBy.Order[T] = OrderBy.Order.asc(this) - def desc: OrderBy.Order[T] = OrderBy.Order.desc(this) + def asc: OrderBy.Order[A] = OrderBy.Order.asc(this) + def desc: OrderBy.Order[A] = OrderBy.Order.desc(this) private lazy val noBagQuotLabel: String = alias.getOrElse(name) private[ldbc] def list: List[Column[?]] = List(this) - def _equals(value: Extract[T])(using Encoder[Extract[T]]): MatchCondition[T] = - MatchCondition[T](noBagQuotLabel, false, value) + def imap[B](f: A => B)(g: B => A): Column[B] = + new Column[B]: + override def name: String = self.name + override def alias: Option[String] = self.alias + override def as(name: String): Column[B] = this + override def decoder: Decoder[B] = + new Decoder[B]((resultSet: ResultSet, prefix: Option[String]) => f(self.decoder.decode(resultSet, prefix))) + override def encoder: Encoder[B] = + (value: B) => self.encoder.encode(g(value)) + override def updateStatement: String = self.updateStatement + override def duplicateKeyUpdateStatement: String = self.duplicateKeyUpdateStatement + override def values: Int = self.values + override private[ldbc] def list: List[Column[?]] = self.list + + def product[B](fb: Column[B]): Column[(A, B)] = + new Column[(A, B)]: + override def name: String = s"${ self.name }, ${ fb.name }" + override def alias: Option[String] = (self.alias, fb.alias) match + case (Some(a), Some(b)) => Some(s"$a, $b") + case (Some(a), None) => Some(a) + case (None, Some(b)) => Some(b) + case (None, None) => None + override def as(name: String): Column[(A, B)] = this + override def decoder: Decoder[(A, B)] = self.decoder.product(fb.decoder) + override def encoder: Encoder[(A, B)] = self.encoder.product(fb.encoder) + override def updateStatement: String = s"${ self.updateStatement }, ${ fb.updateStatement }" + override def duplicateKeyUpdateStatement: String = + s"${ self.duplicateKeyUpdateStatement }, ${ fb.duplicateKeyUpdateStatement }" + override def values: Int = self.values + fb.values + override def opt: Column[Option[(A, B)]] = + val decoder = new Decoder[Option[(A, B)]]((resultSet: ResultSet, prefix: Option[String]) => + for + v1 <- self.opt.decoder.decode(resultSet, prefix) + v2 <- fb.opt.decoder.decode(resultSet, prefix) + yield (v1, v2) + ) + val encoder: Encoder[Option[(A, B)]] = { + case Some((v1, v2)) => self.opt.encoder.encode(Some(v1)).product(fb.opt.encoder.encode(Some(v2))) + case None => self.opt.encoder.encode(None).product(fb.opt.encoder.encode(None)) + } + Column.Impl[Option[(A, B)]]( + name, + alias, + decoder, + encoder, + Some(values), + Some(updateStatement) + ) + override private[ldbc] def list: List[Column[?]] = self.list ++ fb.list + + def _equals(value: Extract[A])(using Encoder[Extract[A]]): MatchCondition[A] = + MatchCondition[A](noBagQuotLabel, false, value) /** * A function that sets a WHERE condition to check whether the values match in a SELECT statement. @@ -81,10 +135,10 @@ trait Column[T]: * A query to check whether the values match in a Where statement */ @targetName("matchCondition") - def ===(value: Extract[T])(using Encoder[Extract[T]]): MatchCondition[T] = _equals(value) + def ===(value: Extract[A])(using Encoder[Extract[A]]): MatchCondition[A] = _equals(value) - def orMore(value: Extract[T])(using Encoder[Extract[T]]): OrMore[T] = - OrMore[T](noBagQuotLabel, false, value) + def orMore(value: Extract[A])(using Encoder[Extract[A]]): OrMore[A] = + OrMore[A](noBagQuotLabel, false, value) /** * A function that sets a WHERE condition to check whether the values are greater than or equal to in a SELECT statement. @@ -100,10 +154,10 @@ trait Column[T]: * A query to check whether the values are greater than or equal to in a Where statement */ @targetName("_orMore") - def >=(value: Extract[T])(using Encoder[Extract[T]]): OrMore[T] = orMore(value) + def >=(value: Extract[A])(using Encoder[Extract[A]]): OrMore[A] = orMore(value) - def over(value: Extract[T])(using Encoder[Extract[T]]): Over[T] = - Over[T](noBagQuotLabel, false, value) + def over(value: Extract[A])(using Encoder[Extract[A]]): Over[A] = + Over[A](noBagQuotLabel, false, value) /** * A function that sets a WHERE condition to check whether the values are greater than in a SELECT statement. @@ -119,10 +173,10 @@ trait Column[T]: * A query to check whether the values are greater than in a Where statement */ @targetName("_over") - def >(value: Extract[T])(using Encoder[Extract[T]]): Over[T] = over(value) + def >(value: Extract[A])(using Encoder[Extract[A]]): Over[A] = over(value) - def lessThanOrEqual(value: Extract[T])(using Encoder[Extract[T]]): LessThanOrEqualTo[T] = - LessThanOrEqualTo[T](noBagQuotLabel, false, value) + def lessThanOrEqual(value: Extract[A])(using Encoder[Extract[A]]): LessThanOrEqualTo[A] = + LessThanOrEqualTo[A](noBagQuotLabel, false, value) /** * A function that sets a WHERE condition to check whether the values are less than or equal to in a SELECT statement. @@ -138,10 +192,10 @@ trait Column[T]: * A query to check whether the values are less than or equal to in a Where statement */ @targetName("_lessThanOrEqual") - def <=(value: Extract[T])(using Encoder[Extract[T]]): LessThanOrEqualTo[T] = lessThanOrEqual(value) + def <=(value: Extract[A])(using Encoder[Extract[A]]): LessThanOrEqualTo[A] = lessThanOrEqual(value) - def lessThan(value: Extract[T])(using Encoder[Extract[T]]): LessThan[T] = - LessThan[T](noBagQuotLabel, false, value) + def lessThan(value: Extract[A])(using Encoder[Extract[A]]): LessThan[A] = + LessThan[A](noBagQuotLabel, false, value) /** * A function that sets a WHERE condition to check whether the values are less than in a SELECT statement. @@ -157,10 +211,10 @@ trait Column[T]: * A query to check whether the values are less than in a Where statement */ @targetName("_lessThan") - def <(value: Extract[T])(using Encoder[Extract[T]]): LessThan[T] = lessThan(value) + def <(value: Extract[A])(using Encoder[Extract[A]]): LessThan[A] = lessThan(value) - def notEqual(value: Extract[T])(using Encoder[Extract[T]]): NotEqual[T] = - NotEqual[T]("<>", noBagQuotLabel, false, value) + def notEqual(value: Extract[A])(using Encoder[Extract[A]]): NotEqual[A] = + NotEqual[A]("<>", noBagQuotLabel, false, value) /** * A function that sets a WHERE condition to check whether the values are not equal in a SELECT statement. @@ -176,7 +230,7 @@ trait Column[T]: * A query to check whether the values are not equal in a Where statement */ @targetName("_notEqual") - def <>(value: Extract[T])(using Encoder[Extract[T]]): NotEqual[T] = notEqual(value) + def <>(value: Extract[A])(using Encoder[Extract[A]]): NotEqual[A] = notEqual(value) /** * A function that sets a WHERE condition to check whether the values are not equal in a SELECT statement. @@ -192,8 +246,8 @@ trait Column[T]: * A query to check whether the values are not equal in a Where statement */ @targetName("_!equal") - def !==(value: Extract[T])(using Encoder[Extract[T]]): NotEqual[T] = - NotEqual[T]("!=", noBagQuotLabel, false, value) + def !==(value: Extract[A])(using Encoder[Extract[A]]): NotEqual[A] = + NotEqual[A]("!=", noBagQuotLabel, false, value) /** * A function that sets a WHERE condition to check whether the values are equal in a SELECT statement. @@ -213,8 +267,8 @@ trait Column[T]: def IS[A <: "TRUE" | "FALSE" | "UNKNOWN" | "NULL"](value: A): Is[A] = Is[A](noBagQuotLabel, false, value) - def nullSafeEqual(value: Extract[T])(using Encoder[Extract[T]]): NullSafeEqual[T] = - NullSafeEqual[T](noBagQuotLabel, false, value) + def nullSafeEqual(value: Extract[A])(using Encoder[Extract[A]]): NullSafeEqual[A] = + NullSafeEqual[A](noBagQuotLabel, false, value) /** * A function that sets a WHERE condition to check whether the values are equal in a SELECT statement. @@ -230,7 +284,7 @@ trait Column[T]: * A query to check whether the values are equal in a Where statement */ @targetName("_nullSafeEqual") - def <=>(value: Extract[T])(using Encoder[Extract[T]]): NullSafeEqual[T] = nullSafeEqual(value) + def <=>(value: Extract[A])(using Encoder[Extract[A]]): NullSafeEqual[A] = nullSafeEqual(value) /** * A function that sets a WHERE condition to check whether the values are equal in a SELECT statement. @@ -245,8 +299,8 @@ trait Column[T]: * @return * A query to check whether the values are equal in a Where statement */ - def IN(value: Extract[T]*)(using Encoder[Extract[T]]): In[T] = - In[T](noBagQuotLabel, false, value*) + def IN(value: Extract[A]*)(using Encoder[Extract[A]]): In[A] = + In[A](noBagQuotLabel, false, value*) /** * A function that sets a WHERE condition to check whether a value is included in a specified range in a SELECT statement. @@ -263,8 +317,8 @@ trait Column[T]: * @return * A query to check whether the value is included in a specified range in a Where statement */ - def BETWEEN(start: Extract[T], end: Extract[T])(using Encoder[Extract[T]]): Between[T] = - Between[T](noBagQuotLabel, false, start, end) + def BETWEEN(start: Extract[A], end: Extract[A])(using Encoder[Extract[A]]): Between[A] = + Between[A](noBagQuotLabel, false, start, end) /** * A function to set a WHERE condition to check a value with an ambiguous search in a SELECT statement. @@ -279,8 +333,8 @@ trait Column[T]: * @return * A query to check a value with an ambiguous search in a Where statement */ - def LIKE(value: Extract[T])(using Encoder[Extract[T]]): Like[T] = - Like[T](noBagQuotLabel, false, value) + def LIKE(value: Extract[A])(using Encoder[Extract[A]]): Like[A] = + Like[A](noBagQuotLabel, false, value) /** * A function to set a WHERE condition to check a value with an ambiguous search in a SELECT statement. @@ -297,8 +351,8 @@ trait Column[T]: * @return * A query to check a value with an ambiguous search in a Where statement */ - def LIKE_ESCAPE(like: Extract[T], escape: Extract[T])(using Encoder[Extract[T]]): LikeEscape[T] = - LikeEscape[T](noBagQuotLabel, false, like, escape) + def LIKE_ESCAPE(like: Extract[A], escape: Extract[A])(using Encoder[Extract[A]]): LikeEscape[A] = + LikeEscape[A](noBagQuotLabel, false, like, escape) /** * A function to set a WHERE condition to check values in a regular expression in a SELECT statement. @@ -313,11 +367,11 @@ trait Column[T]: * @return * A query to check values in a regular expression in a Where statement */ - def REGEXP(value: Extract[T])(using Encoder[Extract[T]]): Regexp[T] = - Regexp[T](noBagQuotLabel, false, value) + def REGEXP(value: Extract[A])(using Encoder[Extract[A]]): Regexp[A] = + Regexp[A](noBagQuotLabel, false, value) - def leftShift(value: Extract[T])(using Encoder[Extract[T]]): LeftShift[T] = - LeftShift[T](noBagQuotLabel, false, value) + def leftShift(value: Extract[A])(using Encoder[Extract[A]]): LeftShift[A] = + LeftShift[A](noBagQuotLabel, false, value) /** * A function to set a WHERE condition to check whether the values are shifted to the left in a SELECT statement. @@ -333,10 +387,10 @@ trait Column[T]: * A query to check whether the values are shifted to the left in a Where statement */ @targetName("_leftShift") - def <<(value: Extract[T])(using Encoder[Extract[T]]): LeftShift[T] = leftShift(value) + def <<(value: Extract[A])(using Encoder[Extract[A]]): LeftShift[A] = leftShift(value) - def rightShift(value: Extract[T])(using Encoder[Extract[T]]): RightShift[T] = - RightShift[T](noBagQuotLabel, false, value) + def rightShift(value: Extract[A])(using Encoder[Extract[A]]): RightShift[A] = + RightShift[A](noBagQuotLabel, false, value) /** * A function to set a WHERE condition to check whether the values are shifted to the right in a SELECT statement. @@ -352,7 +406,7 @@ trait Column[T]: * A query to check whether the values are shifted to the right in a Where statement */ @targetName("_rightShift") - def >>(value: Extract[T])(using Encoder[Extract[T]]): RightShift[T] = rightShift(value) + def >>(value: Extract[A])(using Encoder[Extract[A]]): RightShift[A] = rightShift(value) /** * A function to set a WHERE condition to check whether the values are added in a SELECT statement. @@ -369,8 +423,8 @@ trait Column[T]: * @return * A query to check whether the values are added in a Where statement */ - def DIV(cond: Extract[T], result: Extract[T])(using Encoder[Extract[T]]): Div[T] = - Div[T](noBagQuotLabel, false, cond, result) + def DIV(cond: Extract[A], result: Extract[A])(using Encoder[Extract[A]]): Div[A] = + Div[A](noBagQuotLabel, false, cond, result) /** * A function to set the WHERE condition for modulo operations in a SELECT statement. @@ -387,11 +441,11 @@ trait Column[T]: * @return * A query to check the modulo operation in a Where statement */ - def MOD(cond: Extract[T], result: Extract[T])(using Encoder[Extract[T]]): Mod[T] = - Mod[T]("MOD", noBagQuotLabel, false, cond, result) + def MOD(cond: Extract[A], result: Extract[A])(using Encoder[Extract[A]]): Mod[A] = + Mod[A]("MOD", noBagQuotLabel, false, cond, result) - def mod(cond: Extract[T], result: Extract[T])(using Encoder[Extract[T]]): Mod[T] = - Mod[T]("%", noBagQuotLabel, false, cond, result) + def mod(cond: Extract[A], result: Extract[A])(using Encoder[Extract[A]]): Mod[A] = + Mod[A]("%", noBagQuotLabel, false, cond, result) /** * A function to set the WHERE condition for modulo operations in a SELECT statement. @@ -409,10 +463,10 @@ trait Column[T]: * A query to check the modulo operation in a Where statement */ @targetName("_mod") - def %(cond: Extract[T], result: Extract[T])(using Encoder[Extract[T]]): Mod[T] = mod(cond, result) + def %(cond: Extract[A], result: Extract[A])(using Encoder[Extract[A]]): Mod[A] = mod(cond, result) - def bitXOR(value: Extract[T])(using Encoder[Extract[T]]): BitXOR[T] = - BitXOR[T](noBagQuotLabel, false, value) + def bitXOR(value: Extract[A])(using Encoder[Extract[A]]): BitXOR[A] = + BitXOR[A](noBagQuotLabel, false, value) /** * A function to set the WHERE condition for bitwise XOR operations in a SELECT statement. @@ -428,10 +482,10 @@ trait Column[T]: * A query to check the bitwise XOR operation in a Where statement */ @targetName("_bitXOR") - def ^(value: Extract[T])(using Encoder[Extract[T]]): BitXOR[T] = bitXOR(value) + def ^(value: Extract[A])(using Encoder[Extract[A]]): BitXOR[A] = bitXOR(value) - def bitFlip(value: Extract[T])(using Encoder[Extract[T]]): BitFlip[T] = - BitFlip[T](noBagQuotLabel, false, value) + def bitFlip(value: Extract[A])(using Encoder[Extract[A]]): BitFlip[A] = + BitFlip[A](noBagQuotLabel, false, value) /** * A function to set the WHERE condition for bitwise NOT operations in a SELECT statement. @@ -447,9 +501,10 @@ trait Column[T]: * A query to check the bitwise NOT operation in a Where statement */ @targetName("_bitFlip") - def ~(value: Extract[T])(using Encoder[Extract[T]]): BitFlip[T] = bitFlip(value) + def ~(value: Extract[A])(using Encoder[Extract[A]]): BitFlip[A] = bitFlip(value) - def combine(other: Column[T])(using Decoder.Elem[T]): Column.MultiColumn[T] = Column.MultiColumn[T]("+", this, other) + def combine(other: Column[A])(using Decoder.Elem[A], Encoder[A]): Column.MultiColumn[A] = + Column.MultiColumn[A]("+", this, other) /** * A function to combine columns in a SELECT statement. @@ -465,9 +520,10 @@ trait Column[T]: * A query to combine columns in a SELECT statement */ @targetName("_combine") - def ++(other: Column[T])(using Decoder.Elem[T]): Column.MultiColumn[T] = combine(other) + def ++(other: Column[A])(using Decoder.Elem[A], Encoder[A]): Column.MultiColumn[A] = combine(other) - def deduct(other: Column[T])(using Decoder.Elem[T]): Column.MultiColumn[T] = Column.MultiColumn[T]("-", this, other) + def deduct(other: Column[A])(using Decoder.Elem[A], Encoder[A]): Column.MultiColumn[A] = + Column.MultiColumn[A]("-", this, other) /** * A function to subtract columns in a SELECT statement. @@ -483,10 +539,10 @@ trait Column[T]: * A query to subtract columns in a SELECT statement */ @targetName("_deduct") - def --(other: Column[T])(using Decoder.Elem[T]): Column.MultiColumn[T] = deduct(other) + def --(other: Column[A])(using Decoder.Elem[A], Encoder[A]): Column.MultiColumn[A] = deduct(other) - def multiply(other: Column[T])(using Decoder.Elem[T]): Column.MultiColumn[T] = - Column.MultiColumn[T]("*", this, other) + def multiply(other: Column[A])(using Decoder.Elem[A], Encoder[A]): Column.MultiColumn[A] = + Column.MultiColumn[A]("*", this, other) /** * A function to multiply columns in a SELECT statement. @@ -502,9 +558,10 @@ trait Column[T]: * A query to multiply columns in a SELECT statement */ @targetName("_multiply") - def *(other: Column[T])(using Decoder.Elem[T]): Column.MultiColumn[T] = multiply(other) + def *(other: Column[A])(using Decoder.Elem[A], Encoder[A]): Column.MultiColumn[A] = multiply(other) - def smash(other: Column[T])(using Decoder.Elem[T]): Column.MultiColumn[T] = Column.MultiColumn[T]("/", this, other) + def smash(other: Column[A])(using Decoder.Elem[A], Encoder[A]): Column.MultiColumn[A] = + Column.MultiColumn[A]("/", this, other) /** * A function to divide columns in a SELECT statement. @@ -520,11 +577,11 @@ trait Column[T]: * A query to divide columns in a SELECT statement */ @targetName("_smash") - def /(other: Column[T])(using Decoder.Elem[T]): Column.MultiColumn[T] = smash(other) + def /(other: Column[A])(using Decoder.Elem[A], Encoder[A]): Column.MultiColumn[A] = smash(other) /** List of sub query methods */ - def _equals(value: SQL): SubQuery[T] = - SubQuery[T]("=", noBagQuotLabel, value) + def _equals(value: SQL): SubQuery[A] = + SubQuery[A]("=", noBagQuotLabel, value) /** * A function to perform a comparison with the column of interest using a subquery. @@ -541,40 +598,40 @@ trait Column[T]: * A query to compare with the column of interest using a subquery */ @targetName("subQueryEquals") - def ===(value: SQL): SubQuery[T] = _equals(value) + def ===(value: SQL): SubQuery[A] = _equals(value) - def orMore(value: SQL): SubQuery[T] = - SubQuery[T](">=", noBagQuotLabel, value) + def orMore(value: SQL): SubQuery[A] = + SubQuery[A](">=", noBagQuotLabel, value) @targetName("subQueryOrMore") - def >=(value: SQL): SubQuery[T] = orMore(value) + def >=(value: SQL): SubQuery[A] = orMore(value) - def over(value: SQL): SubQuery[T] = - SubQuery[T](">", noBagQuotLabel, value) + def over(value: SQL): SubQuery[A] = + SubQuery[A](">", noBagQuotLabel, value) @targetName("subQueryOver") - def >(value: SQL): SubQuery[T] = over(value) + def >(value: SQL): SubQuery[A] = over(value) - def lessThanOrEqual(value: SQL): SubQuery[T] = - SubQuery[T]("<=", noBagQuotLabel, value) + def lessThanOrEqual(value: SQL): SubQuery[A] = + SubQuery[A]("<=", noBagQuotLabel, value) @targetName("subQueryLessThanOrEqual") - def <=(value: SQL): SubQuery[T] = lessThanOrEqual(value) + def <=(value: SQL): SubQuery[A] = lessThanOrEqual(value) - def lessThan(value: SQL): SubQuery[T] = - SubQuery[T]("<", noBagQuotLabel, value) + def lessThan(value: SQL): SubQuery[A] = + SubQuery[A]("<", noBagQuotLabel, value) @targetName("subQueryLessThan") - def <(value: SQL): SubQuery[T] = lessThan(value) + def <(value: SQL): SubQuery[A] = lessThan(value) - def notEqual(value: SQL): SubQuery[T] = - SubQuery[T]("<>", noBagQuotLabel, value) + def notEqual(value: SQL): SubQuery[A] = + SubQuery[A]("<>", noBagQuotLabel, value) @targetName("subQueryNotEqual") - def <>(value: SQL): SubQuery[T] = notEqual(value) + def <>(value: SQL): SubQuery[A] = notEqual(value) - def IN(value: SQL): SubQuery[T] = - SubQuery[T]("IN", noBagQuotLabel, value) + def IN(value: SQL): SubQuery[A] = + SubQuery[A]("IN", noBagQuotLabel, value) /** List of join query methods */ def _equals(other: Column[?]): Expression = JoinQuery("=", this, other) @@ -614,133 +671,113 @@ trait Column[T]: object Column extends TwiddleSyntax[Column]: - type Extract[T] <: Tuple = T match + type Extract[A] <: Tuple = A match case Column[t] => t *: EmptyTuple case Column[t] *: EmptyTuple => t *: EmptyTuple case Column[t] *: ts => t *: Extract[ts] - given Applicative[Column] with - override def pure[A](x: A): Column[A] = Pure(x) - override def ap[A, B](ff: Column[A => B])(fa: Column[A]): Column[B] = - new Column[B]: - override def name: String = if ff.name.isEmpty then fa.name else s"${ ff.name }, ${ fa.name }" - override def alias: Option[String] = - (ff.alias, fa.alias) match - case (Some(ff), Some(fa)) => Some(s"$ff, $fa") - case (Some(ff), None) => Some(s"$ff") - case (None, Some(fa)) => Some(s"$fa") - case (None, None) => None - override def as(name: String): Column[B] = this - override def decoder: Decoder[B] = new Decoder[B]((resultSet: ResultSet, prefix: Option[String]) => - ff.decoder.decode(resultSet, prefix)(fa.decoder.decode(resultSet, prefix)) - ) - override def updateStatement: String = - if ff.name.isEmpty then fa.updateStatement else s"${ ff.updateStatement }, ${ fa.updateStatement }" - override def duplicateKeyUpdateStatement: String = - if ff.name.isEmpty then fa.duplicateKeyUpdateStatement - else s"${ ff.duplicateKeyUpdateStatement }, ${ fa.duplicateKeyUpdateStatement }" - override def values: Int = ff.values + fa.values - override def opt: Column[Option[B]] = - val decoder = new Decoder[Option[B]]((resultSet: ResultSet, prefix: Option[String]) => - for - v1 <- ff.opt.decoder.decode(resultSet, prefix) - v2 <- fa.opt.decoder.decode(resultSet, prefix) - yield v1(v2) - ) - Impl[Option[B]](name, alias, decoder, Some(values), Some(updateStatement)) - override private[ldbc] def list: List[Column[?]] = - if ff.name.isEmpty then fa.list else ff.list ++ fa.list - - case class Pure[T](value: T) extends Column[T]: + given InvariantSemigroupal[Column] with + override def imap[A, B](fa: Column[A])(f: A => B)(g: B => A): Column[B] = fa.imap(f)(g) + override def product[A, B](fa: Column[A], fb: Column[B]): Column[(A, B)] = fa product fb + + case class Pure[A](value: A) extends Column[A]: override def name: String = "" override def alias: Option[String] = None - override def as(name: String): Column[T] = this - override def decoder: Decoder[T] = - new Decoder[T]((resultSet: ResultSet, prefix: Option[String]) => value) + override def as(name: String): Column[A] = this + override def decoder: Decoder[A] = + new Decoder[A]((resultSet: ResultSet, prefix: Option[String]) => value) + override def encoder: Encoder[A] = (value: A) => Encoder.Encoded.success(List.empty) override def insertStatement: String = "" override def updateStatement: String = "" override def duplicateKeyUpdateStatement: String = "" override def values: Int = 0 override private[ldbc] def list: List[Column[?]] = List.empty - def apply[T](name: String)(using elem: Decoder.Elem[T]): Column[T] = - Impl[T](name) + def apply[A](name: String)(using elem: Decoder.Elem[A], encoder: Encoder[A]): Column[A] = + Impl[A](name) - def apply[T](name: String, alias: String)(using elem: Decoder.Elem[T]): Column[T] = - Impl[T](name, s"$alias.$name") + def apply[A](name: String, alias: String)(using elem: Decoder.Elem[A], encoder: Encoder[A]): Column[A] = + Impl[A](name, s"$alias.$name") - private[ldbc] case class Impl[T]( + private[ldbc] case class Impl[A]( name: String, alias: Option[String], - decoder: Decoder[T], + decoder: Decoder[A], + encoder: Encoder[A], length: Option[Int] = None, update: Option[String] = None - ) extends Column[T]: - override def as(name: String): Column[T] = + ) extends Column[A]: + override def as(name: String): Column[A] = this.copy( alias = Some(name), decoder = - new Decoder[T]((resultSet: ResultSet, prefix: Option[String]) => decoder.decode(resultSet, Some(name))) + new Decoder[A]((resultSet: ResultSet, prefix: Option[String]) => decoder.decode(resultSet, Some(name))) ) override def values: Int = length.getOrElse(1) override def updateStatement: String = update.getOrElse(s"$name = ?") override def duplicateKeyUpdateStatement: String = s"$name = VALUES(${ alias.getOrElse(name) })" object Impl: - def apply[T](name: String)(using elem: Decoder.Elem[T]): Column[T] = - val decoder = new Decoder[T]((resultSet: ResultSet, prefix: Option[String]) => + def apply[A](name: String)(using elem: Decoder.Elem[A], encoder: Encoder[A]): Column[A] = + val decoder = new Decoder[A]((resultSet: ResultSet, prefix: Option[String]) => val column = prefix.map(_ + ".").getOrElse("") + name elem.decode(resultSet, column) ) - Impl[T](name, None, decoder) + Impl[A](name, None, decoder, encoder) - def apply[T](name: String, alias: String)(using elem: Decoder.Elem[T]): Column[T] = - val decoder = new Decoder[T]((resultSet: ResultSet, prefix: Option[String]) => + def apply[A](name: String, alias: String)(using elem: Decoder.Elem[A], encoder: Encoder[A]): Column[A] = + val decoder = new Decoder[A]((resultSet: ResultSet, prefix: Option[String]) => elem.decode(resultSet, prefix.getOrElse(alias)) ) - Impl[T](name, Some(alias), decoder) + Impl[A](name, Some(alias), decoder, encoder) - private[ldbc] case class Opt[T]( + private[ldbc] case class Opt[A]( name: String, alias: Option[String], - _decoder: Decoder[T] - ) extends Column[Option[T]]: - override def as(name: String): Column[Option[T]] = Opt[T](this.name, Some(s"$name.${ this.name }"), _decoder) - override def decoder: Decoder[Option[T]] = - new Decoder[Option[T]]((resultSet: ResultSet, prefix: Option[String]) => - Option(_decoder.decode(resultSet, prefix.orElse(alias))) - ) + _decoder: Decoder[A], + _encoder: Encoder[A] + ) extends Column[Option[A]]: + override def as(name: String): Column[Option[A]] = this.copy(alias = Some(s"$name.${ this.name }")) + override def decoder: Decoder[Option[A]] = _decoder.opt + override def encoder: Encoder[Option[A]] = { + case Some(v) => _encoder.encode(v) + case None => Encoder.Encoded.success(List(None)) + } override def updateStatement: String = s"$name = ?" override def duplicateKeyUpdateStatement: String = s"$name = VALUES(${ alias.getOrElse(name) })" - private[ldbc] case class MultiColumn[T]( + private[ldbc] case class MultiColumn[A]( flag: String, - left: Column[T], - right: Column[T] - )(using elem: Decoder.Elem[T]) - extends Column[T]: + left: Column[A], + right: Column[A] + )(using elem: Decoder.Elem[A], _encoder: Encoder[A]) + extends Column[A]: override def name: String = s"${ left.noBagQuotLabel } $flag ${ right.noBagQuotLabel }" override def alias: Option[String] = Some( s"${ left.alias.getOrElse(left.name) } $flag ${ right.alias.getOrElse(right.name) }" ) - override def as(name: String): Column[T] = this - override def decoder: Decoder[T] = - new Decoder[T]((resultSet: ResultSet, prefix: Option[String]) => + override def as(name: String): Column[A] = this + override def decoder: Decoder[A] = + new Decoder[A]((resultSet: ResultSet, prefix: Option[String]) => elem.decode(resultSet, prefix.map(_ + ".").getOrElse("") + name) ) - override def insertStatement: String = "" - override def updateStatement: String = "" - override def duplicateKeyUpdateStatement: String = "" - - private[ldbc] case class Count(_name: String, _alias: Option[String])(using elem: Decoder.Elem[Int]) - extends Column[Int]: + override def encoder: Encoder[A] = _encoder + override def insertStatement: String = "" + override def updateStatement: String = "" + override def duplicateKeyUpdateStatement: String = "" + + private[ldbc] case class Count(_name: String, _alias: Option[String])(using + elem: Decoder.Elem[Int], + _encoder: Encoder[Int] + ) extends Column[Int]: override def name: String = s"COUNT($_name)" override def alias: Option[String] = _alias.map(a => s"COUNT($a)") override def as(name: String): Column[Int] = this.copy(s"$name.${ _name }") override def decoder: Decoder[Int] = new Decoder[Int]((resultSet: ResultSet, prefix: Option[String]) => elem.decode(resultSet, alias.getOrElse(name)) ) - override def toString: String = name - override def insertStatement: String = "" - override def updateStatement: String = "" - override def duplicateKeyUpdateStatement: String = "" + override def encoder: Encoder[Int] = _encoder + override def toString: String = name + override def insertStatement: String = "" + override def updateStatement: String = "" + override def duplicateKeyUpdateStatement: String = "" diff --git a/module/ldbc-statement/src/main/scala/ldbc/statement/Expression.scala b/module/ldbc-statement/src/main/scala/ldbc/statement/Expression.scala index 1c8d10ddb..a90baa9e8 100644 --- a/module/ldbc-statement/src/main/scala/ldbc/statement/Expression.scala +++ b/module/ldbc-statement/src/main/scala/ldbc/statement/Expression.scala @@ -158,10 +158,10 @@ object Expression: /** comparison operator */ private[ldbc] case class MatchCondition[T](column: String, isNot: Boolean, value: Extract[T])(using - Encoder[Extract[T]] + encoder: Encoder[Extract[T]] ) extends SingleValue[T]: override def flag: String = "=" - override def parameter: List[Parameter.Dynamic] = List(Parameter.Dynamic(value)) + override def parameter: List[Parameter.Dynamic] = Parameter.Dynamic.many(encoder.encode(value)) override def statement: String = val not = if isNot then "NOT " else "" s"$not$column $flag ?" @@ -169,11 +169,11 @@ object Expression: def NOT: MatchCondition[T] = MatchCondition[T](this.column, true, this.value) private[ldbc] case class OrMore[T](column: String, isNot: Boolean, value: Extract[T])(using - Encoder[Extract[T]] + encoder: Encoder[Extract[T]] ) extends SingleValue[T]: override def flag: String = ">=" - override def parameter: List[Parameter.Dynamic] = List(Parameter.Dynamic(value)) + override def parameter: List[Parameter.Dynamic] = Parameter.Dynamic.many(encoder.encode(value)) override def statement: String = val not = if isNot then "NOT " else "" @@ -182,11 +182,11 @@ object Expression: def NOT: OrMore[T] = OrMore[T](this.column, true, this.value) private[ldbc] case class Over[T](column: String, isNot: Boolean, value: Extract[T])(using - Encoder[Extract[T]] + encoder: Encoder[Extract[T]] ) extends SingleValue[T]: override def flag: String = ">" - override def parameter: List[Parameter.Dynamic] = List(Parameter.Dynamic(value)) + override def parameter: List[Parameter.Dynamic] = Parameter.Dynamic.many(encoder.encode(value)) override def statement: String = val not = if isNot then "NOT " else "" @@ -195,11 +195,11 @@ object Expression: def NOT: Over[T] = Over[T](this.column, true, this.value) private[ldbc] case class LessThanOrEqualTo[T](column: String, isNot: Boolean, value: Extract[T])(using - Encoder[Extract[T]] + encoder: Encoder[Extract[T]] ) extends SingleValue[T]: override def flag: String = "<=" - override def parameter: List[Parameter.Dynamic] = List(Parameter.Dynamic(value)) + override def parameter: List[Parameter.Dynamic] = Parameter.Dynamic.many(encoder.encode(value)) override def statement: String = val not = if isNot then "NOT " else "" @@ -208,11 +208,11 @@ object Expression: def NOT: LessThanOrEqualTo[T] = LessThanOrEqualTo[T](this.column, true, this.value) private[ldbc] case class LessThan[T](column: String, isNot: Boolean, value: Extract[T])(using - Encoder[Extract[T]] + encoder: Encoder[Extract[T]] ) extends SingleValue[T]: override def flag: String = "<" - override def parameter: List[Parameter.Dynamic] = List(Parameter.Dynamic(value)) + override def parameter: List[Parameter.Dynamic] = Parameter.Dynamic.many(encoder.encode(value)) override def statement: String = val not = if isNot then "NOT " else "" @@ -221,9 +221,9 @@ object Expression: def NOT: LessThan[T] = LessThan[T](this.column, true, this.value) private[ldbc] case class NotEqual[T](flag: String, column: String, isNot: Boolean, value: Extract[T])(using - Encoder[Extract[T]] + encoder: Encoder[Extract[T]] ) extends SingleValue[T]: - override def parameter: List[Parameter.Dynamic] = List(Parameter.Dynamic(value)) + override def parameter: List[Parameter.Dynamic] = Parameter.Dynamic.many(encoder.encode(value)) override def statement: String = val not = if isNot then "NOT " else "" @@ -247,11 +247,11 @@ object Expression: def NOT: Is[T] = Is[T](this.column, true, this.value) private[ldbc] case class NullSafeEqual[T](column: String, isNot: Boolean, value: Extract[T])(using - Encoder[Extract[T]] + encoder: Encoder[Extract[T]] ) extends SingleValue[T]: override def flag: String = "<=>" - override def parameter: List[Parameter.Dynamic] = List(Parameter.Dynamic(value)) + override def parameter: List[Parameter.Dynamic] = Parameter.Dynamic.many(encoder.encode(value)) override def statement: String = val not = if isNot then "NOT " else "" @@ -260,11 +260,12 @@ object Expression: def NOT: NullSafeEqual[T] = NullSafeEqual[T](this.column, true, this.value) private[ldbc] case class In[T](column: String, isNot: Boolean, values: Extract[T]*)(using - Encoder[Extract[T]] + encoder: Encoder[Extract[T]] ) extends MultiValue[T]: override def flag: String = "IN" - override def parameter: List[Parameter.Dynamic] = values.map(Parameter.Dynamic(_)).toList + override def parameter: List[Parameter.Dynamic] = + values.flatMap(value => Parameter.Dynamic.many(encoder.encode(value))).toList override def statement: String = val not = if isNot then "NOT " else "" @@ -273,11 +274,12 @@ object Expression: def NOT: In[T] = In[T](this.column, true, this.values*) private[ldbc] case class Between[T](column: String, isNot: Boolean, values: Extract[T]*)(using - Encoder[Extract[T]] + encoder: Encoder[Extract[T]] ) extends MultiValue[T]: override def flag: String = "BETWEEN" - override def parameter: List[Parameter.Dynamic] = values.map(Parameter.Dynamic(_)).toList + override def parameter: List[Parameter.Dynamic] = + values.flatMap(value => Parameter.Dynamic.many(encoder.encode(value))).toList override def statement: String = val not = if isNot then "NOT " else "" @@ -286,11 +288,11 @@ object Expression: def NOT: Between[T] = Between[T](this.column, true, this.values*) private[ldbc] case class Like[T](column: String, isNot: Boolean, value: Extract[T])(using - Encoder[Extract[T]] + encoder: Encoder[Extract[T]] ) extends SingleValue[T]: override def flag: String = "LIKE" - override def parameter: List[Parameter.Dynamic] = List(Parameter.Dynamic(value)) + override def parameter: List[Parameter.Dynamic] = Parameter.Dynamic.many(encoder.encode(value)) override def statement: String = val not = if isNot then "NOT " else "" @@ -299,11 +301,12 @@ object Expression: def NOT: Like[T] = Like[T](this.column, true, this.value) private[ldbc] case class LikeEscape[T](column: String, isNot: Boolean, values: Extract[T]*)(using - Encoder[Extract[T]] + encoder: Encoder[Extract[T]] ) extends MultiValue[T]: override def flag: String = "LIKE" - override def parameter: List[Parameter.Dynamic] = values.map(Parameter.Dynamic(_)).toList + override def parameter: List[Parameter.Dynamic] = + values.flatMap(value => Parameter.Dynamic.many(encoder.encode(value))).toList override def statement: String = val not = if isNot then "NOT " else "" @@ -312,11 +315,11 @@ object Expression: def NOT: LikeEscape[T] = LikeEscape[T](this.column, true, this.values*) private[ldbc] case class Regexp[T](column: String, isNot: Boolean, value: Extract[T])(using - Encoder[Extract[T]] + encoder: Encoder[Extract[T]] ) extends SingleValue[T]: override def flag: String = "REGEXP" - override def parameter: List[Parameter.Dynamic] = List(Parameter.Dynamic(value)) + override def parameter: List[Parameter.Dynamic] = Parameter.Dynamic.many(encoder.encode(value)) override def statement: String = val not = if isNot then "NOT " else "" @@ -325,11 +328,11 @@ object Expression: def NOT: Regexp[T] = Regexp[T](this.column, true, this.value) private[ldbc] case class LeftShift[T](column: String, isNot: Boolean, value: Extract[T])(using - Encoder[Extract[T]] + encoder: Encoder[Extract[T]] ) extends SingleValue[T]: override def flag: String = "<<" - override def parameter: List[Parameter.Dynamic] = List(Parameter.Dynamic(value)) + override def parameter: List[Parameter.Dynamic] = Parameter.Dynamic.many(encoder.encode(value)) override def statement: String = val not = if isNot then "NOT " else "" @@ -338,11 +341,11 @@ object Expression: def NOT: LeftShift[T] = LeftShift[T](this.column, true, this.value) private[ldbc] case class RightShift[T](column: String, isNot: Boolean, value: Extract[T])(using - Encoder[Extract[T]] + encoder: Encoder[Extract[T]] ) extends SingleValue[T]: override def flag: String = ">>" - override def parameter: List[Parameter.Dynamic] = List(Parameter.Dynamic(value)) + override def parameter: List[Parameter.Dynamic] = Parameter.Dynamic.many(encoder.encode(value)) override def statement: String = val not = if isNot then "NOT " else "" @@ -351,11 +354,12 @@ object Expression: def NOT: RightShift[T] = RightShift[T](this.column, true, this.value) private[ldbc] case class Div[T](column: String, isNot: Boolean, values: Extract[T]*)(using - Encoder[Extract[T]] + encoder: Encoder[Extract[T]] ) extends MultiValue[T]: override def flag: String = "DIV" - override def parameter: List[Parameter.Dynamic] = values.map(Parameter.Dynamic(_)).toList + override def parameter: List[Parameter.Dynamic] = + values.flatMap(value => Parameter.Dynamic.many(encoder.encode(value))).toList override def statement: String = val not = if isNot then "NOT " else "" @@ -364,9 +368,10 @@ object Expression: def NOT: Div[T] = Div[T](this.column, true, this.values*) private[ldbc] case class Mod[T](flag: String, column: String, isNot: Boolean, values: Extract[T]*)(using - Encoder[Extract[T]] + encoder: Encoder[Extract[T]] ) extends MultiValue[T]: - override def parameter: List[Parameter.Dynamic] = values.map(Parameter.Dynamic(_)).toList + override def parameter: List[Parameter.Dynamic] = + values.flatMap(value => Parameter.Dynamic.many(encoder.encode(value))).toList override def statement: String = val not = if isNot then "NOT " else "" @@ -375,11 +380,11 @@ object Expression: def NOT: Mod[T] = Mod[T](this.flag, this.column, true, this.values*) private[ldbc] case class BitXOR[T](column: String, isNot: Boolean, value: Extract[T])(using - Encoder[Extract[T]] + encoder: Encoder[Extract[T]] ) extends SingleValue[T]: override def flag: String = "^" - override def parameter: List[Parameter.Dynamic] = List(Parameter.Dynamic(value)) + override def parameter: List[Parameter.Dynamic] = Parameter.Dynamic.many(encoder.encode(value)) override def statement: String = val not = if isNot then "NOT " else "" @@ -388,11 +393,11 @@ object Expression: def NOT: BitXOR[T] = BitXOR[T](this.column, true, this.value) private[ldbc] case class BitFlip[T](column: String, isNot: Boolean, value: Extract[T])(using - Encoder[Extract[T]] + encoder: Encoder[Extract[T]] ) extends SingleValue[T]: override def flag: String = "~" - override def parameter: List[Parameter.Dynamic] = List(Parameter.Dynamic(value)) + override def parameter: List[Parameter.Dynamic] = Parameter.Dynamic.many(encoder.encode(value)) override def statement: String = val not = if isNot then "NOT " else "" diff --git a/module/ldbc-statement/src/main/scala/ldbc/statement/Insert.scala b/module/ldbc-statement/src/main/scala/ldbc/statement/Insert.scala index 3be384f72..53253084c 100644 --- a/module/ldbc-statement/src/main/scala/ldbc/statement/Insert.scala +++ b/module/ldbc-statement/src/main/scala/ldbc/statement/Insert.scala @@ -9,8 +9,6 @@ package ldbc.statement import scala.annotation.targetName import ldbc.dsl.{ Parameter, SQL } -import ldbc.dsl.codec.Encoder -import ldbc.statement.interpreter.ToTuple /** * Trait for building Statements to be added. @@ -32,7 +30,7 @@ sealed trait Insert[A] extends Command: * .onDuplicateKeyUpdate(_.name) * }}} */ - def onDuplicateKeyUpdate[B](columns: A => Column[B]): Insert.DuplicateKeyUpdate[A] + def onDuplicateKeyUpdate[B](func: A => Column[B]): Insert.DuplicateKeyUpdate[A] /** * Methods for constructing INSERT ... ON DUPLICATE KEY UPDATE statements. @@ -43,7 +41,7 @@ sealed trait Insert[A] extends Command: * .onDuplicateKeyUpdate(_.name, "Osaka") * }}} */ - def onDuplicateKeyUpdate[B](columns: A => Column[B], value: B)(using Encoder[B]): Insert.DuplicateKeyUpdate[A] + def onDuplicateKeyUpdate[B](func: A => Column[B], value: B): Insert.DuplicateKeyUpdate[A] object Insert: @@ -52,20 +50,19 @@ object Insert: @targetName("combine") override def ++(sql: SQL): SQL = this.copy(statement = statement ++ sql.statement, params = params ++ sql.params) - override def onDuplicateKeyUpdate[B](columns: A => Column[B]): Insert.DuplicateKeyUpdate[A] = + override def onDuplicateKeyUpdate[B](func: A => Column[B]): Insert.DuplicateKeyUpdate[A] = Insert.DuplicateKeyUpdate( table, - s"$statement ON DUPLICATE KEY UPDATE ${ columns(table).duplicateKeyUpdateStatement }", + s"$statement ON DUPLICATE KEY UPDATE ${ func(table).duplicateKeyUpdateStatement }", params ) - override def onDuplicateKeyUpdate[B](columns: A => Column[B], value: B)(using - Encoder[B] - ): Insert.DuplicateKeyUpdate[A] = + override def onDuplicateKeyUpdate[B](func: A => Column[B], value: B): Insert.DuplicateKeyUpdate[A] = + val columns = func(table) Insert.DuplicateKeyUpdate( table, - s"$statement ON DUPLICATE KEY UPDATE ${ columns(table).name } = ?", - params :+ Parameter.Dynamic(value) + s"$statement ON DUPLICATE KEY UPDATE ${ columns.name } = ?", + params ++ Parameter.Dynamic.many(columns.encoder.encode(value)) ) case class Into[A, B](table: A, statement: String, columns: Column[B]): @@ -82,24 +79,14 @@ object Insert: * @param values * The values to be inserted. */ - inline def values(values: B*): Values[A] = - val parameterBinders = values - .map { - case h *: EmptyTuple => h *: EmptyTuple - case h *: t => h *: t - case h => h *: EmptyTuple - } - .flatMap( - _.zip(Encoder.fold[ToTuple[B]]).toList - .map { - case (value, encoder) => Parameter.Dynamic(value)(using encoder.asInstanceOf[Encoder[Any]]) - } - .toList - ) + def values(values: B*): Values[A] = + val parameterBinders: List[Parameter.Dynamic] = values.flatMap { value => + Parameter.Dynamic.many(columns.encoder.encode(value)) + }.toList Values( table, s"$statement (${ columns.name }) VALUES ${ List.fill(values.length)(s"(${ List.fill(columns.values)("?").mkString(",") })").mkString(",") }", - parameterBinders.toList + parameterBinders ) /** @@ -160,20 +147,19 @@ object Insert: @targetName("combine") override def ++(sql: SQL): SQL = this.copy(statement = statement ++ sql.statement, params = params ++ sql.params) - override def onDuplicateKeyUpdate[C](columns: A => Column[C]): Insert.DuplicateKeyUpdate[A] = + override def onDuplicateKeyUpdate[C](func: A => Column[C]): Insert.DuplicateKeyUpdate[A] = Insert.DuplicateKeyUpdate( table, - s"$statement ON DUPLICATE KEY UPDATE ${ columns(table).duplicateKeyUpdateStatement }", + s"$statement ON DUPLICATE KEY UPDATE ${ func(table).duplicateKeyUpdateStatement }", params ) - override def onDuplicateKeyUpdate[B](columns: A => Column[B], value: B)(using - Encoder[B] - ): Insert.DuplicateKeyUpdate[A] = + override def onDuplicateKeyUpdate[B](func: A => Column[B], value: B): Insert.DuplicateKeyUpdate[A] = + val columns = func(table) Insert.DuplicateKeyUpdate( table, - s"$statement ON DUPLICATE KEY UPDATE ${ columns(table).name } = ?", - params :+ Parameter.Dynamic(value) + s"$statement ON DUPLICATE KEY UPDATE ${ columns.name } = ?", + params ++ Parameter.Dynamic.many(columns.encoder.encode(value)) ) case class DuplicateKeyUpdate[A](table: A, statement: String, params: List[Parameter.Dynamic]) extends Command: diff --git a/module/ldbc-statement/src/main/scala/ldbc/statement/Limit.scala b/module/ldbc-statement/src/main/scala/ldbc/statement/Limit.scala index 72d6f2cb3..577029ada 100644 --- a/module/ldbc-statement/src/main/scala/ldbc/statement/Limit.scala +++ b/module/ldbc-statement/src/main/scala/ldbc/statement/Limit.scala @@ -56,7 +56,7 @@ object Limit: table = table, columns = columns, statement = statement ++ " OFFSET ?", - params = params :+ Parameter.Dynamic(length) + params = params ++ Parameter.Dynamic.many(summon[Encoder[Long]].encode(length)) ) transparent trait QueryProvider[A, B]: @@ -79,7 +79,7 @@ object Limit: table = self.table, columns = self.columns, statement = self.statement ++ " LIMIT ?", - params = self.params :+ Parameter.Dynamic(length) + params = self.params ++ Parameter.Dynamic.many(summon[Encoder[Long]].encode(length)) ) /** @@ -118,5 +118,5 @@ object Limit: def limit(length: Long): Encoder[Long] ?=> Limit.C = Limit.C( statement = statement ++ " LIMIT ?", - params = params :+ Parameter.Dynamic(length) + params = params ++ Parameter.Dynamic.many(summon[Encoder[Long]].encode(length)) ) diff --git a/module/ldbc-statement/src/main/scala/ldbc/statement/TableQuery.scala b/module/ldbc-statement/src/main/scala/ldbc/statement/TableQuery.scala index a2bb1647b..f62fab461 100644 --- a/module/ldbc-statement/src/main/scala/ldbc/statement/TableQuery.scala +++ b/module/ldbc-statement/src/main/scala/ldbc/statement/TableQuery.scala @@ -13,7 +13,6 @@ import scala.annotation.targetName import ldbc.dsl.Parameter import ldbc.dsl.codec.Encoder import ldbc.statement.internal.QueryConcat -import ldbc.statement.interpreter.ToTuple /** * Trait for constructing SQL Statement from Table information. @@ -116,12 +115,9 @@ trait TableQuery[A, O]: inline this match case Join.On(_, _, _, _, _) => error("Join Query does not yet support Insert processing.") case _ => - val parameterBinders = values - .flatMap(_.zip(Encoder.fold[mirror.MirroredElemTypes]).toList) - .map { - case (value, encoder) => Parameter.Dynamic(value)(using encoder.asInstanceOf[Encoder[Any]]) - } - .toList + val parameterBinders: List[Parameter.Dynamic] = values.flatMap { value => + Parameter.Dynamic.many(column.encoder.asInstanceOf[Encoder[mirror.MirroredElemTypes]].encode(value)) + }.toList Insert.Impl( table = table, statement = @@ -138,31 +134,17 @@ trait TableQuery[A, O]: * * @param value * Value to be inserted into the table - * @param mirror - * Mirror of Entity */ @targetName("insertProduct") - inline def +=(value: Entity)(using mirror: Mirror.Of[Entity]): Insert[A] = + inline def +=(value: Entity): Insert[A] = inline this match case Join.On(_, _, _, _, _) => error("Join Query does not yet support Insert processing.") case _ => - inline mirror match - case s: Mirror.SumOf[Entity] => error("Sum type is not supported.") - case p: Mirror.ProductOf[Entity] => derivedProduct(value, p) - - private inline def derivedProduct[P](value: P, mirror: Mirror.ProductOf[P]): Insert[A] = - val tuples = Tuple.fromProduct(value.asInstanceOf[Product]).asInstanceOf[mirror.MirroredElemTypes] - val parameterBinders = tuples - .zip(Encoder.fold[mirror.MirroredElemTypes]) - .toList - .map { - case (value, encoder) => Parameter.Dynamic(value)(using encoder.asInstanceOf[Encoder[Any]]) - } - Insert.Impl( - table = table, - statement = s"INSERT INTO $name ${ column.insertStatement }", - params = params ++ parameterBinders - ) + Insert.Impl( + table = table, + statement = s"INSERT INTO $name ${ column.insertStatement }", + params = params ++ Parameter.Dynamic.many(column.encoder.encode(value)) + ) /** * Method to construct a query to insert a table. @@ -173,16 +155,20 @@ trait TableQuery[A, O]: * * @param values * Value to be inserted into the table - * @param check - * Check if the type of the value is the same as the Entity - * @tparam P - * Scala types to be converted by Encoder */ @targetName("insertProducts") - inline def ++=[P <: Product](values: List[P])(using check: P =:= Entity): Insert[A] = + inline def ++=(values: List[Entity]): Insert[A] = inline this match case Join.On(_, _, _, _, _) => error("Join Query does not yet support Insert processing.") - case _ => TableQueryMacro.++=[A, P](table, name, column.asInstanceOf[Column[P]], params, values) + case _ => + Insert.Impl( + table = table, + statement = + s"INSERT INTO $name (${ column.name }) VALUES ${ values.map(_ => s"(${ List.fill(column.values)("?").mkString(",") })").mkString(",") }", + params = params ++ values.flatMap { value => + Parameter.Dynamic.many(column.encoder.encode(value)) + } + ) /** * Method to construct a query to update a table. @@ -203,17 +189,8 @@ trait TableQuery[A, O]: inline this match case Join.On(_, _, _, _, _) => error("Join Query does not yet support Update processing.") case _ => - val columns = func(table) - val parameterBinders = (values match - case h *: EmptyTuple => h *: EmptyTuple - case h *: t => h *: t - case h => h *: EmptyTuple - ) - .zip(Encoder.fold[ToTuple[C]]) - .toList - .map { - case (value, encoder) => Parameter.Dynamic(value)(using encoder.asInstanceOf[Encoder[Any]]) - } + val columns = func(table) + val parameterBinders = Parameter.Dynamic.many(columns.encoder.encode(values)) Update.Impl[A](table, s"UPDATE $name SET ${ columns.updateStatement }", params ++ parameterBinders) /** @@ -226,21 +203,9 @@ trait TableQuery[A, O]: * * @param value * Value to be updated in the table - * @param mirror - * Mirror of Entity - * @param check - * Check if the type of the value is the same as the Entity - * @tparam P - * Scala types to be converted by Encoder */ - inline def update[P <: Product](value: P)(using mirror: Mirror.ProductOf[P], check: P =:= Entity): Update[A] = - val parameterBinders = Tuple - .fromProductTyped(value) - .zip(Encoder.fold[mirror.MirroredElemTypes]) - .toList - .map { - case (value, encoder) => Parameter.Dynamic(value)(using encoder.asInstanceOf[Encoder[Any]]) - } + def update(value: Entity): Update[A] = + val parameterBinders = Parameter.Dynamic.many(column.encoder.encode(value)) val statement = s"UPDATE $name SET ${ column.updateStatement }" Update.Impl[A](table, statement, params ++ parameterBinders) @@ -330,68 +295,3 @@ object TableQuery: type Extract[T] = T match case AbstractTable[t] => t case AbstractTable[t] *: tn => t *: Extract[tn] - -private[ldbc] object TableQueryMacro: - - import scala.quoted.* - - @targetName("insertProducts") - private[ldbc] inline def ++=[A, B <: Product]( - table: A, - name: String, - column: Column[B], - params: List[Parameter.Dynamic], - values: List[B] - ): Insert[A] = - ${ derivedProducts('table, 'name, 'column, 'params, 'values) } - - private[ldbc] def derivedProducts[A: Type, B <: Product]( - table: Expr[A], - name: Expr[String], - column: Expr[Column[B]], - params: Expr[List[Parameter.Dynamic]], - values: Expr[List[B]] - )(using quotes: Quotes, tpe: Type[B]): Expr[Insert[A]] = - import quotes.reflect.* - - val symbol = TypeRepr.of[B].typeSymbol - - val encodes = Expr.ofSeq( - symbol.caseFields - .map { field => - field.tree match - case ValDef(name, tpt, _) => - tpt.tpe.asType match - case '[tpe] => - val encoder = Expr.summon[Encoder[tpe]].getOrElse { - report.errorAndAbort(s"Encoder for type $tpe not found") - } - encoder.asExprOf[Encoder[tpe]] - case _ => - report.errorAndAbort(s"Type $tpt is not a type") - } - ) - - val lists: Expr[List[Tuple]] = '{ - $values - .map(value => Tuple.fromProduct(value)) - } - - val parameterBinders = '{ - $lists.flatMap(list => - list.toList - .zip($encodes) - .map { - case (value, encoder) => Parameter.Dynamic(value)(using encoder.asInstanceOf[Encoder[Any]]) - } - ) - } - - '{ - Insert.Impl( - table = $table, - statement = - s"INSERT INTO ${ $name } (${ $column.name }) VALUES ${ $lists.map(list => s"(${ list.toList.map(_ => "?").mkString(",") })").mkString(",") }", - params = $params ++ $parameterBinders - ) - } diff --git a/module/ldbc-statement/src/main/scala/ldbc/statement/Update.scala b/module/ldbc-statement/src/main/scala/ldbc/statement/Update.scala index 2c4d620e7..4723c7098 100644 --- a/module/ldbc-statement/src/main/scala/ldbc/statement/Update.scala +++ b/module/ldbc-statement/src/main/scala/ldbc/statement/Update.scala @@ -9,7 +9,6 @@ package ldbc.statement import scala.annotation.targetName import ldbc.dsl.{ Parameter, SQL } -import ldbc.dsl.codec.Encoder /** * Trait for building Statements to be updated. @@ -36,7 +35,7 @@ sealed trait Update[A] extends Command: * @param value * The value to be set */ - def set[B](column: A => Column[B], value: B)(using Encoder[B]): Update[A] + def set[B](column: A => Column[B], value: B): Update[A] /** * Methods for setting the value of a column in a table. @@ -55,7 +54,7 @@ sealed trait Update[A] extends Command: * @param bool * A boolean value that determines whether to update */ - def set[B](column: A => Column[B], value: B, bool: Boolean)(using Encoder[B]): Update[A] + def set[B](column: A => Column[B], value: B, bool: Boolean): Update[A] /** * A method for setting the WHERE condition in a Update statement. @@ -83,13 +82,14 @@ object Update: @targetName("combine") override def ++(sql: SQL): SQL = this.copy(statement = statement ++ sql.statement, params = params ++ sql.params) - override def set[B](column: A => Column[B], value: B)(using Encoder[B]): Update[A] = + override def set[B](column: A => Column[B], value: B): Update[A] = + val columns = column(table) this.copy( - statement = statement ++ s", ${ column(table).updateStatement }", - params = params :+ Parameter.Dynamic(value) + statement = statement ++ s", ${ columns.updateStatement }", + params = params ++ Parameter.Dynamic.many(columns.encoder.encode(value)) ) - override def set[B](column: A => Column[B], value: B, bool: Boolean)(using Encoder[B]): Update[A] = + override def set[B](column: A => Column[B], value: B, bool: Boolean): Update[A] = if bool then set(column, value) else this override def where(func: A => Expression): Where.C[A] = diff --git a/module/ldbc-statement/src/main/scala/ldbc/statement/interpreter/ToTuple.scala b/module/ldbc-statement/src/main/scala/ldbc/statement/interpreter/ToTuple.scala deleted file mode 100644 index 51966c61d..000000000 --- a/module/ldbc-statement/src/main/scala/ldbc/statement/interpreter/ToTuple.scala +++ /dev/null @@ -1,12 +0,0 @@ -/** - * Copyright (c) 2023-2024 by Takahiko Tominaga - * This software is licensed under the MIT License (MIT). - * For more information see LICENSE or https://opensource.org/licenses/MIT - */ - -package ldbc.statement.interpreter - -type ToTuple[T] <: Tuple = T match - case h *: EmptyTuple => Tuple1[h] - case h *: t => h *: ToTuple[t] - case _ => Tuple1[T] diff --git a/tests/src/main/scala/ldbc/tests/model/Country.scala b/tests/src/main/scala/ldbc/tests/model/Country.scala index 4cbcfe2bf..5b8150e48 100644 --- a/tests/src/main/scala/ldbc/tests/model/Country.scala +++ b/tests/src/main/scala/ldbc/tests/model/Country.scala @@ -42,8 +42,7 @@ object Country: override def toString: String = value - given Encoder[Continent] with - override def encode(continent: Continent): String = continent.value + given Encoder[Continent] = Encoder[String].contramap(_.value) given Decoder.Elem[Continent] = Decoder.Elem.mapping[String, Continent](str => Continent.valueOf(str.replace(" ", "_"))) diff --git a/tests/src/main/scala/ldbc/tests/model/CountryLanguage.scala b/tests/src/main/scala/ldbc/tests/model/CountryLanguage.scala index b5e657a0c..25809d74d 100644 --- a/tests/src/main/scala/ldbc/tests/model/CountryLanguage.scala +++ b/tests/src/main/scala/ldbc/tests/model/CountryLanguage.scala @@ -25,8 +25,7 @@ object CountryLanguage: object IsOfficial - given Encoder[IsOfficial] with - override def encode(isOfficial: IsOfficial): String = isOfficial.toString + given Encoder[IsOfficial] = Encoder[String].contramap(_.toString) given Decoder.Elem[IsOfficial] = Decoder.Elem.mapping[String, IsOfficial](str => IsOfficial.valueOf(str)) diff --git a/tests/src/test/scala/ldbc/tests/TableQueryUpdateConnectionTest.scala b/tests/src/test/scala/ldbc/tests/TableQueryUpdateConnectionTest.scala index 1e6db46b3..feb5d0bb1 100644 --- a/tests/src/test/scala/ldbc/tests/TableQueryUpdateConnectionTest.scala +++ b/tests/src/test/scala/ldbc/tests/TableQueryUpdateConnectionTest.scala @@ -6,8 +6,6 @@ package ldbc.tests -import scala.concurrent.duration.DurationInt - import com.mysql.cj.jdbc.MysqlDataSource import cats.syntax.all.* @@ -387,7 +385,7 @@ trait TableQueryUpdateConnectionTest extends CatsEffectSuite: "The value of AutoIncrement obtained during insert matches the specified value." ) { assertIOBoolean( - IO.sleep(5.seconds) >> connection.use { conn => + connection.use { conn => (for length <- city.select(_.id.count).query.unsafe.map(_ + 1) result <- @@ -401,7 +399,7 @@ trait TableQueryUpdateConnectionTest extends CatsEffectSuite: ) } - test("") { + test("The value of AutoIncrement obtained during insert matches the specified value") { assertIO( connection.use { conn => city diff --git a/tests/src/test/scala/ldbc/tests/TableSchemaUpdateConnectionTest.scala b/tests/src/test/scala/ldbc/tests/TableSchemaUpdateConnectionTest.scala index 673116ac3..9b6ec4261 100644 --- a/tests/src/test/scala/ldbc/tests/TableSchemaUpdateConnectionTest.scala +++ b/tests/src/test/scala/ldbc/tests/TableSchemaUpdateConnectionTest.scala @@ -6,8 +6,6 @@ package ldbc.tests -import scala.concurrent.duration.DurationInt - import com.mysql.cj.jdbc.MysqlDataSource import cats.syntax.all.* @@ -387,7 +385,7 @@ trait TableSchemaUpdateConnectionTest extends CatsEffectSuite: "The value of AutoIncrement obtained during insert matches the specified value." ) { assertIOBoolean( - IO.sleep(5.seconds) >> connection.use { conn => + connection.use { conn => (for length <- city.select(_.id.count).query.unsafe.map(_ + 1) result <- @@ -401,7 +399,7 @@ trait TableSchemaUpdateConnectionTest extends CatsEffectSuite: ) } - test("") { + test("The value of AutoIncrement obtained during insert matches the specified value.") { assertIO( connection.use { conn => city