Skip to content

Commit

Permalink
Allow to beta reduce curried function applications in quotes reflect
Browse files Browse the repository at this point in the history
Previously, the curried functions with multiple applications
were not able to be beta-reduced in any way, which was unexpected.
Now we allow reducing any number of top-level function applications
for a curried function. This was also made clearer in the documentation
for the affected (Expr.betaReduce and Term.betaReduce) methods.
  • Loading branch information
jchyb committed Jul 3, 2023
1 parent cdc9fe8 commit 9cd22f3
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 10 deletions.
12 changes: 12 additions & 0 deletions compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,18 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
for e <- betaReduce(expr) yield tpd.cpy.Block(tree)(Nil, e)
case tpd.Inlined(_, Nil, expr) =>
betaReduce(expr)
case tpd.Apply(tpd.Select(expr: Apply, nme), args) =>
betaReduce(expr).map { expr1 =>
dotc.transform.BetaReduce(
tpd.Apply(tpd.Select(expr1, nme), args)
).withSpan(tree.span)
}
case tpd.Apply(tpd.TypeApply(tpd.Select(expr: Apply, nme), tpts), args) =>
betaReduce(expr).map { expr1 =>
dotc.transform.BetaReduce(
tpd.Apply(tpd.TypeApply(tpd.Select(expr1, nme), tpts), args)
).withSpan(tree.span)
}
case _ =>
val tree1 = dotc.transform.BetaReduce(tree)
if tree1 eq tree then None
Expand Down
32 changes: 27 additions & 5 deletions library/src/scala/quoted/Expr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,34 @@ abstract class Expr[+T] private[scala] ()
object Expr {

/** `e.betaReduce` returns an expression that is functionally equivalent to `e`,
* however if `e` is of the form `((y1, ..., yn) => e2)(e1, ..., en)`
* then it optimizes this the top most call by returning the result of beta-reducing the application.
* Otherwise returns `expr`.
* however if `e` is of the form `((y1, ..., yn) => e2)(e1, ..., en)`
* then it optimizes the top most call by returning the result of beta-reducing the application.
* Similarly, all outermost curried function applications will be beta-reduced, if possible.
* Otherwise returns `expr`.
*
* To retain semantics the argument `ei` is bound as `val yi = ei` and by-name arguments to `def yi = ei`.
* Some bindings may be elided as an early optimization.
* To retain semantics the argument `ei` is bound as `val yi = ei` and by-name arguments to `def yi = ei`.
* Some bindings may be elided as an early optimization.
*
* Example:
* ```scala sc:nocompile
* ([X1, Y1, ...] => (x1, y1, ...) => ... => [Xn, Yn, ...] => (xn, yn, ...) => f[X1, Y1, ..., Xn, Yn, ...](x1, y1, ..., xn, yn, ...))).apply[Tx1, Ty1, ...](myX1, myY1, ...)....apply[Txn, Tyn, ...](myXn, myYn, ...)
* ```
* will be reduced to
* ```scala sc:nocompile
* type X1 = Tx1
* type Y1 = Ty1
* ...
* val x1 = myX1
* val y1 = myY1
* ...
* type Xn = Txn
* type Yn = Tyn
* ...
* val xn = myXn
* val yn = myYn
* ...
* f[X1, Y1, ..., Xn, Yn, ...](x1, y1, ..., xn, yn, ...)
* ```
*/
def betaReduce[T](expr: Expr[T])(using Quotes): Expr[T] =
import quotes.reflect._
Expand Down
33 changes: 28 additions & 5 deletions library/src/scala/quoted/Quotes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -751,14 +751,37 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching =>
/** Methods of the module object `val Term` */
trait TermModule { this: Term.type =>

/** Returns a term that is functionally equivalent to `t`,
/** Returns a term that is functionally equivalent to `t`,
* however if `t` is of the form `((y1, ..., yn) => e2)(e1, ..., en)`
* then it optimizes this the top most call by returning the `Some`
* with the result of beta-reducing the application.
* then it optimizes the top most call by returning the `Some`
* with the result of beta-reducing the function application.
* Similarly, all outermost curried function applications will be
* beta-reduced, if possible.
* Otherwise returns `None`.
*
* To retain semantics the argument `ei` is bound as `val yi = ei` and by-name arguments to `def yi = ei`.
* Some bindings may be elided as an early optimization.
* To retain semantics the argument `ei` is bound as `val yi = ei` and by-name arguments to `def yi = ei`.
* Some bindings may be elided as an early optimization.
*
* Example:
* ```scala sc:nocompile
* ([X1, Y1, ...] => (x1, y1, ...) => ... => [Xn, Yn, ...] => (xn, yn, ...) => f[X1, Y1, ..., Xn, Yn, ...](x1, y1, ..., xn, yn, ...))).apply[Tx1, Ty1, ...](myX1, myY1, ...)....apply[Txn, Tyn, ...](myXn, myYn, ...)
* ```
* will be reduced to
* ```scala sc:nocompile
* type X1 = Tx1
* type Y1 = Ty1
* ...
* val x1 = myX1
* val y1 = myY1
* ...
* type Xn = Txn
* type Yn = Tyn
* ...
* val xn = myXn
* val yn = myYn
* ...
* f[X1, Y1, ..., Xn, Yn, ...](x1, y1, ..., xn, yn, ...)
* ```
*/
def betaReduce(term: Term): Option[Term]

Expand Down
76 changes: 76 additions & 0 deletions tests/pos-macros/i17506/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import scala.quoted._

def assertBetaReduction(using Quotes)(applied: Expr[Any], expected: String): quotes.reflect.Term =
import quotes.reflect._
val reducedMaybe = Term.betaReduce(applied.asTerm)
assert(reducedMaybe.isDefined)
val reduced = reducedMaybe.get
assert(reduced.show == expected)
reduced

inline def regularCurriedCtxFun2BetaReduceTest(inline f: Foo ?=> Bar ?=> Int): Unit =
${regularCurriedCtxFun2BetaReduceTestImpl('f)}
def regularCurriedCtxFun2BetaReduceTestImpl(f: Expr[Foo ?=> Bar ?=> Int])(using Quotes): Expr[Int] =
val expected =
"""|{
| val evidence$3: Bar = new Bar()
| val evidence$2: Foo = new Foo()
| 123
|}""".stripMargin
val applied = '{$f(using new Foo())(using new Bar())}
assertBetaReduction(applied, expected).asExprOf[Int]

inline def regularCurriedFun2BetaReduceTest(inline f: Foo => Bar => Int): Int =
${regularCurriedFun2BetaReduceTestImpl('f)}
def regularCurriedFun2BetaReduceTestImpl(f: Expr[Foo => Bar => Int])(using Quotes): Expr[Int] =
val expected =
"""|{
| val b: Bar = new Bar()
| val f: Foo = new Foo()
| 123
|}""".stripMargin
val applied = '{$f(new Foo())(new Bar())}
assertBetaReduction(applied, expected).asExprOf[Int]

inline def typeParamCurriedFun2BetaReduceTest(inline f: [A] => A => [B] => B => Unit): Unit =
${typeParamCurriedFun2BetaReduceTestImpl('f)}
def typeParamCurriedFun2BetaReduceTestImpl(f: Expr[[A] => (a: A) => [B] => (b: B) => Unit])(using Quotes): Expr[Unit] =
val expected =
"""|{
| type Y = Bar
| val y: Bar = new Bar()
| type X = Foo
| val x: Foo = new Foo()
| typeParamFun2[Y, X](y, x)
|}""".stripMargin
val applied = '{$f.apply[Foo](new Foo()).apply[Bar](new Bar())}
assertBetaReduction(applied, expected).asExprOf[Unit]

inline def regularCurriedFun3BetaReduceTest(inline f: Foo => Bar => Baz => Int): Int =
${regularCurriedFun3BetaReduceTestImpl('f)}
def regularCurriedFun3BetaReduceTestImpl(f: Expr[Foo => Bar => Baz => Int])(using Quotes): Expr[Int] =
val expected =
"""|{
| val i: Baz = new Baz()
| val b: Bar = new Bar()
| val f: Foo = new Foo()
| 123
|}""".stripMargin
val applied = '{$f(new Foo())(new Bar())(new Baz())}
assertBetaReduction(applied, expected).asExprOf[Int]

inline def typeParamCurriedFun3BetaReduceTest(inline f: [A] => A => [B] => B => [C] => C => Unit): Unit =
${typeParamCurriedFun3BetaReduceTestImpl('f)}
def typeParamCurriedFun3BetaReduceTestImpl(f: Expr[[A] => A => [B] => B => [C] => C => Unit])(using Quotes): Expr[Unit] =
val expected =
"""|{
| type Z = Baz
| val z: Baz = new Baz()
| type Y = Bar
| val y: Bar = new Bar()
| type X = Foo
| val x: Foo = new Foo()
| typeParamFun3[Z, Y, X](z, y, x)
|}""".stripMargin
val applied = '{$f.apply[Foo](new Foo()).apply[Bar](new Bar()).apply[Baz](new Baz())}
assertBetaReduction(applied, expected).asExprOf[Unit]
15 changes: 15 additions & 0 deletions tests/pos-macros/i17506/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
class Foo
class Bar
class Baz

@main def run() =
def typeParamFun2[A, B](a: A, b: B): Unit = println(a.toString + " " + b.toString)
def typeParamFun3[A, B, C](a: A, b: B, c: C): Unit = println(a.toString + " " + b.toString)

regularCurriedCtxFun2BetaReduceTest((f: Foo) ?=> (b: Bar) ?=> 123)
regularCurriedCtxFun2BetaReduceTest(123)
regularCurriedFun2BetaReduceTest(((f: Foo) => (b: Bar) => 123))
typeParamCurriedFun2BetaReduceTest([X] => (x: X) => [Y] => (y: Y) => typeParamFun2[Y, X](y, x))

regularCurriedFun3BetaReduceTest((f: Foo) => (b: Bar) => (i: Baz) => 123)
typeParamCurriedFun3BetaReduceTest([X] => (x: X) => [Y] => (y: Y) => [Z] => (z: Z) => typeParamFun3[Z, Y, X](z, y, x))

0 comments on commit 9cd22f3

Please sign in to comment.