Skip to content

Commit

Permalink
Implement a less brain-rot version of dropUnusedMap
Browse files Browse the repository at this point in the history
The new implementation doesn't require an `AvocADO` instance to work
  • Loading branch information
KacperFKorban committed Jul 9, 2024
1 parent 8e54d97 commit be58b83
Show file tree
Hide file tree
Showing 3 changed files with 333 additions and 14 deletions.
29 changes: 29 additions & 0 deletions avocADO/src/main/scala/ado.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,32 @@ trait AvocADO[F[_]] {
def zip[A, B](fa: F[A], fb: F[B]): F[(A, B)]
def flatMap[A, B](fa: F[A], f: A => F[B]): F[B]
}

/**
* Drops unused trailing map calls in a for-comprehension. Helps with making for-comprehensions stack-safe.
* Example usage:
* ```scala
* dropUnusedMap {
* for {
* a <- doSth()
* _ <- doSideEffectAndReturnUnit(a)
* } yield ()
* }
* ```
*
* The above code will be transformed to code essentially equivalent to:
* ```scala
* doSth().flatMap(a => doSideEffectAndReturnUnit(a))
* ```
*
* instead of the normal for-comprehension desugaring:
* ```scala
* doSth().map(a => doSideEffectAndReturnUnit(a)).map(_ => ())
* ```
*
* Handled cases:
* - returning `()` from the for-comprehension, when the last generator expression also binds to `Unit`
* - returning the same variable reference as the last generator expression
*/
inline def dropUnusedMap[F[_], A](inline comp: F[A]): F[A] =
${ macros.dropUnusedMapImpl[F, A]('comp) }
128 changes: 114 additions & 14 deletions avocADO/src/main/scala/macros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ private[avocado] object macros {
def adoImpl[F[_]: Type, A: Type](compExpr: Expr[F[A]], instanceExpr: Expr[AvocADO[F]])(using Quotes): Expr[F[A]] =
new ADOImpl(using quotes).adoImpl(compExpr, instanceExpr)

def dropUnusedMapImpl[F[_]: Type, A: Type](compExpr: Expr[F[A]])(using Quotes): Expr[F[A]] =
new ADOImpl(using quotes).dropUnusedMapImpl1(compExpr)

class ADOImpl(using Quotes) {
import quotes.reflect.*

Expand All @@ -34,7 +37,56 @@ private[avocado] object macros {

private def ctx(using context: Context): Context = context

def dropUnusedMapImpl[F[_]: Type, A: Type](compExpr: Expr[F[A]], instanceExpr: Expr[AvocADO[F]])(using Quotes): Expr[F[A]] = {
implCommon(compExpr, instanceExpr)(doPar = false, doDropMap = true)
}

def dropUnusedMapImpl1[F[_]: Type, A: Type](compExpr: Expr[F[A]]): Expr[F[A]] = {
given Context = Context(compExpr.asTerm, TypeRepr.of[F]) // This instance is wrong on purpose, with the assumption that it won't be used
def doDrop(expr: Term, arg: Term): Option[Term] = {
arg match {
case Lambda(List(param), body)
if isConstUnitBody(param, body) && expr.tpe.widen <:< ctx.fTpe.appliedTo(TypeRepr.of[Unit]) =>
Some(expr)
case Lambda(List(param), body)
if isIdentityBody(param, body) =>
Some(expr)
case _ =>
None
}
}
def isConstUnitBody(param: ValDef, tree: Tree): Boolean = tree match {
case Block(Nil, body) => isConstUnitBody(param, body)
case Match(scrutinee, List(CaseDef(_, _, body))) if scrutinee.symbol == param.symbol => isConstUnitBody(param, body)
case Literal(UnitConstant()) => true
case _ => false
}
def isIdentityBody(param: ValDef, tree: Tree): Boolean = tree match {
case Block(Nil, body) => isIdentityBody(param, body)
case Ident(name) if name == param.name => true
case _ => false
}
object dropUnusedMapMap extends TreeMap {
override def transformTerm(tree: Term)(owner: Symbol): Term = tree match {
case NormalAllowed(expr, methodName, typeArgs, arg) if methodName == "map" =>
doDrop(expr, arg).fold(super.transformTerm(tree)(owner))(identity)
case WithImplicitsAllowed(expr, args, methodName, typeArgs, arg) if methodName == "map" =>
doDrop(expr, arg).fold(super.transformTerm(tree)(owner))(identity)
case FromTypeclassAllowed(expr, evidences, methodName, typeArgs, arg) if methodName == "map" =>
doDrop(expr, arg).fold(super.transformTerm(tree)(owner))(identity)
case _ =>
super.transformTerm(tree)(owner)
}
}

dropUnusedMapMap.transformTerm(compExpr.asTerm)(Symbol.spliceOwner).asExprOf[F[A]]
}

def adoImpl[F[_]: Type, A: Type](compExpr: Expr[F[A]], instanceExpr: Expr[AvocADO[F]])(using Quotes): Expr[F[A]] = {
implCommon(compExpr, instanceExpr)(doPar = true, doDropMap = false)
}

def implCommon[F[_]: Type, A: Type](compExpr: Expr[F[A]], instanceExpr: Expr[AvocADO[F]])(doPar: Boolean, doDropMap: Boolean)(using Quotes): Expr[F[A]] = {
val exprTree = compExpr.asTerm match
case Inlined(_, _, tree) => tree match
case Block(Nil, expr) => expr
Expand All @@ -51,30 +103,70 @@ private[avocado] object macros {
case binding => (binding, getBindingDependencies(binding.tree, bindingVals))
}

connectBindings(bindingsWithDependencies, res).asExprOf[F[A]]
val splitFn = if doPar then splitToZip else splitByOne
val dropMapFn = if doDropMap then maybeDropMap else (_: Term, _: Tree, _: Term) => None

connectBindings(bindingsWithDependencies, res)(splitFn, dropMapFn).asExprOf[F[A]]
}

private def connectBindings(bindings: List[(Binding, Set[Symbol])], res: Term)(using Context): Tree = {
def go(bindings: List[(Binding, Set[Symbol])], zipped: List[(Tree, TypeRepr)], acc: Term, lastBinding: Binding): Term = bindings match {
private def connectBindings(
bindings: List[(Binding, Set[Symbol])],
res: Term
)(
splitFn: List[(Binding, Set[Symbol])] => (List[(Binding, Set[Symbol])], List[(Binding, Set[Symbol])], Binding),
dropMapFn: (Term, Tree, Term) => Option[Term]
)(using Context): Tree = {
def go(bindings: List[(Binding, Set[Symbol])], zipped: List[(Tree, TypeRepr)], acc: Term, lastBinding: Binding, res: Term): Term = bindings match {
case Nil =>
val arg = funFromZipped(zipped, res, Symbol.spliceOwner)
ctx.instance
.select(ctx.instance.tpe.typeSymbol.methodMember(lastBinding.methodName).head)
.appliedToTypes(List(typeReprForBindings(zipped), adaptTpeForMethod(res, lastBinding.methodName)))
.appliedToArgs(List(acc, arg))
case head :: Nil =>
dropMapFn(head._1.tree, head._1.pattern, res) match {
case Some(prevExpr) =>
go(Nil, zipped, acc, lastBinding, head._1.tree)
case None =>
makeNonFinalCall(bindings, zipped, acc, lastBinding, res)
}
case _ =>
val (toZip, rest, newLastBinding) = splitToZip(bindings)
val body = go(rest, toZip.map(b => b._1.pattern -> b._1.tpe), zipExprs(toZip.map(_._1), Symbol.spliceOwner), newLastBinding)
val arg = funFromZipped(zipped, body, Symbol.spliceOwner)
val tpes = lastBinding.typeArgs.map(_.widen)
ctx.instance
.select(ctx.instance.tpe.typeSymbol.methodMember(lastBinding.methodName).head)
.appliedToTypes(List(typeReprForBindings(zipped), adaptTpeForMethod(body, lastBinding.methodName)))
.appliedToArgs(List(acc, arg))
makeNonFinalCall(bindings, zipped, acc, lastBinding, res)
}

def makeNonFinalCall(bindings: List[(Binding, Set[Symbol])], zipped: List[(Tree, TypeRepr)], acc: Term, lastBinding: Binding, res: Term): Term = {
val (toZip, rest, newLastBinding) = splitFn(bindings)
val body = go(rest, toZip.map(b => b._1.pattern -> b._1.tpe), zipExprs(toZip.map(_._1), Symbol.spliceOwner), newLastBinding, res)
val arg = funFromZipped(zipped, body, Symbol.spliceOwner)
val tpes = lastBinding.typeArgs.map(_.widen)
ctx.instance
.select(ctx.instance.tpe.typeSymbol.methodMember(lastBinding.methodName).head)
.appliedToTypes(List(typeReprForBindings(zipped), adaptTpeForMethod(body, lastBinding.methodName)))
.appliedToArgs(List(acc, arg))
}

val (toZip, rest, lastMethod) = splitFn(bindings)
go(rest, toZip.map(b => b._1.pattern -> b._1.tpe), zipExprs(toZip.map(_._1), Symbol.spliceOwner), lastMethod, res)
}

private def maybeDropMap(prevTerm: Term, prevPattern: Tree, res: Term)(using Context): Option[Term] = {
res match {
case Literal(UnitConstant()) if extractTypeFromApplicative(prevTerm.tpe).widen =:= TypeRepr.of[Unit] =>
Some(prevTerm)
case _ if eqPrevPatternRef(prevPattern, res) =>
Some(prevTerm)
case _ =>
None
}
}

val (toZip, rest, lastMethod) = splitToZip(bindings)
go(rest, toZip.map(b => b._1.pattern -> b._1.tpe), zipExprs(toZip.map(_._1), Symbol.spliceOwner), lastMethod)
private def eqPrevPatternRef(prevPattern: Tree, res: Term): Boolean = {
(prevPattern, res) match {
case (valdef: ValDef, ident: Ident) =>
valdef.symbol == ident.symbol
case _ =>
false
}
}

private def adaptTpeForMethod(arg: Term, methodName: String): TypeRepr =
Expand Down Expand Up @@ -127,8 +219,16 @@ private[avocado] object macros {
(List(head), tail, head._1)
case _ =>
throwGenericError()
}
}

private def splitByOne(bindings: List[(Binding, Set[Symbol])]): (List[(Binding, Set[Symbol])], List[(Binding, Set[Symbol])], Binding) = {
bindings match {
case head :: tail =>
(List(head), tail, head._1)
case _ =>
throwGenericError()
}

}

private val tuple2: Term = Ref(Symbol.requiredModule("scala.Tuple2"))
Expand Down
190 changes: 190 additions & 0 deletions avocADO/src/test/scala/DropMapTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
package avocado.tests

import avocado.*

class DropMapTest extends munit.FunSuite {
class myOptionPackage(doOnMap: => Unit) {
sealed trait MyOption[+A] {
def map[B](f: A => B): MyOption[B] = this match {
case MySome(a) =>
doOnMap
MySome(f(a))
case MyNone => MyNone
}
def flatMap[B](f: A => MyOption[B]): MyOption[B] = this match {
case MySome(a) => f(a)
case MyNone => MyNone
}
def zip[B](that: MyOption[B]): MyOption[(A, B)] = (this, that) match {
case (MySome(a), MySome(b)) => MySome((a, b))
case _ => MyNone
}
def value: Option[A] = this match {
case MySome(a) => Some(a)
case MyNone => None
}
}
case class MySome[A](a: A) extends MyOption[A]
case object MyNone extends MyOption[Nothing]
}

test("don't drop map in a simple case") {
val (resOrg, mapUnusedResOrg) = {
var mapUsed = 0
val myOption = new myOptionPackage({ mapUsed = mapUsed + 1 })
import myOption.*
(for {
a <- MySome(1)
b <- MySome(2)
} yield a + b
).value -> mapUsed
}
val (res, mapUnusedRes) = {
var mapUsed = 0
val myOption = new myOptionPackage({ mapUsed = mapUsed + 1 })
import myOption.*
dropUnusedMap(
for {
a <- MySome(1)
b <- MySome(2)
} yield a + b
).value -> mapUsed
}
assertEquals(res, resOrg)
assert(mapUnusedRes == mapUnusedResOrg)
}

test("drop map with same var ref as result") {
val (resOrg, mapUnusedResOrg) = {
var mapUsed = 0
val myOption = new myOptionPackage({ mapUsed = mapUsed + 1 })
import myOption.*
(for {
a <- MySome(1)
b <- MySome(a)
} yield b
).value -> mapUsed
}
val (res, mapUnusedRes) = {
var mapUsed = 0
val myOption = new myOptionPackage({ mapUsed = mapUsed + 1 })
import myOption.*
dropUnusedMap(
for {
a <- MySome(1)
b <- MySome(a)
} yield b
).value -> mapUsed
}
assertEquals(res, resOrg)
assert(mapUnusedRes < mapUnusedResOrg)
}

test("drop map with unit result and wildcard last pattern") {
val (resOrg, mapUnusedResOrg) = {
var mapUsed = 0
val myOption = new myOptionPackage({ mapUsed = mapUsed + 1 })
import myOption.*
(for {
a <- MySome(1)
_ <- MySome(())
} yield ()
).value -> mapUsed
}
val (res, mapUnusedRes) = {
var mapUsed = 0
val myOption = new myOptionPackage({ mapUsed = mapUsed + 1 })
import myOption.*
dropUnusedMap(
for {
a <- MySome(1)
_ <- MySome(())
} yield ()
).value -> mapUsed
}
assertEquals(res, resOrg)
assert(mapUnusedRes < mapUnusedResOrg)
}

test("drop map with unit result and named last pattern") {
val (resOrg, mapUnusedResOrg) = {
var mapUsed = 0
val myOption = new myOptionPackage({ mapUsed = mapUsed + 1 })
import myOption.*
(for {
a <- MySome(1)
b <- MySome(())
} yield ()
).value -> mapUsed
}
val (res, mapUnusedRes) = {
var mapUsed = 0
val myOption = new myOptionPackage({ mapUsed = mapUsed + 1 })
import myOption.*
dropUnusedMap(
for {
a <- MySome(1)
b <- MySome(())
} yield ()
).value -> mapUsed
}
assertEquals(res, resOrg)
assert(mapUnusedRes < mapUnusedResOrg)
}

test("drop map with unit result and wildcard last pattern with alias in the middle") {
val (resOrg, mapUnusedResOrg) = {
var mapUsed = 0
val myOption = new myOptionPackage({ mapUsed = mapUsed + 1 })
import myOption.*
(for {
a <- MySome(1)
b = a
_ <- MySome(())
} yield ()
).value -> mapUsed
}
val (res, mapUnusedRes) = {
var mapUsed = 0
val myOption = new myOptionPackage({ mapUsed = mapUsed + 1 })
import myOption.*
dropUnusedMap(
for {
a <- MySome(1)
b = a
_ <- MySome(())
} yield ()
).value -> mapUsed
}
assertEquals(res, resOrg)
assert(mapUnusedRes < mapUnusedResOrg)
}

test("drop map with unit result and named last pattern with alias in the middle") {
val (resOrg, mapUnusedResOrg) = {
var mapUsed = 0
val myOption = new myOptionPackage({ mapUsed = mapUsed + 1 })
import myOption.*
(for {
a <- MySome(1)
b = a
c <- MySome(())
} yield ()
).value -> mapUsed
}
val (res, mapUnusedRes) = {
var mapUsed = 0
val myOption = new myOptionPackage({ mapUsed = mapUsed + 1 })
import myOption.*
dropUnusedMap(
for {
a <- MySome(1)
b = a
c <- MySome(())
} yield ()
).value -> mapUsed
}
assertEquals(res, resOrg)
assert(mapUnusedRes < mapUnusedResOrg)
}
}

0 comments on commit be58b83

Please sign in to comment.