diff --git a/build.sbt b/build.sbt index aded597..cf8be3d 100644 --- a/build.sbt +++ b/build.sbt @@ -41,8 +41,8 @@ lazy val testDependencies = Seq( scalacOptions in Test ++= Seq("-Yrangepos") lazy val baseSettings = Seq( - scalacOptions ++= compilerOptions, - scalacOptions in (Compile, console) := compilerOptions, + scalacOptions ++= compilerOptions, + scalacOptions in (Compile, console) := compilerOptions, scalacOptions in (Compile, doc) ++= Seq( "-doc-title", "roc", "-doc-version", version.value, @@ -59,7 +59,7 @@ lazy val allSettings = buildSettings ++ baseSettings ++ Defaults.itSettings lazy val coreVersion = "0.0.5" -lazy val catsVersion = "0.6.0" +lazy val catsVersion = "0.8.1" lazy val finagleVersion = "6.38.0" diff --git a/core/src/main/scala/roc/postgresql/ClientDispatcher.scala b/core/src/main/scala/roc/postgresql/ClientDispatcher.scala index 2eed5aa..892dbe1 100644 --- a/core/src/main/scala/roc/postgresql/ClientDispatcher.scala +++ b/core/src/main/scala/roc/postgresql/ClientDispatcher.scala @@ -1,8 +1,7 @@ package roc package postgresql -import cats.data.Xor -import cats.std.all._ +import cats.instances.all._ import cats.syntax.eq._ import com.twitter.finagle.dispatch.GenSerialClientDispatcher import com.twitter.finagle.transport.Transport @@ -26,14 +25,14 @@ private[roc] final class ClientDispatcher(trans: Transport[Packet, Packet], private[roc] lazy val paramStatuses: Map[String, String] = mutableParamStatuses.map(x => (x.parameter, x.value)).toMap - override def apply(req: Request): Future[Result] = + override def apply(req: Request): Future[Result] = startupPhase.flatMap(_ => super.apply(req)) /** Performs the Startup phase of a Postgresql Connection. * * The startup phase is performed once per connection prior to any exchanges * between the client and server. Failure to startup renders the service unsuable. - * The startup phase consists of two separate but sequential phases + * The startup phase consists of two separate but sequential phases * 1. Authentication 2. Server Process setting run time parameters * @see [[http://www.postgresql.org/docs/current/static/protocol-flow.html#AEN108589]] */ @@ -55,8 +54,8 @@ private[roc] final class ClientDispatcher(trans: Transport[Packet, Packet], for { packet <- trans.read() message <- Message.decode(packet) match { - case Xor.Left(l) => Future.exception(l) - case Xor.Right(m) => Future.value(m) + case Left(l) => Future.exception(l) + case Right(m) => Future.value(m) } } yield message } @@ -73,7 +72,7 @@ private[roc] final class ClientDispatcher(trans: Transport[Packet, Packet], for { _ <- trans.write(encodePacket(query)).rescue(wrapWriteException) signal = rep.become(readTransport(query, new Promise[Unit])) - } yield signal + } yield signal } private[this] def readTransport(req: Transmission, signal: Promise[Unit]): Future[Result] = @@ -89,17 +88,17 @@ private[roc] final class ClientDispatcher(trans: Transport[Packet, Packet], type Collection = (Descriptions, Rows, CommandCompleteString) def go(xs: Descriptions, ys: Rows, ccStr: CommandCompleteString): Future[Collection] = trans.read().flatMap(packet => Message.decode(packet) match { - case Xor.Right(RowDescription(a,b)) => go(RowDescription(a,b) :: xs, ys, ccStr) - case Xor.Right(DataRow(a,b)) => go(xs, DataRow(a,b) :: ys, ccStr) - case Xor.Right(EmptyQueryResponse) => go(xs, ys, "EmptyQueryResponse") - case Xor.Right(CommandComplete(x)) => go(xs, ys, x) - case Xor.Right(ErrorResponse(e)) => + case Right(RowDescription(a,b)) => go(RowDescription(a,b) :: xs, ys, ccStr) + case Right(DataRow(a,b)) => go(xs, DataRow(a,b) :: ys, ccStr) + case Right(EmptyQueryResponse) => go(xs, ys, "EmptyQueryResponse") + case Right(CommandComplete(x)) => go(xs, ys, x) + case Right(ErrorResponse(e)) => Future.exception(new PostgresqlServerFailure(e)) - case Xor.Right(NoticeResponse(_)) => go(xs, ys, ccStr) // throw Notice Responses away - case Xor.Right(Idle) => Future.value((xs.reverse, ys.reverse, ccStr)) - case Xor.Right(u) => + case Right(NoticeResponse(_)) => go(xs, ys, ccStr) // throw Notice Responses away + case Right(Idle) => Future.value((xs.reverse, ys.reverse, ccStr)) + case Right(u) => Future.exception(new PostgresqlStateMachineFailure("Query", u.toString)) - case Xor.Left(l) => Future.exception(l) + case Left(l) => Future.exception(l) } ) @@ -147,16 +146,16 @@ private[roc] final class ClientDispatcher(trans: Transport[Packet, Packet], type ParamStatuses = List[ParameterStatus] type BKDs = List[BackendKeyData] - def go(safetyCheck: Int, xs: ParamStatuses, ys: BKDs): Future[(ParamStatuses, BKDs)] = + def go(safetyCheck: Int, xs: ParamStatuses, ys: BKDs): Future[(ParamStatuses, BKDs)] = safetyCheck match { // TODO - create an Error type for this case x if x > 1000 => Future.exception(new Exception()) case x if x < 1000 => trans.read().flatMap(packet => Message.decode(packet) match { - case Xor.Left(l) => Future.exception(l) - case Xor.Right(ParameterStatus(i, j)) => go(safetyCheck + 1, ParameterStatus(i,j) :: xs, ys) - case Xor.Right(BackendKeyData(i, j)) => go(safetyCheck + 1, xs, BackendKeyData(i, j) :: ys) - case Xor.Right(Idle) => Future.value((xs, ys)) - case Xor.Right(message) => Future.exception( + case Left(l) => Future.exception(l) + case Right(ParameterStatus(i, j)) => go(safetyCheck + 1, ParameterStatus(i,j) :: xs, ys) + case Right(BackendKeyData(i, j)) => go(safetyCheck + 1, xs, BackendKeyData(i, j) :: ys) + case Right(Idle) => Future.value((xs, ys)) + case Right(message) => Future.exception( new PostgresqlStateMachineFailure("StartupMessage", message.toString) ) }) diff --git a/core/src/main/scala/roc/postgresql/Messages.scala b/core/src/main/scala/roc/postgresql/Messages.scala index 03ad228..e612cbb 100644 --- a/core/src/main/scala/roc/postgresql/Messages.scala +++ b/core/src/main/scala/roc/postgresql/Messages.scala @@ -2,8 +2,7 @@ package roc package postgresql import cats.Eq -import cats.data.Xor -import cats.std.all._ +import cats.instances.all._ import cats.syntax.eq._ import com.twitter.util.Future import java.nio.charset.StandardCharsets @@ -30,7 +29,7 @@ private[postgresql] object Message { val TerminateByte: Char = 'X' val NoticeResponseByte: Char = 'N' - private[postgresql] def decode(packet: Packet): Xor[Failure, Message] = packet.messageType match { + private[postgresql] def decode(packet: Packet): Either[Failure, Message] = packet.messageType match { case Some(mt) if mt === AuthenticationMessageByte => decodePacket[AuthenticationMessage](packet) case Some(mt) if mt === ErrorByte => decodePacket[ErrorResponse](packet) case Some(mt) if mt === NoticeResponseByte => decodePacket[NoticeResponse](packet) @@ -40,12 +39,12 @@ private[postgresql] object Message { case Some(mt) if mt === RowDescriptionByte => decodePacket[RowDescription](packet) case Some(mt) if mt === DataRowByte => decodePacket[DataRow](packet) case Some(mt) if mt === CommandCompleteByte => decodePacket[CommandComplete](packet) - case Some(mt) if mt === EmptyQueryResponseByte => Xor.Right(EmptyQueryResponse) + case Some(mt) if mt === EmptyQueryResponseByte => Right(EmptyQueryResponse) case Some(mt) => { println(s"Inside Some($mt)") - Xor.Left(new UnknownPostgresqlMessageTypeFailure(mt)) + Left(new UnknownPostgresqlMessageTypeFailure(mt)) } - case None => Xor.Left(new UnexpectedNoneFailure("")) + case None => Left(new UnexpectedNoneFailure("")) } implicit val messageEq: Eq[Message] = new Eq[Message] { @@ -61,7 +60,7 @@ private[postgresql] case class Query(queryString: String) extends FrontendMessag private[postgresql] case class PasswordMessage(password: String) extends FrontendMessage private[postgresql] object PasswordMessage { - def encryptMD5Passwd(user: String, passwd: String, + def encryptMD5Passwd(user: String, passwd: String, salt: Array[Byte]): String = { val md = MessageDigest.getInstance("MD5") md.update((passwd + user).getBytes) @@ -81,21 +80,21 @@ private[postgresql] case class ErrorResponse(error: PostgresqlMessage) extends B private[postgresql] sealed abstract class AuthenticationMessage extends BackendMessage private[postgresql] object AuthenticationMessage { - def apply(tuple: (Int, Option[Array[Byte]])): Failure Xor AuthenticationMessage = tuple match { - case (0, None) => Xor.Right(AuthenticationOk) - case (2, None) => Xor.Right(AuthenticationKerberosV5) - case (3, None) => Xor.Right(AuthenticationClearTxtPasswd) - case (5, Some(bytes)) => Xor.Right(new AuthenticationMD5Passwd(bytes)) - case (6, None) => Xor.Right(AuthenticationSCMCredential) - case (7, None) => Xor.Right(AuthenticationGSS) - case (8, Some(bytes)) => Xor.Right(new AuthenticationGSSContinue(bytes)) - case (9, None) => Xor.Right(AuthenticationSSPI) - case (x, _) => Xor.Left(new UnknownAuthenticationRequestFailure(x)) + def apply(tuple: (Int, Option[Array[Byte]])): Failure Either AuthenticationMessage = tuple match { + case (0, None) => Right(AuthenticationOk) + case (2, None) => Right(AuthenticationKerberosV5) + case (3, None) => Right(AuthenticationClearTxtPasswd) + case (5, Some(bytes)) => Right(new AuthenticationMD5Passwd(bytes)) + case (6, None) => Right(AuthenticationSCMCredential) + case (7, None) => Right(AuthenticationGSS) + case (8, Some(bytes)) => Right(new AuthenticationGSSContinue(bytes)) + case (9, None) => Right(AuthenticationSSPI) + case (x, _) => Left(new UnknownAuthenticationRequestFailure(x)) } } private[postgresql] case object AuthenticationOk extends AuthenticationMessage private[postgresql] case object AuthenticationClearTxtPasswd extends AuthenticationMessage -private[postgresql] case class AuthenticationMD5Passwd(salt: Array[Byte]) +private[postgresql] case class AuthenticationMD5Passwd(salt: Array[Byte]) extends AuthenticationMessage { def canEqual(a: Any) = a.isInstanceOf[AuthenticationMD5Passwd] @@ -110,30 +109,30 @@ private[postgresql] case object AuthenticationKerberosV5 extends AuthenticationM private[postgresql] case object AuthenticationSCMCredential extends AuthenticationMessage private[postgresql] case object AuthenticationGSS extends AuthenticationMessage private[postgresql] case object AuthenticationSSPI extends AuthenticationMessage -private[postgresql] case class AuthenticationGSSContinue(authBytes: Array[Byte]) +private[postgresql] case class AuthenticationGSSContinue(authBytes: Array[Byte]) extends AuthenticationMessage { def canEqual(a: Any) = a.isInstanceOf[AuthenticationGSSContinue] final override def equals(that: Any): Boolean = that match { - case x: AuthenticationGSSContinue => x.canEqual(this) && + case x: AuthenticationGSSContinue => x.canEqual(this) && authBytes.length == x.authBytes.length && (authBytes sameElements x.authBytes) case _ => false } } -private[postgresql] case class ParameterStatus(parameter: String, value: String) +private[postgresql] case class ParameterStatus(parameter: String, value: String) extends BackendMessage private[postgresql] case class BackendKeyData(processId: Int, secretKey: Int) extends BackendMessage private[postgresql] sealed abstract class ReadyForQuery extends BackendMessage private[postgresql] object ReadyForQuery { - def apply(transactionStatus: Char): ReadyForQueryDecodingFailure Xor ReadyForQuery = + def apply(transactionStatus: Char): ReadyForQueryDecodingFailure Either ReadyForQuery = transactionStatus match { - case 'I' => Xor.Right(Idle) - case 'T' => Xor.Right(TransactionBlock) - case 'E' => Xor.Right(FailedTransactionBlock) - case c => Xor.Left(new ReadyForQueryDecodingFailure(c)) + case 'I' => Right(Idle) + case 'T' => Right(TransactionBlock) + case 'E' => Right(FailedTransactionBlock) + case c => Left(new ReadyForQueryDecodingFailure(c)) } } @@ -144,8 +143,8 @@ private[postgresql] case object FailedTransactionBlock extends ReadyForQuery private[postgresql] case object EmptyQueryResponse extends BackendMessage private[postgresql] case class RowDescription(numFields: Short, fields: List[RowDescriptionField]) - extends BackendMessage -private[postgresql] case class RowDescriptionField(name: String, tableObjectId: Int, + extends BackendMessage +private[postgresql] case class RowDescriptionField(name: String, tableObjectId: Int, tableAttributeId: Short, dataTypeObjectId: Int, dataTypeSize: Short, typeModifier: Int, formatCode: FormatCode) diff --git a/core/src/main/scala/roc/postgresql/results.scala b/core/src/main/scala/roc/postgresql/results.scala index cab87e8..6fbffaf 100644 --- a/core/src/main/scala/roc/postgresql/results.scala +++ b/core/src/main/scala/roc/postgresql/results.scala @@ -1,7 +1,6 @@ package roc package postgresql -import cats.data.Xor import cats.Show import java.nio.charset.StandardCharsets import roc.postgresql.failures.{ElementNotFoundFailure, UnsupportedDecodingFailure} @@ -48,21 +47,21 @@ final class Result(rowDescription: List[RowDescription], data: List[DataRow], cc /** The command tag. This is usually a single word that identifies which SQL command was completed. * - * For an INSERT command, the tag is INSERT oid rows, where rows is the number of rows inserted. + * For an INSERT command, the tag is INSERT oid rows, where rows is the number of rows inserted. * oid is the object ID of the inserted row if rows is 1 and the target table has OIDs; * otherwise oid is 0. * * For a DELETE command, the tag is DELETE rows where rows is the number of rows deleted. - * + * * For an UPDATE command, the tag is UPDATE rows where rows is the number of rows updated. * - * For a SELECT or CREATE TABLE AS command, the tag is SELECT rows where rows is the number of + * For a SELECT or CREATE TABLE AS command, the tag is SELECT rows where rows is the number of * rows retrieved. * - * For a MOVE command, the tag is MOVE rows where rows is the number of rows the cursor's + * For a MOVE command, the tag is MOVE rows where rows is the number of rows the cursor's * position has been changed by. * - * For a FETCH command, the tag is FETCH rows where rows is the number of rows that have been + * For a FETCH command, the tag is FETCH rows where rows is the number of rows that have been * retrieved from the cursor. * @see [[http://www.postgresql.org/docs/current/static/protocol-message-formats.html * CommandComplete]] @@ -97,14 +96,14 @@ final case class Column private[roc](name: Symbol, columnType: Int, formatCode: } object Column { implicit val columnShow: Show[Column] = new Show[Column] { - def show(c: Column): String = + def show(c: Column): String = s"Column(name=${c.name}, columnType=${c.columnType}, formatCode=${c.formatCode})" } } /** A row returned from a Postgresql Server containing at least one * [[Element]] - * @param elements a collection of all [[row.postgresql.Element Elements]] returned from + * @param elements a collection of all [[row.postgresql.Element Elements]] returned from * Postgresql via a query. */ final class Row private[postgresql](private[postgresql] val elements: List[Element]) { @@ -112,7 +111,7 @@ final class Row private[postgresql](private[postgresql] val elements: List[Eleme /** Returns the [[roc.postgresql.Element Element]] found via the column name * * @param columnName the column name given the associated [[roc.postgresql.Element Element]] - * @return the element found via the column name + * @return the element found via the column name */ def get(columnName: Symbol): Element = elements.find(_.name == columnName) match { case Some(e) => e @@ -144,7 +143,7 @@ sealed abstract class Element(val name: Symbol, columnType: Int) { * @param f an implicit [[ElementDecoder]] typeclass * @return A */ - def as[A](implicit f: ElementDecoder[A]): A = + def as[A](implicit f: ElementDecoder[A]): A = fold(f.textDecoder, f.binaryDecoder, f.nullDecoder) /** Decodes this element as a String @@ -156,7 +155,7 @@ sealed abstract class Element(val name: Symbol, columnType: Int) { {(s: String) => s}, {(bs: Array[Byte]) => throw new UnsupportedDecodingFailure(s"Attempted String decoding of Binary column.")}, - {() => + {() => throw new UnsupportedDecodingFailure(s"Attempted String decoding of Null column.")} ) @@ -165,11 +164,11 @@ sealed abstract class Element(val name: Symbol, columnType: Int) { * @see [[http://www.postgresql.org/docs/current/static/protocol-overview.html * 50.1.3 Formats and Format Codes]] * @note Binary representations for integers use network byte order (most significant byte first). - * For other data types consult the documentation or source code to learn about the binary + * For other data types consult the documentation or source code to learn about the binary * representation. */ def asBytes(): Array[Byte] = fold( - {(s: String) => + {(s: String) => throw new UnsupportedDecodingFailure(s"Attempted Binary decoding of String column.")}, {(bs: Array[Byte]) => bs}, {() => @@ -178,7 +177,7 @@ sealed abstract class Element(val name: Symbol, columnType: Int) { } case class Null(override val name: Symbol, columnType: Int) extends Element(name, columnType) -case class Text(override val name: Symbol, columnType: Int, value: String) +case class Text(override val name: Symbol, columnType: Int, value: String) extends Element(name, columnType) case class Binary(override val name: Symbol, columnType: Int, value: Array[Byte]) extends Element(name, columnType) diff --git a/core/src/main/scala/roc/postgresql/server/PostgresqlMessages.scala b/core/src/main/scala/roc/postgresql/server/PostgresqlMessages.scala index 760d9cd..de2d2b0 100644 --- a/core/src/main/scala/roc/postgresql/server/PostgresqlMessages.scala +++ b/core/src/main/scala/roc/postgresql/server/PostgresqlMessages.scala @@ -3,9 +3,10 @@ package postgresql package server import cats.data.Validated._ -import cats.data.{NonEmptyList, Validated, ValidatedNel, Xor} +import cats.data.{NonEmptyList, Validated, ValidatedNel} import cats.Semigroup -import cats.std.all._ +import cats.implicits._ +import cats.instances.all._ import cats.syntax.eq._ import roc.postgresql.failures.{PostgresqlMessageDecodingFailure, Failure} @@ -20,7 +21,7 @@ import roc.postgresql.failures.{PostgresqlMessageDecodingFailure, Failure} /** The severity of the Error or Notice * - * The field contents are ERROR, FATAL, or PANIC (in an error message), or WARNING, NOTICE, + * The field contents are ERROR, FATAL, or PANIC (in an error message), or WARNING, NOTICE, * DEBUG, INFO, or LOG (in a notice message), or a localized translation of one of these. * @note Always present. * @see [[http://www.postgresql.org/docs/current/static/protocol-error-fields.html]] @@ -52,7 +53,7 @@ import roc.postgresql.failures.{PostgresqlMessageDecodingFailure, Failure} /** A optional suggestion what to do about the problem. * - * This is intended to differ from Detail in that it offers advice + * This is intended to differ from Detail in that it offers advice * (potentially inappropriate) rather than hard facts. Might run to multiple lines. * @see [[http://www.postgresql.org/docs/current/static/protocol-error-fields.html]] */ @@ -60,8 +61,8 @@ import roc.postgresql.failures.{PostgresqlMessageDecodingFailure, Failure} /** Indicates an error cursor position as an index into the original query string. * - * The field value is a decimal ASCII integer, indicating an error cursor position - * as an index into the original query string. The first character has index 1, and + * The field value is a decimal ASCII integer, indicating an error cursor position + * as an index into the original query string. The first character has index 1, and * positions are measured in characters not bytes. * @see [[http://www.postgresql.org/docs/current/static/protocol-error-fields.html]] */ @@ -69,7 +70,7 @@ import roc.postgresql.failures.{PostgresqlMessageDecodingFailure, Failure} /** Indicates an error cursor postion as an index of an internally generated command. * - * This is defined the same as [[position]], but it is used when the cursor position refers + * This is defined the same as [[position]], but it is used when the cursor position refers * to an internally generated command rather than the one submitted by the client. The * [[internalQuery]] field will always appear when this field appears. * @see [[http://www.postgresql.org/docs/current/static/protocol-error-fields.html]] @@ -85,7 +86,7 @@ import roc.postgresql.failures.{PostgresqlMessageDecodingFailure, Failure} /** An indication of the context in which the error occurred. * - * Presently this includes a call stack traceback of active procedural language functions + * Presently this includes a call stack traceback of active procedural language functions * and internally-generated queries. The trace is one entry per line, most recent first. * @see [[http://www.postgresql.org/docs/current/static/protocol-error-fields.html]] */ @@ -93,7 +94,7 @@ import roc.postgresql.failures.{PostgresqlMessageDecodingFailure, Failure} /** The name of the schema containing that object. * - * If the error was associated with a specific database object, the name of the schema + * If the error was associated with a specific database object, the name of the schema * containing that object, if any. * @see [[http://www.postgresql.org/docs/current/static/protocol-error-fields.html]] */ @@ -101,7 +102,7 @@ import roc.postgresql.failures.{PostgresqlMessageDecodingFailure, Failure} /** The name of the table. * - * If the error was associated with a specific table, the name of the table. + * If the error was associated with a specific table, the name of the table. * (Refer to the schema name field for the name of the table's schema.) * @see [[http://www.postgresql.org/docs/current/static/protocol-error-fields.html]] */ @@ -109,7 +110,7 @@ import roc.postgresql.failures.{PostgresqlMessageDecodingFailure, Failure} /** The name of the column. * - * If the error was associated with a specific table column, the name of the column. + * If the error was associated with a specific table column, the name of the column. * (Refer to the schema and table name fields to identify the table.) * @see [[http://www.postgresql.org/docs/current/static/protocol-error-fields.html]] */ @@ -117,7 +118,7 @@ import roc.postgresql.failures.{PostgresqlMessageDecodingFailure, Failure} /** The name of the data type. * - * If the error was associated with a specific data type, the name of the data type. + * If the error was associated with a specific data type, the name of the data type. * (Refer to the schema name field for the name of the data type's schema.) * @see [[http://www.postgresql.org/docs/current/static/protocol-error-fields.html]] */ @@ -125,8 +126,8 @@ import roc.postgresql.failures.{PostgresqlMessageDecodingFailure, Failure} /** The name of the constraint. * - * If the error was associated with a specific constraint, the name of the constraint. - * Refer to fields listed above for the associated table or domain. (For this purpose, indexes + * If the error was associated with a specific constraint, the name of the constraint. + * Refer to fields listed above for the associated table or domain. (For this purpose, indexes * are treated as constraints, even if they weren't created with constraint syntax.) * @see [[http://www.postgresql.org/docs/current/static/protocol-error-fields.html]] */ @@ -148,7 +149,7 @@ import roc.postgresql.failures.{PostgresqlMessageDecodingFailure, Failure} val routine: Option[String] = params.routine override def toString: String = { - val xs = List(("Detail: ", detail), ("Hint: ", hint), ("Position: ", position), + val xs = List(("Detail: ", detail), ("Hint: ", hint), ("Position: ", position), ("Internal Position: ", internalPosition), ("Internal Query: ", internalQuery), ("Where: ", where), ("Schema Name: ", schemaName), ("Table Name: ", tableName), ("Column Name: ", columnName), ("Data Type Name: ", dataTypeName), ("Constaint Name: ", @@ -163,11 +164,11 @@ import roc.postgresql.failures.{PostgresqlMessageDecodingFailure, Failure} } -private[postgresql] case class ErrorParams(severity: String, code: String, message: String, +private[postgresql] case class ErrorParams(severity: String, code: String, message: String, detail: Option[String], hint: Option[String], position: Option[String], - internalPosition: Option[String], internalQuery: Option[String], where: Option[String], + internalPosition: Option[String], internalQuery: Option[String], where: Option[String], schemaName: Option[String], tableName: Option[String], columnName: Option[String], - dataTypeName: Option[String], constraintName: Option[String], file: Option[String], + dataTypeName: Option[String], constraintName: Option[String], file: Option[String], line: Option[String], routine: Option[String]) private[postgresql] case class RequiredParams(severity: String, code: String, message: String) @@ -175,17 +176,17 @@ private[postgresql] case class RequiredParams(severity: String, code: String, me private[postgresql] object PostgresqlMessage { import ErrorNoticeMessageFields._ - def apply(xs: Fields): Xor[Failure, PostgresqlMessage] = + def apply(xs: Fields): Either[Failure, PostgresqlMessage] = buildParamsFromTuples(xs).flatMap(x => x.code.take(2) match { - case ErrorClassCodes.SuccessfulCompletion => Xor.Right(new SuccessMessage(x)) - case code if ErrorClassCodes.WarningCodes.contains(code) => Xor.Right(new WarningMessage(x)) - case code if ErrorClassCodes.ErrorCodes.contains(code) => Xor.Right(new ErrorMessage(x)) - case code => Xor.Right(new UnknownMessage(x)) + case ErrorClassCodes.SuccessfulCompletion => Right(new SuccessMessage(x)) + case code if ErrorClassCodes.WarningCodes.contains(code) => Right(new WarningMessage(x)) + case code if ErrorClassCodes.ErrorCodes.contains(code) => Right(new ErrorMessage(x)) + case code => Right(new UnknownMessage(x)) }) // private to server for testing - private[server] def buildParamsFromTuples(xs: List[Field]): - Xor[PostgresqlMessageDecodingFailure, ErrorParams] = { + private[server] def buildParamsFromTuples(xs: List[Field]): + Either[PostgresqlMessageDecodingFailure, ErrorParams] = { val detail = extractValueByCode(Detail, xs) val hint = extractValueByCode(Hint, xs) val position = extractValueByCode(Position, xs) @@ -214,14 +215,14 @@ private[postgresql] object PostgresqlMessage { case None => Invalid("Required Message was not present.") } - validatePacket(severity.toValidatedNel, code.toValidatedNel, + validatePacket(severity.toValidatedNel, code.toValidatedNel, message.toValidatedNel)(RequiredParams.apply) .fold( - {l => Xor.Left(new PostgresqlMessageDecodingFailure(l))}, - {r => Xor.Right(new ErrorParams(severity = r.severity, code = r.code, message = r.message, + {l => Left(new PostgresqlMessageDecodingFailure(l))}, + {r => Right(new ErrorParams(severity = r.severity, code = r.code, message = r.message, detail = detail, hint = hint, position = position, internalPosition = internalPosition, internalQuery = internalQuery, where = where, schemaName = schemaName, - tableName = tableName, columnName = columnName, dataTypeName = dataTypeName, + tableName = tableName, columnName = columnName, dataTypeName = dataTypeName, constraintName = constraintName, file = file, line = line, routine = routine))} ) } @@ -241,14 +242,14 @@ private[postgresql] object PostgresqlMessage { case (Valid(_), Invalid(e1), Invalid(e2)) => Invalid(Semigroup[E].combine(e1, e2)) case (Invalid(e1), Valid(_), Invalid(e2)) => Invalid(Semigroup[E].combine(e1, e2)) case (Invalid(e1), Invalid(e2), Valid(_)) => Invalid(Semigroup[E].combine(e1, e2)) - case (Invalid(e1), Invalid(e2), Invalid(e3)) => + case (Invalid(e1), Invalid(e2), Invalid(e3)) => Invalid(Semigroup[E].combine(e1, Semigroup[E].combine(e2, e3))) } } /** Represents an unknown or undefined message. * - * From Postgresql Documentation: "Since more field types might be added in future, + * From Postgresql Documentation: "Since more field types might be added in future, * frontends should silently ignore fields of unrecognized type." Therefore, if we decode * an Error we do not recognize, we do not create a Failed Decoding Result. */ diff --git a/core/src/main/scala/roc/postgresql/transport/PacketDecoders.scala b/core/src/main/scala/roc/postgresql/transport/PacketDecoders.scala index 3f12bbc..c4d57a8 100644 --- a/core/src/main/scala/roc/postgresql/transport/PacketDecoders.scala +++ b/core/src/main/scala/roc/postgresql/transport/PacketDecoders.scala @@ -2,7 +2,7 @@ package roc package postgresql package transport -import cats.data.Xor +import cats.implicits._ import roc.postgresql.failures.{Failure, PacketDecodingFailure} import roc.postgresql.server.PostgresqlMessage import scala.collection.mutable.ListBuffer @@ -11,7 +11,7 @@ private[postgresql] trait PacketDecoder[A <: BackendMessage] { def apply(p: Packet): PacketDecoder.Result[A] } private[postgresql] object PacketDecoder { - final type Result[A] = Xor[Failure, A] + final type Result[A] = Either[Failure, A] } private[postgresql] trait PacketDecoderImplicits { @@ -20,7 +20,7 @@ private[postgresql] trait PacketDecoderImplicits { private[this] type Field = (Char, String) private[this] type Fields = List[Field] - def readErrorNoticePacket(p: Packet): Xor[Throwable, Fields] = Xor.catchNonFatal({ + def readErrorNoticePacket(p: Packet): Either[Throwable, Fields] = Either.catchNonFatal({ val br = BufferReader(p.body) @annotation.tailrec @@ -34,29 +34,29 @@ private[postgresql] trait PacketDecoderImplicits { loop(List.empty[Field]) }) - implicit val noticeResponsePacketDecoder: PacketDecoder[NoticeResponse] = + implicit val noticeResponsePacketDecoder: PacketDecoder[NoticeResponse] = new PacketDecoder[NoticeResponse] { def apply(p: Packet): Result[NoticeResponse] = readErrorNoticePacket(p) .leftMap(t => new PacketDecodingFailure(t.getMessage)) .flatMap(xs => PostgresqlMessage(xs).fold( - {l => Xor.Left(l)}, - {r => Xor.Right(new NoticeResponse(r))} + {l => Left(l)}, + {r => Right(new NoticeResponse(r))} )) } - implicit val errorMessagePacketDecoder: PacketDecoder[ErrorResponse] = + implicit val errorMessagePacketDecoder: PacketDecoder[ErrorResponse] = new PacketDecoder[ErrorResponse] { def apply(p: Packet): Result[ErrorResponse] = readErrorNoticePacket(p) .leftMap(t => new PacketDecodingFailure(t.getMessage)) .flatMap(xs => PostgresqlMessage(xs).fold( - {l => Xor.Left(l)}, - {r => Xor.Right(new ErrorResponse(r))} + {l => Left(l)}, + {r => Right(new ErrorResponse(r))} )) } - implicit val commandCompletePacketDecoder: PacketDecoder[CommandComplete] = + implicit val commandCompletePacketDecoder: PacketDecoder[CommandComplete] = new PacketDecoder[CommandComplete] { - def apply(p: Packet): Result[CommandComplete] = Xor.catchNonFatal({ + def apply(p: Packet): Result[CommandComplete] = Either.catchNonFatal({ val br = BufferReader(p.body) val commandTag = br.readNullTerminatedString() new CommandComplete(commandTag) @@ -65,7 +65,7 @@ private[postgresql] trait PacketDecoderImplicits { implicit val parameterStatusPacketDecoder: PacketDecoder[ParameterStatus] = new PacketDecoder[ParameterStatus] { - def apply(p: Packet): Result[ParameterStatus] = Xor.catchNonFatal({ + def apply(p: Packet): Result[ParameterStatus] = Either.catchNonFatal({ val br = BufferReader(p.body) val param = br.readNullTerminatedString() val value = br.readNullTerminatedString() @@ -73,9 +73,9 @@ private[postgresql] trait PacketDecoderImplicits { }).leftMap(t => new PacketDecodingFailure(t.getMessage)) } - implicit val backendKeyDataPacketDecoder: PacketDecoder[BackendKeyData] = + implicit val backendKeyDataPacketDecoder: PacketDecoder[BackendKeyData] = new PacketDecoder[BackendKeyData] { - def apply(p: Packet): Result[BackendKeyData] = Xor.catchNonFatal({ + def apply(p: Packet): Result[BackendKeyData] = Either.catchNonFatal({ val br = BufferReader(p.body) val processId = br.readInt val secretKey = br.readInt @@ -85,7 +85,7 @@ private[postgresql] trait PacketDecoderImplicits { implicit val readyForQueryPacketDecoder: PacketDecoder[ReadyForQuery] = new PacketDecoder[ReadyForQuery] { - def apply(p: Packet): Result[ReadyForQuery] = Xor.catchNonFatal({ + def apply(p: Packet): Result[ReadyForQuery] = Either.catchNonFatal({ val br = BufferReader(p.body) val byte = br.readByte byte.toChar @@ -94,9 +94,9 @@ private[postgresql] trait PacketDecoderImplicits { .flatMap(ReadyForQuery(_)) } - implicit val rowDescriptionPacketDecoder: PacketDecoder[RowDescription] = + implicit val rowDescriptionPacketDecoder: PacketDecoder[RowDescription] = new PacketDecoder[RowDescription] { - def apply(p: Packet): Result[RowDescription] = Xor.catchNonFatal({ + def apply(p: Packet): Result[RowDescription] = Either.catchNonFatal({ val br = BufferReader(p.body) val numFields = br.readShort @@ -106,7 +106,7 @@ private[postgresql] trait PacketDecoderImplicits { case x if x < numFields => { val name = br.readNullTerminatedString() val tableObjectId = br.readInt - val tableAttributeId = br.readShort + val tableAttributeId = br.readShort val dataTypeObjectId = br.readInt val dataTypeSize = br.readShort val typeModifier = br.readInt @@ -129,12 +129,12 @@ private[postgresql] trait PacketDecoderImplicits { } implicit val dataRowPacketDecoder: PacketDecoder[DataRow] = new PacketDecoder[DataRow] { - def apply(p: Packet): Result[DataRow] = Xor.catchNonFatal({ + def apply(p: Packet): Result[DataRow] = Either.catchNonFatal({ val br = BufferReader(p.body) val columns = br.readShort @annotation.tailrec - def loop(idx: Short, cbs: List[Option[Array[Byte]]]): List[Option[Array[Byte]]] = + def loop(idx: Short, cbs: List[Option[Array[Byte]]]): List[Option[Array[Byte]]] = idx match { case x if x < columns => { val columnLength = br.readInt @@ -155,9 +155,9 @@ private[postgresql] trait PacketDecoderImplicits { }).leftMap(t => new PacketDecodingFailure(t.getMessage)) } - implicit val authenticationMessagePacketDecoder: PacketDecoder[AuthenticationMessage] = + implicit val authenticationMessagePacketDecoder: PacketDecoder[AuthenticationMessage] = new PacketDecoder[AuthenticationMessage] { - def apply(p: Packet): Result[AuthenticationMessage] = Xor.catchNonFatal({ + def apply(p: Packet): Result[AuthenticationMessage] = Either.catchNonFatal({ val br = BufferReader(p.body) br.readInt match { case 0 => (0, None) diff --git a/core/src/test/scala/roc/postgresql/server/ErrorNoticeGenerator.scala b/core/src/test/scala/roc/postgresql/server/ErrorNoticeGenerator.scala index d6b1d4b..a700748 100644 --- a/core/src/test/scala/roc/postgresql/server/ErrorNoticeGenerator.scala +++ b/core/src/test/scala/roc/postgresql/server/ErrorNoticeGenerator.scala @@ -2,7 +2,7 @@ package roc package postgresql package server -import cats.std.all._ +import cats.instances.all._ import cats.syntax.eq._ import org.scalacheck.Arbitrary.arbitrary import org.scalacheck.Prop.forAll diff --git a/core/src/test/scala/roc/postgresql/server/PostgresqlMessageSpec.scala b/core/src/test/scala/roc/postgresql/server/PostgresqlMessageSpec.scala index 78483fd..7398d02 100644 --- a/core/src/test/scala/roc/postgresql/server/PostgresqlMessageSpec.scala +++ b/core/src/test/scala/roc/postgresql/server/PostgresqlMessageSpec.scala @@ -3,9 +3,9 @@ package postgresql package server import cats.data.Validated._ -import cats.data.{NonEmptyList, Validated, Xor} +import cats.data.{NonEmptyList, Validated} import cats.Semigroup -import cats.std.all._ +import cats.instances.all._ import cats.syntax.eq._ import org.scalacheck.Arbitrary.arbitrary import org.scalacheck.Prop.forAll @@ -20,10 +20,10 @@ final class PostgresqlMessageSpec extends Specification with ScalaCheck { def is PostgresqlMessage must extract the value of a tuple by the Code ${PE().testExtractValueByCode} - must return Xor.Right(UnknownMessage(ErrorParams)) when given unknown SQLSTATE Code ${PE().testUnknownMessage} - must return Xor.Right(SuccesfulMessage) when given a valid Succes Code ${PE().testSuccessfulMessage} - must return Xor.Right(WarningMessage(ErrorParams)) when given a Warning Code ${PE().testWarningMessages} - must return Xor.Right(ErrorMessage(ErrorParams)) when given an Error Code ${PE().testErrorMessages} + must return Right(UnknownMessage(ErrorParams)) when given unknown SQLSTATE Code ${PE().testUnknownMessage} + must return Right(SuccesfulMessage) when given a valid Succes Code ${PE().testSuccessfulMessage} + must return Right(WarningMessage(ErrorParams)) when given a Warning Code ${PE().testWarningMessages} + must return Right(ErrorMessage(ErrorParams)) when given an Error Code ${PE().testErrorMessages} ValidatePacket must return RequiredParams when fields are valid ${VP().testAllValid} @@ -36,8 +36,8 @@ final class PostgresqlMessageSpec extends Specification with ScalaCheck { def is must return Invalid when Severity & SQLSTATE Code & Message are not present ${VP().testInvalidAll} BuildParamsFromTuples - must return Xor.Right(ErrorParams) when given valid Fields ${BPFT().testValidFields} - must return Xor.Left(PostgresqlMessageDecodingFailure) when given invalid Fields ${BPFT().testInvalidFields} + must return Right(ErrorParams) when given valid Fields ${BPFT().testValidFields} + must return Left(PostgresqlMessageDecodingFailure) when given invalid Fields ${BPFT().testInvalidFields} must have correct Error Message when Severity is invalid ${BPFT().testSeverityMessage} must have correct Error Message when SQLSTATECode is invalid ${BPFT().testSqlStateCodeMessage} must have correct Error Message when Message is invalid ${BPFT().testMessageMessage} @@ -78,19 +78,19 @@ final class PostgresqlMessageSpec extends Specification with ScalaCheck { def is } val testUnknownMessage = forAll(unknownErrorGen) { x: FieldsAndErrorParams => - PostgresqlMessage(x.fields) must_== Xor.Right(UnknownMessage(x.errorParams)) + PostgresqlMessage(x.fields) must_== Right(UnknownMessage(x.errorParams)) } val testSuccessfulMessage = forAll(successfulMessageGen) { x: FieldsAndErrorParams => - PostgresqlMessage(x.fields) must_== Xor.Right(SuccessMessage(x.errorParams)) + PostgresqlMessage(x.fields) must_== Right(SuccessMessage(x.errorParams)) } val testWarningMessages = forAll(warningMessageGen) { x: FieldsAndErrorParams => - PostgresqlMessage(x.fields) must_== Xor.Right(WarningMessage(x.errorParams)) + PostgresqlMessage(x.fields) must_== Right(WarningMessage(x.errorParams)) } val testErrorMessages = forAll(errorMessageGen) { x: FieldsAndErrorParams => - PostgresqlMessage(x.fields) must_== Xor.Right(ErrorMessage(x.errorParams)) + PostgresqlMessage(x.fields) must_== Right(ErrorMessage(x.errorParams)) } } @@ -135,7 +135,7 @@ final class PostgresqlMessageSpec extends Specification with ScalaCheck { def is actual must_== expected } - val testInvalidMessage = forAll(invalidMessageFieldsGen) { xs: Fields => + val testInvalidMessage = forAll(invalidMessageFieldsGen) { xs: Fields => val severity = extractSeverity(xs) val code = extractCode(xs) val message = extractMessage(xs) @@ -173,7 +173,7 @@ final class PostgresqlMessageSpec extends Specification with ScalaCheck { def is actual must_== expected } - + val testInvalidSeveritySqlStateCode = forAll(invalidSeveritySqlStateCodeFieldsGen) { xs: Fields => val severity = extractSeverity(xs) val code = extractCode(xs) @@ -200,17 +200,17 @@ final class PostgresqlMessageSpec extends Specification with ScalaCheck { def is actual must_== expected } - private def extractSeverity(xs: Fields): Validated[String, String] = + private def extractSeverity(xs: Fields): Validated[String, String] = xs.find(_._1 === Severity) match { case Some(x) => Valid(x._2) case None => Invalid("Required Severity Level was not present.") } - private def extractCode(xs: Fields): Validated[String, String] = + private def extractCode(xs: Fields): Validated[String, String] = xs.find(_._1 === ErrorNoticeMessageFields.Code) match { case Some(x) => Valid(x._2) case None => Invalid("Required SQLSTATE Code was not present.") } - private def extractMessage(xs: Fields): Validated[String, String] = + private def extractMessage(xs: Fields): Validated[String, String] = xs.find(_._1 === Message) match { case Some(x) => Valid(x._2) case None => Invalid("Required Message was not present.") @@ -224,7 +224,7 @@ final class PostgresqlMessageSpec extends Specification with ScalaCheck { def is case (Invalid(e1), Invalid(e2), Valid(_)) => Invalid(Semigroup[E].combine(e1, e2)) case (Invalid(e1), Valid(_), Invalid(e2)) => Invalid(Semigroup[E].combine(e1, e2)) case (Valid(_), Invalid(e1), Invalid(e2)) => Invalid(Semigroup[E].combine(e1, e2)) - case (Invalid(e1), Invalid(e2), Invalid(e3)) => + case (Invalid(e1), Invalid(e2), Invalid(e3)) => Invalid(Semigroup[E].combine(e1, Semigroup[E].combine(e2, e3))) } } @@ -243,21 +243,21 @@ final class PostgresqlMessageSpec extends Specification with ScalaCheck { def is val xs = List((ErrorNoticeMessageFields.Code, "Foo"), (Message, "Bar")) val actual = PostgresqlMessage.buildParamsFromTuples(xs) val nel = NonEmptyList("Required Severity Level was not present.") - actual must_== Xor.Left(new PostgresqlMessageDecodingFailure(nel)) + actual must_== Left(new PostgresqlMessageDecodingFailure(nel)) } val testSqlStateCodeMessage = { val xs = List((Severity, "Foo"), (Message, "Bar")) val actual = PostgresqlMessage.buildParamsFromTuples(xs) val nel = NonEmptyList("Required SQLSTATE Code was not present.") - actual must_== Xor.Left(new PostgresqlMessageDecodingFailure(nel)) + actual must_== Left(new PostgresqlMessageDecodingFailure(nel)) } val testMessageMessage = { val xs = List((Severity, "Foo"), (ErrorNoticeMessageFields.Code, "Bar")) val actual = PostgresqlMessage.buildParamsFromTuples(xs) val nel = NonEmptyList("Required Message was not present.") - actual must_== Xor.Left(new PostgresqlMessageDecodingFailure(nel)) + actual must_== Left(new PostgresqlMessageDecodingFailure(nel)) } val testSeveritySqlStateCodeMessage = { @@ -265,7 +265,7 @@ final class PostgresqlMessageSpec extends Specification with ScalaCheck { def is val actual = PostgresqlMessage.buildParamsFromTuples(xs) val nel = NonEmptyList("Required Severity Level was not present.", "Required SQLSTATE Code was not present.") - actual must_== Xor.Left(new PostgresqlMessageDecodingFailure(nel)) + actual must_== Left(new PostgresqlMessageDecodingFailure(nel)) } val testSeverityMessageMessage = { @@ -273,7 +273,7 @@ final class PostgresqlMessageSpec extends Specification with ScalaCheck { def is val actual = PostgresqlMessage.buildParamsFromTuples(xs) val nel = NonEmptyList("Required Severity Level was not present.", "Required Message was not present.") - actual must_== Xor.Left(new PostgresqlMessageDecodingFailure(nel)) + actual must_== Left(new PostgresqlMessageDecodingFailure(nel)) } val testSqlStateCodeMessageMessage = { @@ -281,7 +281,7 @@ final class PostgresqlMessageSpec extends Specification with ScalaCheck { def is val actual = PostgresqlMessage.buildParamsFromTuples(xs) val nel = NonEmptyList("Required SQLSTATE Code was not present.", "Required Message was not present.") - actual must_== Xor.Left(new PostgresqlMessageDecodingFailure(nel)) + actual must_== Left(new PostgresqlMessageDecodingFailure(nel)) } val testNoRequiredFieldsFoundMessage = { @@ -290,7 +290,7 @@ final class PostgresqlMessageSpec extends Specification with ScalaCheck { def is val nel = NonEmptyList("Required Severity Level was not present.", "Required SQLSTATE Code was not present.", "Required Message was not present.") - actual must_== Xor.Left(new PostgresqlMessageDecodingFailure(nel)) + actual must_== Left(new PostgresqlMessageDecodingFailure(nel)) } } @@ -384,7 +384,7 @@ final class PostgresqlMessageSpec extends Specification with ScalaCheck { def is case class EM() extends ErrorNoticeGen { val testMessage = forAll(errMsgAndRequiredFieldsGen) { x: ErrorMessageAndRequiredFields => val xs = x.errorParams - val ys = List(("Detail: ", xs.detail), ("Hint: ", xs.hint), ("Position: ", xs.position), + val ys = List(("Detail: ", xs.detail), ("Hint: ", xs.hint), ("Position: ", xs.position), ("Internal Position: ", xs.internalPosition), ("Internal Query: ", xs.internalQuery), ("Where: ", xs.where), ("Schema Name: ", xs.schemaName), ("Table Name: ", xs.tableName), ("Column Name: ", xs.columnName), ("Data Type Name: ", xs.dataTypeName), ("Constaint Name: ", @@ -394,9 +394,8 @@ final class PostgresqlMessageSpec extends Specification with ScalaCheck { def is .foldLeft("")((x,y) => x + y._1 + y._2 + "\n") val requiredString = s"${xs.severity} - ${xs.message}. SQLSTATE: ${xs.code}." - val expectedMessage = requiredString + "\n" + optString + val expectedMessage = requiredString + "\n" + optString x.error.toString must_== expectedMessage } } } - diff --git a/core/src/test/scala/roc/postgresql/transport/PacketDecodersSpec.scala b/core/src/test/scala/roc/postgresql/transport/PacketDecodersSpec.scala index 38bbe1e..be397e9 100644 --- a/core/src/test/scala/roc/postgresql/transport/PacketDecodersSpec.scala +++ b/core/src/test/scala/roc/postgresql/transport/PacketDecodersSpec.scala @@ -2,7 +2,6 @@ package roc package postgresql package transport -import cats.data.Xor import java.nio.charset.StandardCharsets import org.scalacheck.Arbitrary.arbitrary import org.scalacheck.Prop.forAll @@ -17,52 +16,52 @@ import roc.postgresql.server.PostgresqlMessage final class PacketDecodersSpec extends Specification with ScalaCheck { def is = s2""" ErrorMessage - must return Xor.Right(ErrorMessage(PostgresqlMessage)) when given a valid Packet ${ErrorMsg().test} - must return Xor.Left(ErrorResponseDecodingFailure) when given an invalid Error Message ${ErrorMsg().testInvalid} - must return Xor.Left(PacketDecodingFailure) when given an invalid Packet ${ErrorMsg().testInvalidPacket} + must return Right(ErrorMessage(PostgresqlMessage)) when given a valid Packet ${ErrorMsg().test} + must return Left(ErrorResponseDecodingFailure) when given an invalid Error Message ${ErrorMsg().testInvalid} + must return Left(PacketDecodingFailure) when given an invalid Packet ${ErrorMsg().testInvalidPacket} CommandComplete - should return Xor.Right(CommandComplete) when given a valid Packet ${CmdComplete().test} - should return Xor.Left(PacketDecodingFailure) when given an invalid Packet ${CmdComplete().testInvalidPacket} + should return Right(CommandComplete) when given a valid Packet ${CmdComplete().test} + should return Left(PacketDecodingFailure) when given an invalid Packet ${CmdComplete().testInvalidPacket} ParameterStatus - should return Xor.Right(ParameterStatus) when given a valid Packet ${ParamStatus().test} - should return Xor.Left(PacketDecodingFailure) when given an invalid Packet ${ParamStatus().testInvalidPacket} + should return Right(ParameterStatus) when given a valid Packet ${ParamStatus().test} + should return Left(PacketDecodingFailure) when given an invalid Packet ${ParamStatus().testInvalidPacket} BackendKeyData - should return Xor.Right(BackendKeyData) when given a valid Packet ${BackendKey().test} - should return Xor.Left(PacketDecodingFailure) when given an invalid Packet ${BackendKey().testInvalidPacket} + should return Right(BackendKeyData) when given a valid Packet ${BackendKey().test} + should return Left(PacketDecodingFailure) when given an invalid Packet ${BackendKey().testInvalidPacket} ReadyForQuery - should return Xor.Right(ReadyForQuery) when given a valid Char ${RFQ().testValid} - should return Xor.Left(ReadyForQueryDecodingFailure) when given an invalid Char ${RFQ().testInvalidChar} - should return Xor.Left(PacketDecodingFailure) when given an invalid Packet ${RFQ().testInvalidPacket} + should return Right(ReadyForQuery) when given a valid Char ${RFQ().testValid} + should return Left(ReadyForQueryDecodingFailure) when given an invalid Char ${RFQ().testInvalidChar} + should return Left(PacketDecodingFailure) when given an invalid Packet ${RFQ().testInvalidPacket} DataRow - should return Xor.Right(DataRow) when given a valid Packet ${DR().testValidPacket} - should return Xor.Left(PacketDecodingFailure) when given an invalid Packet ${DR().testInvalidPacket} + should return Right(DataRow) when given a valid Packet ${DR().testValidPacket} + should return Left(PacketDecodingFailure) when given an invalid Packet ${DR().testInvalidPacket} RowDescription - should return Xor.Right(RowDescription) when given a valid Packet ${RD().testValidPacket} - should return Xor.Left(PacketDecodingFailure) when given an invalid Packet ${RD().testInvalidPacket} + should return Right(RowDescription) when given a valid Packet ${RD().testValidPacket} + should return Left(PacketDecodingFailure) when given an invalid Packet ${RD().testInvalidPacket} should have valid Message when decoding an unknown Format Code ${RD().testUnknownFormatCode} AuthenticationMessages - should return Xor.Right(AuthenticationMessage) when given a valid Message Int ${AM().testValid} - should return Xor.Left(UnknownAuthenticationRequestFailure) when given an Unknown Request ${AM().testUnkownRequestType} - should return Xor.Left(PacketDecodingFailure) when given an invalid Packet ${AM().testInvalidPacket} + should return Right(AuthenticationMessage) when given a valid Message Int ${AM().testValid} + should return Left(UnknownAuthenticationRequestFailure) when given an Unknown Request ${AM().testUnkownRequestType} + should return Left(PacketDecodingFailure) when given an invalid Packet ${AM().testInvalidPacket} NoticeResponseMessage - must return Xor.Right(NoticeResponse(PostgresqlMessage)) when given a valid Packet ${NR().test} - must return Xor.Left(PostgresqlMessageDecodingFailure) when given an invalid PostgresqlMessage ${NR().testInvalid} - must return Xor.Left(PacketDecodingFailure) when given an invalid Packet ${NR().testInvalidPacket} + must return Right(NoticeResponse(PostgresqlMessage)) when given a valid Packet ${NR().test} + must return Left(PostgresqlMessageDecodingFailure) when given an invalid PostgresqlMessage ${NR().testInvalid} + must return Left(PacketDecodingFailure) when given an invalid Packet ${NR().testInvalidPacket} """ case class ErrorMsg() extends generators.ErrorNoticePacketGen { val test = forAll(validErrorPacketContainerGen) { c: ErrorNoticePacketContainer => val message = PostgresqlMessage(c.fields).getOrElse(throw new Exception("Generator Failed")) - decodePacket[ErrorResponse](c.packet) must_== Xor.Right(ErrorResponse(message)) + decodePacket[ErrorResponse](c.packet) must_== Right(ErrorResponse(message)) } val testInvalid = forAll(invalidErrorPacketContainerGen) { c: ErrorNoticePacketContainer => @@ -72,100 +71,100 @@ final class PacketDecodersSpec extends Specification with ScalaCheck { def is = val testInvalidPacket = { val packet = Packet(Some(Message.CommandCompleteByte), Buffer(Array.empty[Byte])) decodePacket[ErrorResponse](packet) must_== - Xor.Left(new PacketDecodingFailure("Readable byte limit exceeded: 0")) + Left(new PacketDecodingFailure("Readable byte limit exceeded: 0")) } } case class CmdComplete() extends generators.CommandCompleteGen { val test = forAll(commandCompleteValidPacket) { (c: CommandCompleteContainer) => - decodePacket[CommandComplete](c.p) must_== Xor.Right(new CommandComplete(c.str)) + decodePacket[CommandComplete](c.p) must_== Right(new CommandComplete(c.str)) } def testInvalidPacket = { val packet = Packet(Some(Message.CommandCompleteByte), Buffer(Array.empty[Byte])) - decodePacket[CommandComplete](packet) must_== - Xor.Left(new PacketDecodingFailure("Readable byte limit exceeded: 0")) + decodePacket[CommandComplete](packet) must_== + Left(new PacketDecodingFailure("Readable byte limit exceeded: 0")) } } case class ParamStatus() extends generators.ParameterStatusGen { val test = forAll { (psc: ParameterStatusContainer) => - decodePacket[ParameterStatus](psc.packet) must_== - Xor.Right(new ParameterStatus(psc.param, psc.value)) + decodePacket[ParameterStatus](psc.packet) must_== + Right(new ParameterStatus(psc.param, psc.value)) } def testInvalidPacket = { val packet = Packet(Some(Message.CommandCompleteByte), Buffer(Array.empty[Byte])) - decodePacket[ParameterStatus](packet) must_== - Xor.Left(new PacketDecodingFailure("Readable byte limit exceeded: 0")) + decodePacket[ParameterStatus](packet) must_== + Left(new PacketDecodingFailure("Readable byte limit exceeded: 0")) } } case class BackendKey() extends generators.BackendKeyGen { val test = forAll { (bkdc: BackendKeyDataContainer) => val backendKeyData = new BackendKeyData(bkdc.processId, bkdc.secretKey) - decodePacket[BackendKeyData](bkdc.packet) must_== Xor.Right(backendKeyData) + decodePacket[BackendKeyData](bkdc.packet) must_== Right(backendKeyData) } def testInvalidPacket = { val packet = Packet(Some(Message.CommandCompleteByte), Buffer(Array.empty[Byte])) - decodePacket[BackendKeyData](packet) must_== - Xor.Left(new PacketDecodingFailure("Not enough readable bytes - Need 4, maximum is 0")) + decodePacket[BackendKeyData](packet) must_== + Left(new PacketDecodingFailure("Not enough readable bytes - Need 4, maximum is 0")) } } case class RFQ() extends generators.ReadyForQueryGen { val testValid = forAll(genValidReadyForQueryContainer) { (rfqc: ReadyForQueryContainer) => rfqc.transactionStatus match { - case 'I' => decodePacket[ReadyForQuery](rfqc.packet) must_== Xor.Right(Idle) - case 'T' => decodePacket[ReadyForQuery](rfqc.packet) must_== Xor.Right(TransactionBlock) - case 'E' => decodePacket[ReadyForQuery](rfqc.packet) must_== Xor.Right(FailedTransactionBlock) + case 'I' => decodePacket[ReadyForQuery](rfqc.packet) must_== Right(Idle) + case 'T' => decodePacket[ReadyForQuery](rfqc.packet) must_== Right(TransactionBlock) + case 'E' => decodePacket[ReadyForQuery](rfqc.packet) must_== Right(FailedTransactionBlock) case c => decodePacket[ReadyForQuery](rfqc.packet) must_== - Xor.Left(new ReadyForQueryDecodingFailure(rfqc.transactionStatus)) + Left(new ReadyForQueryDecodingFailure(rfqc.transactionStatus)) } } val testInvalidChar = forAll(genInvalidReadyForQueryContainer) { (rfqc: ReadyForQueryContainer) => decodePacket[ReadyForQuery](rfqc.packet) must_== - Xor.Left(new ReadyForQueryDecodingFailure(rfqc.transactionStatus)) + Left(new ReadyForQueryDecodingFailure(rfqc.transactionStatus)) } def testInvalidPacket = { val packet = Packet(Some(Message.ReadyForQueryByte), Buffer(Array.empty[Byte])) - decodePacket[ReadyForQuery](packet) must_== - Xor.Left(new PacketDecodingFailure("Readable byte limit exceeded: 0")) + decodePacket[ReadyForQuery](packet) must_== + Left(new PacketDecodingFailure("Readable byte limit exceeded: 0")) } } case class RD() extends generators.RowDescriptionGen { - val testValidPacket = forAll(genRowDescriptionContainer) { (rdc: RowDescriptionContainer) => + val testValidPacket = forAll(genRowDescriptionContainer) { (rdc: RowDescriptionContainer) => decodePacket[RowDescription](rdc.packet) must_== - Xor.Right(RowDescription(rdc.numFields, rdc.fields)) + Right(RowDescription(rdc.numFields, rdc.fields)) } def testInvalidPacket = { val packet = Packet(Some(Message.RowDescriptionByte), Buffer(Array.empty[Byte])) decodePacket[RowDescription](packet) must_== - Xor.Left(new PacketDecodingFailure("Not enough readable bytes - Need 2, maximum is 0")) + Left(new PacketDecodingFailure("Not enough readable bytes - Need 2, maximum is 0")) } - val testUnknownFormatCode = + val testUnknownFormatCode = forAll(genUnknownFormatCodeRDC) { (rdc: RowDescriptionFormatCodeContainer) => decodePacket[RowDescription](rdc.packet) must_== - Xor.Left(new PacketDecodingFailure(s"Unknown format code ${rdc.formatCode}.")) + Left(new PacketDecodingFailure(s"Unknown format code ${rdc.formatCode}.")) } } case class DR() extends generators.DataRowGen { val testValidPacket = forAll { x: DataRowContainer => - decodePacket[DataRow](x.packet) must_== Xor.Right(x.dataRow) + decodePacket[DataRow](x.packet) must_== Right(x.dataRow) } - val testInvalidPacket = { + val testInvalidPacket = { val packet = Packet(Some(Message.DataRowByte), Buffer(Array.empty[Byte])) decodePacket[DataRow](packet) must_== - Xor.Left(new PacketDecodingFailure("Not enough readable bytes - Need 2, maximum is 0")) + Left(new PacketDecodingFailure("Not enough readable bytes - Need 2, maximum is 0")) } } @@ -173,36 +172,36 @@ final class PacketDecodersSpec extends Specification with ScalaCheck { def is = val testValid = forAll(genValidAuthMessageContainer) { (amc: AuthMessageContainer) => val decodedPacket = decodePacket[AuthenticationMessage](amc.packet) amc.requestType match { - case 0 => decodedPacket must_== Xor.Right(AuthenticationOk) - case 2 => decodedPacket must_== Xor.Right(AuthenticationKerberosV5) - case 3 => decodedPacket must_== Xor.Right(AuthenticationClearTxtPasswd) - case 5 => decodedPacket must_== Xor.Right(new AuthenticationMD5Passwd(amc.salt)) - case 6 => decodedPacket must_== Xor.Right(AuthenticationSCMCredential) - case 7 => decodedPacket must_== Xor.Right(AuthenticationGSS) - case 8 => decodedPacket must_== Xor.Right(new AuthenticationGSSContinue(amc.authBytes)) - case 9 => decodedPacket must_== Xor.Right(AuthenticationSSPI) + case 0 => decodedPacket must_== Right(AuthenticationOk) + case 2 => decodedPacket must_== Right(AuthenticationKerberosV5) + case 3 => decodedPacket must_== Right(AuthenticationClearTxtPasswd) + case 5 => decodedPacket must_== Right(new AuthenticationMD5Passwd(amc.salt)) + case 6 => decodedPacket must_== Right(AuthenticationSCMCredential) + case 7 => decodedPacket must_== Right(AuthenticationGSS) + case 8 => decodedPacket must_== Right(new AuthenticationGSSContinue(amc.authBytes)) + case 9 => decodedPacket must_== Right(AuthenticationSSPI) case x => decodedPacket must_== - Xor.Left(new UnknownAuthenticationRequestFailure(amc.requestType)) + Left(new UnknownAuthenticationRequestFailure(amc.requestType)) } } - val testUnkownRequestType = + val testUnkownRequestType = forAll(genUnknownReqAuthMessageContainer) { (amc: AuthMessageContainer) => decodePacket[AuthenticationMessage](amc.packet) must_== - Xor.Left(new UnknownAuthenticationRequestFailure(amc.requestType)) + Left(new UnknownAuthenticationRequestFailure(amc.requestType)) } val testInvalidPacket = { val packet = Packet(Some(Message.AuthenticationMessageByte), Buffer(Array.empty[Byte])) decodePacket[AuthenticationMessage](packet) must_== - Xor.Left(new PacketDecodingFailure("Not enough readable bytes - Need 4, maximum is 0")) + Left(new PacketDecodingFailure("Not enough readable bytes - Need 4, maximum is 0")) } } case class NR() extends generators.ErrorNoticePacketGen { val test = forAll(validNoticePacketContainerGen) { c: ErrorNoticePacketContainer => val message = PostgresqlMessage(c.fields).getOrElse(throw new Exception("Generator Failed")) - decodePacket[NoticeResponse](c.packet) must_== Xor.Right(NoticeResponse(message)) + decodePacket[NoticeResponse](c.packet) must_== Right(NoticeResponse(message)) } val testInvalid = forAll(invalidNoticePacketContainerGen) { c: ErrorNoticePacketContainer => @@ -212,7 +211,7 @@ final class PacketDecodersSpec extends Specification with ScalaCheck { def is = val testInvalidPacket = { val packet = Packet(Some(Message.NoticeResponseByte), Buffer(Array.empty[Byte])) decodePacket[NoticeResponse](packet) must_== - Xor.Left(new PacketDecodingFailure("Readable byte limit exceeded: 0")) + Left(new PacketDecodingFailure("Readable byte limit exceeded: 0")) } } } diff --git a/types/src/main/scala/roc/types/decoders.scala b/types/src/main/scala/roc/types/decoders.scala index d23c8dd..2928021 100644 --- a/types/src/main/scala/roc/types/decoders.scala +++ b/types/src/main/scala/roc/types/decoders.scala @@ -1,7 +1,7 @@ package roc package types -import cats.data.{Validated, Xor} +import cats.data.Validated import io.netty.buffer.Unpooled import java.nio.ByteBuffer import java.nio.charset.StandardCharsets @@ -14,7 +14,7 @@ import roc.types.failures._ object decoders { - implicit def optionElementDecoder[A](implicit f: ElementDecoder[A]) = + implicit def optionElementDecoder[A](implicit f: ElementDecoder[A]) = new ElementDecoder[Option[A]] { def textDecoder(text: String): Option[A] = Some(f.textDecoder(text)) def binaryDecoder(bytes: Array[Byte]): Option[A] = Some(f.binaryDecoder(bytes)) @@ -28,13 +28,13 @@ object decoders { } implicit val shortElementDecoder: ElementDecoder[Short] = new ElementDecoder[Short] { - def textDecoder(text: String): Short = Xor.catchNonFatal( + def textDecoder(text: String): Short = Either.catchNonFatal( text.toShort ).fold( {l => throw new ElementDecodingFailure("SHORT", l)}, {r => r} ) - def binaryDecoder(bytes: Array[Byte]): Short = Xor.catchNonFatal({ + def binaryDecoder(bytes: Array[Byte]): Short = Either.catchNonFatal({ val buffer = Unpooled.buffer(2) buffer.writeBytes(bytes.take(2)) buffer.readShort @@ -46,13 +46,13 @@ object decoders { } implicit val intElementDecoder: ElementDecoder[Int] = new ElementDecoder[Int] { - def textDecoder(text: String): Int = Xor.catchNonFatal( + def textDecoder(text: String): Int = Either.catchNonFatal( text.toInt ).fold( {l => throw new ElementDecodingFailure("INT", l)}, {r => r} ) - def binaryDecoder(bytes: Array[Byte]): Int = Xor.catchNonFatal({ + def binaryDecoder(bytes: Array[Byte]): Int = Either.catchNonFatal({ val buffer = Unpooled.buffer(4) buffer.writeBytes(bytes.take(4)) buffer.readInt @@ -64,13 +64,13 @@ object decoders { } implicit val longElementDecoder: ElementDecoder[Long] = new ElementDecoder[Long] { - def textDecoder(text: String): Long = Xor.catchNonFatal( + def textDecoder(text: String): Long = Either.catchNonFatal( text.toLong ).fold( {l => throw new ElementDecodingFailure("LONG", l)}, {r => r} ) - def binaryDecoder(bytes: Array[Byte]): Long = Xor.catchNonFatal({ + def binaryDecoder(bytes: Array[Byte]): Long = Either.catchNonFatal({ val buffer = Unpooled.buffer(8) buffer.writeBytes(bytes.take(8)) buffer.readLong @@ -78,17 +78,17 @@ object decoders { {l => throw new ElementDecodingFailure("LONG", l)}, {r => r} ) - def nullDecoder: Long = throw new NullDecodedFailure("LONG") + def nullDecoder: Long = throw new NullDecodedFailure("LONG") } implicit val floatElementDecoder: ElementDecoder[Float] = new ElementDecoder[Float] { - def textDecoder(text: String): Float = Xor.catchNonFatal( + def textDecoder(text: String): Float = Either.catchNonFatal( text.toFloat ).fold( {l => throw new ElementDecodingFailure("FLOAT", l)}, {r => r} ) - def binaryDecoder(bytes: Array[Byte]): Float = Xor.catchNonFatal({ + def binaryDecoder(bytes: Array[Byte]): Float = Either.catchNonFatal({ val buffer = Unpooled.buffer(4) buffer.writeBytes(bytes.take(4)) buffer.readFloat @@ -100,13 +100,13 @@ object decoders { } implicit val doubleElementDecoder: ElementDecoder[Double] = new ElementDecoder[Double] { - def textDecoder(text: String): Double = Xor.catchNonFatal( + def textDecoder(text: String): Double = Either.catchNonFatal( text.toDouble ).fold( {l => throw new ElementDecodingFailure("DOUBLE", l)}, {r => r} ) - def binaryDecoder(bytes: Array[Byte]): Double = Xor.catchNonFatal({ + def binaryDecoder(bytes: Array[Byte]): Double = Either.catchNonFatal({ val buffer = Unpooled.buffer(8) buffer.writeBytes(bytes.take(8)) buffer.readDouble @@ -118,14 +118,14 @@ object decoders { } implicit val booleanElementDecoder: ElementDecoder[Boolean] = new ElementDecoder[Boolean] { - def textDecoder(text: String): Boolean = Xor.catchNonFatal(text.head match { + def textDecoder(text: String): Boolean = Either.catchNonFatal(text.head match { case 't' => true case 'f' => false }).fold( {l => throw new ElementDecodingFailure("BOOLEAN", l)}, {r => r} ) - def binaryDecoder(bytes: Array[Byte]): Boolean = Xor.catchNonFatal(bytes.head match { + def binaryDecoder(bytes: Array[Byte]): Boolean = Either.catchNonFatal(bytes.head match { case 0x00 => false case 0x01 => true }).fold( @@ -136,11 +136,11 @@ object decoders { } implicit val charElementDecoder: ElementDecoder[Char] = new ElementDecoder[Char] { - def textDecoder(text: String): Char = Xor.catchNonFatal(text.head.toChar).fold( + def textDecoder(text: String): Char = Either.catchNonFatal(text.head.toChar).fold( {l => throw new ElementDecodingFailure("CHAR", l)}, {r => r} ) - def binaryDecoder(bytes: Array[Byte]): Char = Xor.catchNonFatal(bytes.head.toChar).fold( + def binaryDecoder(bytes: Array[Byte]): Char = Either.catchNonFatal(bytes.head.toChar).fold( {l => throw new ElementDecodingFailure("CHAR", l)}, {r => r} ) @@ -165,11 +165,11 @@ object decoders { } implicit val dateElementDecoders: ElementDecoder[Date] = new ElementDecoder[Date] { - def textDecoder(text: String): Date = Xor.catchNonFatal(LocalDate.parse(text)).fold( + def textDecoder(text: String): Date = Either.catchNonFatal(LocalDate.parse(text)).fold( {l => throw new ElementDecodingFailure("DATE", l)}, {r => r} ) - def binaryDecoder(bytes: Array[Byte]): Date = Xor.catchNonFatal({ + def binaryDecoder(bytes: Array[Byte]): Date = Either.catchNonFatal({ val text = new String(bytes, StandardCharsets.UTF_8) LocalDate.parse(text) }).fold( @@ -180,11 +180,11 @@ object decoders { } implicit val localTimeElementDecoders: ElementDecoder[Time] = new ElementDecoder[Time] { - def textDecoder(text: String): Time = Xor.catchNonFatal(LocalTime.parse(text)).fold( + def textDecoder(text: String): Time = Either.catchNonFatal(LocalTime.parse(text)).fold( {l => throw new ElementDecodingFailure("TIME", l)}, {r => r} ) - def binaryDecoder(bytes: Array[Byte]): Time = Xor.catchNonFatal({ + def binaryDecoder(bytes: Array[Byte]): Time = Either.catchNonFatal({ val text = new String(bytes, StandardCharsets.UTF_8) LocalTime.parse(text) }).fold( @@ -194,20 +194,20 @@ object decoders { def nullDecoder: Time = throw new NullDecodedFailure("TIME") } - implicit val zonedDateTimeElementDecoders: ElementDecoder[TimestampWithTZ] = + implicit val zonedDateTimeElementDecoders: ElementDecoder[TimestampWithTZ] = new ElementDecoder[TimestampWithTZ] { private val zonedDateTimeFmt = new DateTimeFormatterBuilder() .appendPattern("yyyy-MM-dd HH:mm:ss") .appendFraction(ChronoField.MICRO_OF_SECOND, 0, 6, true) .appendOptional(DateTimeFormatter.ofPattern("X")) .toFormatter() - def textDecoder(text: String): TimestampWithTZ = Xor.catchNonFatal({ + def textDecoder(text: String): TimestampWithTZ = Either.catchNonFatal({ ZonedDateTime.parse(text, zonedDateTimeFmt) }).fold( {l => throw new ElementDecodingFailure("TIMESTAMP WITH TIME ZONE", l)}, {r => r} ) - def binaryDecoder(bytes: Array[Byte]): TimestampWithTZ = Xor.catchNonFatal({ + def binaryDecoder(bytes: Array[Byte]): TimestampWithTZ = Either.catchNonFatal({ val text = new String(bytes, StandardCharsets.UTF_8) ZonedDateTime.parse(text, zonedDateTimeFmt) }).fold(