From be58b8334ba49974764e6f567234aae14fa3c9e3 Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Tue, 9 Jul 2024 17:30:55 +0200 Subject: [PATCH] Implement a less brain-rot version of dropUnusedMap The new implementation doesn't require an `AvocADO` instance to work --- avocADO/src/main/scala/ado.scala | 29 ++++ avocADO/src/main/scala/macros.scala | 128 +++++++++++++-- avocADO/src/test/scala/DropMapTest.scala | 190 +++++++++++++++++++++++ 3 files changed, 333 insertions(+), 14 deletions(-) create mode 100644 avocADO/src/test/scala/DropMapTest.scala diff --git a/avocADO/src/main/scala/ado.scala b/avocADO/src/main/scala/ado.scala index 3c5cb5e..2a8a948 100644 --- a/avocADO/src/main/scala/ado.scala +++ b/avocADO/src/main/scala/ado.scala @@ -73,3 +73,32 @@ trait AvocADO[F[_]] { def zip[A, B](fa: F[A], fb: F[B]): F[(A, B)] def flatMap[A, B](fa: F[A], f: A => F[B]): F[B] } + +/** + * Drops unused trailing map calls in a for-comprehension. Helps with making for-comprehensions stack-safe. + * Example usage: + * ```scala + * dropUnusedMap { + * for { + * a <- doSth() + * _ <- doSideEffectAndReturnUnit(a) + * } yield () + * } + * ``` + * + * The above code will be transformed to code essentially equivalent to: + * ```scala + * doSth().flatMap(a => doSideEffectAndReturnUnit(a)) + * ``` + * + * instead of the normal for-comprehension desugaring: + * ```scala + * doSth().map(a => doSideEffectAndReturnUnit(a)).map(_ => ()) + * ``` + * + * Handled cases: + * - returning `()` from the for-comprehension, when the last generator expression also binds to `Unit` + * - returning the same variable reference as the last generator expression + */ +inline def dropUnusedMap[F[_], A](inline comp: F[A]): F[A] = + ${ macros.dropUnusedMapImpl[F, A]('comp) } diff --git a/avocADO/src/main/scala/macros.scala b/avocADO/src/main/scala/macros.scala index bfddda3..08be51e 100644 --- a/avocADO/src/main/scala/macros.scala +++ b/avocADO/src/main/scala/macros.scala @@ -9,6 +9,9 @@ private[avocado] object macros { def adoImpl[F[_]: Type, A: Type](compExpr: Expr[F[A]], instanceExpr: Expr[AvocADO[F]])(using Quotes): Expr[F[A]] = new ADOImpl(using quotes).adoImpl(compExpr, instanceExpr) + def dropUnusedMapImpl[F[_]: Type, A: Type](compExpr: Expr[F[A]])(using Quotes): Expr[F[A]] = + new ADOImpl(using quotes).dropUnusedMapImpl1(compExpr) + class ADOImpl(using Quotes) { import quotes.reflect.* @@ -34,7 +37,56 @@ private[avocado] object macros { private def ctx(using context: Context): Context = context + def dropUnusedMapImpl[F[_]: Type, A: Type](compExpr: Expr[F[A]], instanceExpr: Expr[AvocADO[F]])(using Quotes): Expr[F[A]] = { + implCommon(compExpr, instanceExpr)(doPar = false, doDropMap = true) + } + + def dropUnusedMapImpl1[F[_]: Type, A: Type](compExpr: Expr[F[A]]): Expr[F[A]] = { + given Context = Context(compExpr.asTerm, TypeRepr.of[F]) // This instance is wrong on purpose, with the assumption that it won't be used + def doDrop(expr: Term, arg: Term): Option[Term] = { + arg match { + case Lambda(List(param), body) + if isConstUnitBody(param, body) && expr.tpe.widen <:< ctx.fTpe.appliedTo(TypeRepr.of[Unit]) => + Some(expr) + case Lambda(List(param), body) + if isIdentityBody(param, body) => + Some(expr) + case _ => + None + } + } + def isConstUnitBody(param: ValDef, tree: Tree): Boolean = tree match { + case Block(Nil, body) => isConstUnitBody(param, body) + case Match(scrutinee, List(CaseDef(_, _, body))) if scrutinee.symbol == param.symbol => isConstUnitBody(param, body) + case Literal(UnitConstant()) => true + case _ => false + } + def isIdentityBody(param: ValDef, tree: Tree): Boolean = tree match { + case Block(Nil, body) => isIdentityBody(param, body) + case Ident(name) if name == param.name => true + case _ => false + } + object dropUnusedMapMap extends TreeMap { + override def transformTerm(tree: Term)(owner: Symbol): Term = tree match { + case NormalAllowed(expr, methodName, typeArgs, arg) if methodName == "map" => + doDrop(expr, arg).fold(super.transformTerm(tree)(owner))(identity) + case WithImplicitsAllowed(expr, args, methodName, typeArgs, arg) if methodName == "map" => + doDrop(expr, arg).fold(super.transformTerm(tree)(owner))(identity) + case FromTypeclassAllowed(expr, evidences, methodName, typeArgs, arg) if methodName == "map" => + doDrop(expr, arg).fold(super.transformTerm(tree)(owner))(identity) + case _ => + super.transformTerm(tree)(owner) + } + } + + dropUnusedMapMap.transformTerm(compExpr.asTerm)(Symbol.spliceOwner).asExprOf[F[A]] + } + def adoImpl[F[_]: Type, A: Type](compExpr: Expr[F[A]], instanceExpr: Expr[AvocADO[F]])(using Quotes): Expr[F[A]] = { + implCommon(compExpr, instanceExpr)(doPar = true, doDropMap = false) + } + + def implCommon[F[_]: Type, A: Type](compExpr: Expr[F[A]], instanceExpr: Expr[AvocADO[F]])(doPar: Boolean, doDropMap: Boolean)(using Quotes): Expr[F[A]] = { val exprTree = compExpr.asTerm match case Inlined(_, _, tree) => tree match case Block(Nil, expr) => expr @@ -51,30 +103,70 @@ private[avocado] object macros { case binding => (binding, getBindingDependencies(binding.tree, bindingVals)) } - connectBindings(bindingsWithDependencies, res).asExprOf[F[A]] + val splitFn = if doPar then splitToZip else splitByOne + val dropMapFn = if doDropMap then maybeDropMap else (_: Term, _: Tree, _: Term) => None + + connectBindings(bindingsWithDependencies, res)(splitFn, dropMapFn).asExprOf[F[A]] } - private def connectBindings(bindings: List[(Binding, Set[Symbol])], res: Term)(using Context): Tree = { - def go(bindings: List[(Binding, Set[Symbol])], zipped: List[(Tree, TypeRepr)], acc: Term, lastBinding: Binding): Term = bindings match { + private def connectBindings( + bindings: List[(Binding, Set[Symbol])], + res: Term + )( + splitFn: List[(Binding, Set[Symbol])] => (List[(Binding, Set[Symbol])], List[(Binding, Set[Symbol])], Binding), + dropMapFn: (Term, Tree, Term) => Option[Term] + )(using Context): Tree = { + def go(bindings: List[(Binding, Set[Symbol])], zipped: List[(Tree, TypeRepr)], acc: Term, lastBinding: Binding, res: Term): Term = bindings match { case Nil => val arg = funFromZipped(zipped, res, Symbol.spliceOwner) ctx.instance .select(ctx.instance.tpe.typeSymbol.methodMember(lastBinding.methodName).head) .appliedToTypes(List(typeReprForBindings(zipped), adaptTpeForMethod(res, lastBinding.methodName))) .appliedToArgs(List(acc, arg)) + case head :: Nil => + dropMapFn(head._1.tree, head._1.pattern, res) match { + case Some(prevExpr) => + go(Nil, zipped, acc, lastBinding, head._1.tree) + case None => + makeNonFinalCall(bindings, zipped, acc, lastBinding, res) + } case _ => - val (toZip, rest, newLastBinding) = splitToZip(bindings) - val body = go(rest, toZip.map(b => b._1.pattern -> b._1.tpe), zipExprs(toZip.map(_._1), Symbol.spliceOwner), newLastBinding) - val arg = funFromZipped(zipped, body, Symbol.spliceOwner) - val tpes = lastBinding.typeArgs.map(_.widen) - ctx.instance - .select(ctx.instance.tpe.typeSymbol.methodMember(lastBinding.methodName).head) - .appliedToTypes(List(typeReprForBindings(zipped), adaptTpeForMethod(body, lastBinding.methodName))) - .appliedToArgs(List(acc, arg)) + makeNonFinalCall(bindings, zipped, acc, lastBinding, res) + } + + def makeNonFinalCall(bindings: List[(Binding, Set[Symbol])], zipped: List[(Tree, TypeRepr)], acc: Term, lastBinding: Binding, res: Term): Term = { + val (toZip, rest, newLastBinding) = splitFn(bindings) + val body = go(rest, toZip.map(b => b._1.pattern -> b._1.tpe), zipExprs(toZip.map(_._1), Symbol.spliceOwner), newLastBinding, res) + val arg = funFromZipped(zipped, body, Symbol.spliceOwner) + val tpes = lastBinding.typeArgs.map(_.widen) + ctx.instance + .select(ctx.instance.tpe.typeSymbol.methodMember(lastBinding.methodName).head) + .appliedToTypes(List(typeReprForBindings(zipped), adaptTpeForMethod(body, lastBinding.methodName))) + .appliedToArgs(List(acc, arg)) + } + + val (toZip, rest, lastMethod) = splitFn(bindings) + go(rest, toZip.map(b => b._1.pattern -> b._1.tpe), zipExprs(toZip.map(_._1), Symbol.spliceOwner), lastMethod, res) + } + + private def maybeDropMap(prevTerm: Term, prevPattern: Tree, res: Term)(using Context): Option[Term] = { + res match { + case Literal(UnitConstant()) if extractTypeFromApplicative(prevTerm.tpe).widen =:= TypeRepr.of[Unit] => + Some(prevTerm) + case _ if eqPrevPatternRef(prevPattern, res) => + Some(prevTerm) + case _ => + None } + } - val (toZip, rest, lastMethod) = splitToZip(bindings) - go(rest, toZip.map(b => b._1.pattern -> b._1.tpe), zipExprs(toZip.map(_._1), Symbol.spliceOwner), lastMethod) + private def eqPrevPatternRef(prevPattern: Tree, res: Term): Boolean = { + (prevPattern, res) match { + case (valdef: ValDef, ident: Ident) => + valdef.symbol == ident.symbol + case _ => + false + } } private def adaptTpeForMethod(arg: Term, methodName: String): TypeRepr = @@ -127,8 +219,16 @@ private[avocado] object macros { (List(head), tail, head._1) case _ => throwGenericError() + } + } + + private def splitByOne(bindings: List[(Binding, Set[Symbol])]): (List[(Binding, Set[Symbol])], List[(Binding, Set[Symbol])], Binding) = { + bindings match { + case head :: tail => + (List(head), tail, head._1) + case _ => + throwGenericError() } - } private val tuple2: Term = Ref(Symbol.requiredModule("scala.Tuple2")) diff --git a/avocADO/src/test/scala/DropMapTest.scala b/avocADO/src/test/scala/DropMapTest.scala new file mode 100644 index 0000000..a0a51d4 --- /dev/null +++ b/avocADO/src/test/scala/DropMapTest.scala @@ -0,0 +1,190 @@ +package avocado.tests + +import avocado.* + +class DropMapTest extends munit.FunSuite { + class myOptionPackage(doOnMap: => Unit) { + sealed trait MyOption[+A] { + def map[B](f: A => B): MyOption[B] = this match { + case MySome(a) => + doOnMap + MySome(f(a)) + case MyNone => MyNone + } + def flatMap[B](f: A => MyOption[B]): MyOption[B] = this match { + case MySome(a) => f(a) + case MyNone => MyNone + } + def zip[B](that: MyOption[B]): MyOption[(A, B)] = (this, that) match { + case (MySome(a), MySome(b)) => MySome((a, b)) + case _ => MyNone + } + def value: Option[A] = this match { + case MySome(a) => Some(a) + case MyNone => None + } + } + case class MySome[A](a: A) extends MyOption[A] + case object MyNone extends MyOption[Nothing] + } + + test("don't drop map in a simple case") { + val (resOrg, mapUnusedResOrg) = { + var mapUsed = 0 + val myOption = new myOptionPackage({ mapUsed = mapUsed + 1 }) + import myOption.* + (for { + a <- MySome(1) + b <- MySome(2) + } yield a + b + ).value -> mapUsed + } + val (res, mapUnusedRes) = { + var mapUsed = 0 + val myOption = new myOptionPackage({ mapUsed = mapUsed + 1 }) + import myOption.* + dropUnusedMap( + for { + a <- MySome(1) + b <- MySome(2) + } yield a + b + ).value -> mapUsed + } + assertEquals(res, resOrg) + assert(mapUnusedRes == mapUnusedResOrg) + } + + test("drop map with same var ref as result") { + val (resOrg, mapUnusedResOrg) = { + var mapUsed = 0 + val myOption = new myOptionPackage({ mapUsed = mapUsed + 1 }) + import myOption.* + (for { + a <- MySome(1) + b <- MySome(a) + } yield b + ).value -> mapUsed + } + val (res, mapUnusedRes) = { + var mapUsed = 0 + val myOption = new myOptionPackage({ mapUsed = mapUsed + 1 }) + import myOption.* + dropUnusedMap( + for { + a <- MySome(1) + b <- MySome(a) + } yield b + ).value -> mapUsed + } + assertEquals(res, resOrg) + assert(mapUnusedRes < mapUnusedResOrg) + } + + test("drop map with unit result and wildcard last pattern") { + val (resOrg, mapUnusedResOrg) = { + var mapUsed = 0 + val myOption = new myOptionPackage({ mapUsed = mapUsed + 1 }) + import myOption.* + (for { + a <- MySome(1) + _ <- MySome(()) + } yield () + ).value -> mapUsed + } + val (res, mapUnusedRes) = { + var mapUsed = 0 + val myOption = new myOptionPackage({ mapUsed = mapUsed + 1 }) + import myOption.* + dropUnusedMap( + for { + a <- MySome(1) + _ <- MySome(()) + } yield () + ).value -> mapUsed + } + assertEquals(res, resOrg) + assert(mapUnusedRes < mapUnusedResOrg) + } + + test("drop map with unit result and named last pattern") { + val (resOrg, mapUnusedResOrg) = { + var mapUsed = 0 + val myOption = new myOptionPackage({ mapUsed = mapUsed + 1 }) + import myOption.* + (for { + a <- MySome(1) + b <- MySome(()) + } yield () + ).value -> mapUsed + } + val (res, mapUnusedRes) = { + var mapUsed = 0 + val myOption = new myOptionPackage({ mapUsed = mapUsed + 1 }) + import myOption.* + dropUnusedMap( + for { + a <- MySome(1) + b <- MySome(()) + } yield () + ).value -> mapUsed + } + assertEquals(res, resOrg) + assert(mapUnusedRes < mapUnusedResOrg) + } + + test("drop map with unit result and wildcard last pattern with alias in the middle") { + val (resOrg, mapUnusedResOrg) = { + var mapUsed = 0 + val myOption = new myOptionPackage({ mapUsed = mapUsed + 1 }) + import myOption.* + (for { + a <- MySome(1) + b = a + _ <- MySome(()) + } yield () + ).value -> mapUsed + } + val (res, mapUnusedRes) = { + var mapUsed = 0 + val myOption = new myOptionPackage({ mapUsed = mapUsed + 1 }) + import myOption.* + dropUnusedMap( + for { + a <- MySome(1) + b = a + _ <- MySome(()) + } yield () + ).value -> mapUsed + } + assertEquals(res, resOrg) + assert(mapUnusedRes < mapUnusedResOrg) + } + + test("drop map with unit result and named last pattern with alias in the middle") { + val (resOrg, mapUnusedResOrg) = { + var mapUsed = 0 + val myOption = new myOptionPackage({ mapUsed = mapUsed + 1 }) + import myOption.* + (for { + a <- MySome(1) + b = a + c <- MySome(()) + } yield () + ).value -> mapUsed + } + val (res, mapUnusedRes) = { + var mapUsed = 0 + val myOption = new myOptionPackage({ mapUsed = mapUsed + 1 }) + import myOption.* + dropUnusedMap( + for { + a <- MySome(1) + b = a + c <- MySome(()) + } yield () + ).value -> mapUsed + } + assertEquals(res, resOrg) + assert(mapUnusedRes < mapUnusedResOrg) + } +}