diff --git a/modules/core/src/main/scala/io/github/irevive/union/derivation/UnionDerivation.scala b/modules/core/src/main/scala/io/github/irevive/union/derivation/UnionDerivation.scala index 0339dac..85cbb49 100644 --- a/modules/core/src/main/scala/io/github/irevive/union/derivation/UnionDerivation.scala +++ b/modules/core/src/main/scala/io/github/irevive/union/derivation/UnionDerivation.scala @@ -15,34 +15,132 @@ object UnionDerivation { import quotes.reflect.* def deriveImpl[A: Type]: Expr[F[A]] = { + given Diagnostic = Diagnostic(TypeRepr.of[F], TypeRepr.of[A].dealias) + val tpe: TypeRepr = TypeRepr.of[A] tpe.dealias match { case o: OrType => + val paramType = findParamType val abstractMethod = findAbstractMethod val collectedTypes = collectTypes(o) - val mt = MethodType(List("a"))(_ => List(tpe), _ => abstractMethod.returnTpt.tpe) + val params = collectParams(abstractMethod, paramType.tpe) + + val lambdaType = MethodType(params.map(_.name))( + _ => params.map(p => if (p.isPoly) tpe else p.typeRepr), + _ => abstractMethod.returnTpt.tpe + ) - val lambda = - Lambda(Symbol.spliceOwner, mt, (meth, arg) => body(arg.head.asExprOf[A], collectedTypes, abstractMethod.name)) + val lambda = Lambda( + Symbol.spliceOwner, + lambdaType, + (_, args) => body(collectedTypes, params, args, abstractMethod.name) + ) // transform lambda to an instance of the typeclass val instanceTree = lambda match { - case Block(body, Closure(meth, _)) => - Block(body, Closure(meth, Some(TypeRepr.of[F].appliedTo(tpe)))) + case Block(body, Closure(method, _)) => + Block(body, Closure(method, Some(TypeRepr.of[F].appliedTo(tpe)))) } instanceTree.asExprOf[F[A]] case other => - report.errorAndAbort(s"Cannot derive a typeclass for the ${tpe.show}. Only Union type is supported") + errorAndAbort("only Union type is supported.") } } + private final case class Diagnostic(typeclass: TypeRepr, targetType: TypeRepr) + + private final case class MethodParam( + name: String, + typeRepr: TypeRepr, + isPoly: Boolean // whether param appear in the polymorphic position, e.g. (a: A) + ) + + private def collectParams(method: DefDef, paramType: TypeRepr)(using Diagnostic): List[MethodParam] = + method.paramss match { + case TermParamClause(params) :: Nil => + val all = params.map { param => + MethodParam(param.name, param.tpt.tpe, param.tpt.tpe == paramType) + } + + val typed = all.filter(_.isPoly) + + if (typed.size == 1) { + all + } else if (typed.isEmpty) { + errorAndAbort( + "the abstract method without the polymorphic param isn't supported.", + Some( + """check the example below where the instance cannot be derived + | + |trait Typeclass[A] { + | def magic(a: Int): String + | // ^ + | // Polymorphic param of type A is missing + |}""".stripMargin + ) + ) + } else { + errorAndAbort( + s"the abstract method has multiple polymorphic params of the same parametrized type: ${typed.map(_.name).mkString(", ")}.", + Some("""check the example below where the instance cannot be derived + | + |trait Typeclass[A] { + | def magic(a1: A, b: Int, a2: A): String + | // ^ ^ + | // Polymorphic type A appears in two positions + |}""".stripMargin) + ) + } + + case Nil => + errorAndAbort( + "the abstract method without the polymorphic param isn't supported.", + Some( + """check the example below where the instance cannot be derived + | + |trait Typeclass[A] { + | def magic: String + | // ^ + | // Polymorphic param of type A is missing + |}""".stripMargin + ) + ) + + case _ => + errorAndAbort( + "the curried abstract method isn't supported.", + Some( + """check the example below where the instance cannot be derived + | + |trait Typeclass[A] { + | def magic(a: A)(b: Int): String + | // ^ + | // Curried functions aren't supported + |}""".stripMargin + ) + ) + } + + // required exactly one type param + private def findParamType(using Diagnostic): TypeTree = + TypeRepr.of[F].typeSymbol.declaredTypes match { + case head :: Nil => + TypeIdent(head) + + case Nil => + errorAndAbort("The typeclass doesn't have a type parameter") + + case _ => + errorAndAbort("The typeclass has multiple type parameters") + } + /** * Looks-up for an abstract method in F[_] */ - private def findAbstractMethod: DefDef = { + private def findAbstractMethod(using Diagnostic): DefDef = { val tcl: TypeRepr = TypeRepr.of[F] val methods = tcl.typeSymbol.declaredMethods.filter(_.isDefDef).map(_.tree).collect { @@ -51,16 +149,17 @@ object UnionDerivation { methods match { case Nil => - report.errorAndAbort( - s"""Cannot detect an abstract method in ${tcl.typeSymbol}. `scalacOptions += "-Yretain-trees"` may solve the issue""" + errorAndAbort( + "cannot detect an abstract method in the typeclass.", + Some("""`scalacOptions += "-Yretain-trees"` may solve the issue.""") ) case head :: Nil => head case other => - report.errorAndAbort( - s"More than one abstract method detected in ${tcl.typeSymbol}: ${other.map(_.name).mkString(", ")}. Automatic derivation is impossible" + errorAndAbort( + s"more than one abstract method is detected: ${other.map(_.name).mkString(", ")}." ) } } @@ -70,33 +169,50 @@ object UnionDerivation { * * The * {{{ - * if (value.isInstanceOf[Int]) summon[Show[Int]].show(value) - * else if (value.isInstanceOf[String]) summon[Show[String]].show(value) + * if (value.isInstanceOf[Int]) summon[Typeclass[Int]].magic(value) + * else if (value.isInstanceOf[String]) summon[Typeclass[String]].magic(value) * else sys.error("Impossible") // impossible state * }}} * - * @param t - * the input value of the method * @param knownTypes * the known member types of the union + * @param params + * the list of function parameter + * @param lambdaArgs + * the list of lambda args * @param method * the name of the typeclass method to apply * @tparam A * the input type - * @tparam R - * the output type of the method - * @return */ - private def body[A](t: Expr[A], knownTypes: List[TypeRepr], method: String): Term = { - val selector: Term = t.asTerm + private def body[A: Type]( + knownTypes: List[TypeRepr], + params: List[MethodParam], + lambdaArgs: List[Tree], + method: String + )(using Diagnostic): Term = { + + val selector: Term = params + .zip(lambdaArgs) + .collectFirst { case (param, arg) if param.isPoly => arg } + .getOrElse(errorAndAbort("cannot find poly param in the list of lambda arguments.")) + .asExprOf[A] + .asTerm val ifBranches: List[(Term, Term)] = knownTypes.map { tpe => - val identifier = TypeIdent(tpe.typeSymbol) - val condition = TypeApply(Select.unique(selector, "isInstanceOf"), identifier :: Nil) - val tcl = lookupImplicit(tpe) - val castedValue = Select.unique(selector, "asInstanceOf").appliedToType(tpe) + val identifier = TypeIdent(tpe.typeSymbol) + val condition = TypeApply(Select.unique(selector, "isInstanceOf"), identifier :: Nil) + val tcl = lookupImplicit(tpe) - val action: Term = Apply(Select.unique(tcl, method), castedValue :: Nil) + val args: List[Term] = params.zip(lambdaArgs).map { + case (param, arg) if param.isPoly => + Select.unique(selector, "asInstanceOf").appliedToType(tpe) + + case (_, arg) => + arg.asExpr.asTerm + } + + val action: Term = Select.unique(tcl, method).appliedToArgs(args) (condition, action) } @@ -122,12 +238,12 @@ object UnionDerivation { /** * Looks-up for an instance of `F[A]` for the provided type */ - private def lookupImplicit(t: TypeRepr): Term = { + private def lookupImplicit(t: TypeRepr)(using Diagnostic): Term = { val typeclassTpe = TypeRepr.of[F] val tclTpe = typeclassTpe.appliedTo(t) Implicits.search(tclTpe) match { case success: ImplicitSearchSuccess => success.tree - case failure: ImplicitSearchFailure => report.errorAndAbort(failure.explanation) + case failure: ImplicitSearchFailure => errorAndAbort(failure.explanation) } } @@ -141,5 +257,11 @@ object UnionDerivation { case Nil => ('{ throw RuntimeException("Unhandled condition encountered during derivation") }).asTerm } + + private def errorAndAbort(reason: String, hint: Option[String] = None)(using d: Diagnostic): Nothing = + report.errorAndAbort( + s"""UnionDerivation cannot derive an instance of ${d.typeclass.typeSymbol} for the type `${d.targetType.show}`. + |Reason: $reason""".stripMargin + hint.map(fix => s"\nHint: $fix").getOrElse("") + "\n\n" + ) } } diff --git a/modules/core/src/test/scala/io/github/irevive/union/derivation/ShowDerivationSuite.scala b/modules/core/src/test/scala/io/github/irevive/union/derivation/ShowDerivationSuite.scala index bc30773..e4de397 100644 --- a/modules/core/src/test/scala/io/github/irevive/union/derivation/ShowDerivationSuite.scala +++ b/modules/core/src/test/scala/io/github/irevive/union/derivation/ShowDerivationSuite.scala @@ -24,7 +24,10 @@ class ShowDerivationSuite extends munit.FunSuite { test("fail derivation for a non-union type") { val expected = """ - |error: Cannot derive a typeclass for the scala.Int. Only Union type is supported + |error: + |UnionDerivation cannot derive an instance of trait Show for the type `scala.Int`. + |Reason: only Union type is supported. + | | assertNoDiff(compileErrors("Show.deriveUnion[Int]"), expected) | ^ | @@ -36,7 +39,10 @@ class ShowDerivationSuite extends munit.FunSuite { test("fail derivation if an instance of a typeclass is missing for a member type") { val expected = """ - |error: no implicit values were found that match type io.github.irevive.union.derivation.ShowDerivationSuite.Show[Double] + |error: + |UnionDerivation cannot derive an instance of trait Show for the type `scala.Int | scala.Predef.String | scala.Double`. + |Reason: no implicit values were found that match type io.github.irevive.union.derivation.ShowDerivationSuite.Show[Double] + | | assertNoDiff(compileErrors("Show.deriveUnion[Int | String | Double]"), expected) | ^ |""".stripMargin diff --git a/modules/core/src/test/scala/io/github/irevive/union/derivation/UnionDerivationSuite.scala b/modules/core/src/test/scala/io/github/irevive/union/derivation/UnionDerivationSuite.scala index c4ae0d8..b599eac 100644 --- a/modules/core/src/test/scala/io/github/irevive/union/derivation/UnionDerivationSuite.scala +++ b/modules/core/src/test/scala/io/github/irevive/union/derivation/UnionDerivationSuite.scala @@ -9,6 +9,22 @@ class UnionDerivationSuite extends munit.FunSuite { def show(a: A): String } + trait MultipleParamsSameType[A] { + def magic(a: A, b: Int, c: A): Int + } + + trait NoParams[A] { + def magic: A + } + + trait UnusedTypeParam[A] { + def magic(a: Int): String + } + + trait Curried[A] { + def magic(a: A)(b: Int): String + } + trait Typeclass[A] { def magic(a: A): Int } @@ -29,7 +45,10 @@ class UnionDerivationSuite extends munit.FunSuite { test("fail derivation for a non-union type") { val expected = """ - |error: Cannot derive a typeclass for the scala.Int. Only Union type is supported + |error: + |UnionDerivation cannot derive an instance of trait Typeclass for the type `scala.Int`. + |Reason: only Union type is supported. + | | assertNoDiff(compileErrors("UnionDerivation.derive[Typeclass, Int]"), expected) | ^ | @@ -41,7 +60,10 @@ class UnionDerivationSuite extends munit.FunSuite { test("fail derivation if an instance of a typeclass is missing for a member type") { val expected = """ - |error: no implicit values were found that match type UnionDerivationSuite.this.Typeclass[Double] + |error: + |UnionDerivation cannot derive an instance of trait Typeclass for the type `scala.Int | scala.Predef.String | scala.Double`. + |Reason: no implicit values were found that match type UnionDerivationSuite.this.Typeclass[Double] + | | assertNoDiff(compileErrors("UnionDerivation.derive[Typeclass, Int | String | Double]"), expected) | ^ |""".stripMargin @@ -52,7 +74,10 @@ class UnionDerivationSuite extends munit.FunSuite { test("fail derivation if a typeclass has more than one abstract methods") { val expected = """ - |error: More than one abstract method detected in trait MultipleMethods: magic, show. Automatic derivation is impossible + |error: + |UnionDerivation cannot derive an instance of trait MultipleMethods for the type `scala.Int | scala.Predef.String`. + |Reason: more than one abstract method is detected: magic, show. + | | assertNoDiff(compileErrors("UnionDerivation.derive[MultipleMethods, Int | String]"), expected) | ^ |""".stripMargin @@ -63,7 +88,11 @@ class UnionDerivationSuite extends munit.FunSuite { test("fail derivation if a typeclass does not have abstract methods") { val expected = """ - |error: Cannot detect an abstract method in trait SimpleTrait. `scalacOptions += "-Yretain-trees"` may solve the issue + |error: + |UnionDerivation cannot derive an instance of trait SimpleTrait for the type `scala.Predef.String | scala.Int`. + |Reason: cannot detect an abstract method in the typeclass. + |Hint: `scalacOptions += "-Yretain-trees"` may solve the issue. + | | assertNoDiff(compileErrors("UnionDerivation.derive[SimpleTrait, String | Int]"), expected) | ^ |""".stripMargin @@ -71,4 +100,118 @@ class UnionDerivationSuite extends munit.FunSuite { assertNoDiff(compileErrors("UnionDerivation.derive[SimpleTrait, String | Int]"), expected) } + test("fail derivation if a typeclass function has multiple polymorphic params of the same type") { + val expected = + """ + |error: + |UnionDerivation cannot derive an instance of trait MultipleParamsSameType for the type `scala.Predef.String | scala.Int`. + |Reason: the abstract method has multiple polymorphic params of the same parametrized type: a, c. + |Hint: check the example below where the instance cannot be derived + | + |trait Typeclass[A] { + | def magic(a1: A, b: Int, a2: A): String + | // ^ ^ + | // Polymorphic type A appears in two positions + |} + | + | assertNoDiff(compileErrors("UnionDerivation.derive[MultipleParamsSameType, String | Int]"), expected) + | ^ + |""".stripMargin + + assertNoDiff(compileErrors("UnionDerivation.derive[MultipleParamsSameType, String | Int]"), expected) + } + + test("fail derivation if a typeclass function doesn't have params") { + val expected = + """ + |error: + |UnionDerivation cannot derive an instance of trait NoParams for the type `scala.Predef.String | scala.Int`. + |Reason: the abstract method without the polymorphic param isn't supported. + |Hint: check the example below where the instance cannot be derived + | + |trait Typeclass[A] { + | def magic: String + | // ^ + | // Polymorphic param of type A is missing + |} + | + | assertNoDiff(compileErrors("UnionDerivation.derive[NoParams, String | Int]"), expected) + | ^ + |""".stripMargin + + assertNoDiff(compileErrors("UnionDerivation.derive[NoParams, String | Int]"), expected) + } + + test("fail derivation if a typeclass function doesn't use type parameter") { + val expected = + """ + |error: + |UnionDerivation cannot derive an instance of trait UnusedTypeParam for the type `scala.Predef.String | scala.Int`. + |Reason: the abstract method without the polymorphic param isn't supported. + |Hint: check the example below where the instance cannot be derived + | + |trait Typeclass[A] { + | def magic(a: Int): String + | // ^ + | // Polymorphic param of type A is missing + |} + | + | assertNoDiff(compileErrors("UnionDerivation.derive[UnusedTypeParam, String | Int]"), expected) + | ^ + |""".stripMargin + + assertNoDiff(compileErrors("UnionDerivation.derive[UnusedTypeParam, String | Int]"), expected) + } + + test("fail derivation if a typeclass function is curried") { + val expected = + """ + |error: + |UnionDerivation cannot derive an instance of trait Curried for the type `scala.Predef.String | scala.Int`. + |Reason: the curried abstract method isn't supported. + |Hint: check the example below where the instance cannot be derived + | + |trait Typeclass[A] { + | def magic(a: A)(b: Int): String + | // ^ + | // Curried functions aren't supported + |} + | + | assertNoDiff(compileErrors("UnionDerivation.derive[Curried, String | Int]"), expected) + | ^ + |""".stripMargin + + assertNoDiff(compileErrors("UnionDerivation.derive[Curried, String | Int]"), expected) + } + + test("multi param - derive a typeclass for a union type") { + trait MultipleParams[A] { + def multipleParams(a: A, b: String, c: Int): String + } + + given MultipleParams[Int] = (a, b, c) => a.toString + "->" + b + "->" + c + given MultipleParams[String] = (a, b, c) => a + "=>" + b + "=>" + c + + type UnionType = Int | String + val unionTypeGiven: MultipleParams[UnionType] = UnionDerivation.derive[MultipleParams, UnionType] + + assertEquals(unionTypeGiven.multipleParams(1, "!", 42), "1->!->42") + assertEquals(unionTypeGiven.multipleParams("some-string-value", "?", 42), "some-string-value=>?=>42") + } + + test("multi param - polymorphic param in the end") { + trait MultipleParams[A] { + def multipleParams(b: String, c: Int, a: A): String + } + + given MultipleParams[Int] = (a, b, c) => a + "->" + b + "->" + c.toString + given MultipleParams[String] = (a, b, c) => a + "=>" + b + "=>" + c + + type UnionType = Int | String + val unionTypeGiven: MultipleParams[UnionType] = UnionDerivation.derive[MultipleParams, UnionType] + + assertEquals(unionTypeGiven.multipleParams("!", 42, 1), "!->42->1") + assertEquals(unionTypeGiven.multipleParams("?", 42, "@"), "?=>42=>@") + } + }