Skip to content

Commit

Permalink
Add support for companion in MacroAnnotations
Browse files Browse the repository at this point in the history
  • Loading branch information
hamzaremmal committed Apr 23, 2024
1 parent 54d67e0 commit 58fe2ac
Show file tree
Hide file tree
Showing 75 changed files with 630 additions and 350 deletions.
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/CompilationUnit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import config.{SourceVersion, Feature}
import StdNames.nme
import scala.annotation.internal.sharable
import scala.util.control.NoStackTrace
import transform.MacroAnnotations
import transform.MacroAnnotations.isMacroAnnotation

class CompilationUnit protected (val source: SourceFile, val info: CompilationUnitInfo | Null) {

Expand Down Expand Up @@ -193,7 +193,7 @@ object CompilationUnit {
case _ =>
case _ =>
for annot <- tree.symbol.annotations do
if MacroAnnotations.isMacroAnnotation(annot) then
if annot.isMacroAnnotation then
ctx.compilationUnit.hasMacroAnnotations = true
traverseChildren(tree)
}
Expand Down
76 changes: 76 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/TreeMapWithTrackedStats.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package dotty.tools.dotc
package ast

import tpd.*
import core.Contexts.*
import core.Symbols.*
import util.Property

import scala.collection.mutable

// It is safe to assume that the companion of a tree is in the same scope
// Therefore, when expanding MacroAnnotations, we will only keep track of
// the trees in the same scope as the current transformed tree

abstract class TreeMapWithTrackedStats extends TreeMapWithImplicits:

import TreeMapWithTrackedStats.*

/** Fetch the corresponding tracked tree for a given symbol */
protected final def getTracked(sym: Symbol)(using Context): Option[MemberDef] =
for trees <- ctx.property(TrackedTrees)
tree <- trees.get(sym)
yield tree

/** Update the tracked trees */
protected final def updateTracked(tree: Tree)(using Context): Tree =
tree match
case tree: MemberDef =>
trackedTrees.update(tree.symbol, tree)
tree
case _ => tree
end updateTracked

/** Process a list of trees and give the priority to trakced trees */
private final def withUpdatedTrackedTrees(stats: List[Tree])(using Context) =
val trackedTrees = TreeMapWithTrackedStats.trackedTrees
stats.mapConserve:
case tree: MemberDef if trackedTrees.contains(tree.symbol) =>
trackedTrees(tree.symbol)
case stat => stat

override def transform(tree: Tree)(using Context): Tree =
tree match
case PackageDef(_, stats) =>
inContext(trackedDefinitionsCtx(stats)): // Step I: Collect and memoize all the definition trees
// Step II: Transform the tree
val pkg@PackageDef(pid, stats) = super.transform(tree): @unchecked
// Step III: Reconcile between the symbols in syms and the tree
cpy.PackageDef(pkg)(pid = pid, stats = withUpdatedTrackedTrees(stats))
case block: Block =>
inContext(trackedDefinitionsCtx(block.stats)): // Step I: Collect all the member definitions in the block
// Step II: Transform the tree
val b@Block(stats, expr) = super.transform(tree): @unchecked
// Step III: Reconcile between the symbols in syms and the tree
cpy.Block(b)(expr = expr, stats = withUpdatedTrackedTrees(stats))
case TypeDef(_, impl: Template) =>
inContext(trackedDefinitionsCtx(impl.body)): // Step I: Collect and memoize all the stats
// Step II: Transform the tree
val newTree@TypeDef(name, impl: Template) = super.transform(tree): @unchecked
// Step III: Reconcile between the symbols in syms and the tree
cpy.TypeDef(newTree)(rhs = cpy.Template(impl)(body = withUpdatedTrackedTrees(impl.body)))
case _ => super.transform(tree)

end TreeMapWithTrackedStats

object TreeMapWithTrackedStats:
private val TrackedTrees = new Property.Key[mutable.Map[Symbol, tpd.MemberDef]]

/** Fetch the tracked trees in the cuurent context */
private def trackedTrees(using Context): mutable.Map[Symbol, MemberDef] =
ctx.property(TrackedTrees).get

/** Build a context and track the provided MemberDef trees */
private def trackedDefinitionsCtx(stats: List[Tree])(using Context): Context =
val treesToTrack = stats.collect { case m: MemberDef => (m.symbol, m) }
ctx.fresh.setProperty(TrackedTrees, mutable.Map(treesToTrack*))
91 changes: 60 additions & 31 deletions compiler/src/dotty/tools/dotc/transform/Inlining.scala
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
package dotty.tools.dotc
package transform

import ast.tpd
import ast.Trees.*
import ast.TreeMapWithTrackedStats
import core.*
import Flags.*
import Decorators.*
import Contexts.*
import Symbols.*
import Decorators.*
import config.Printers.inlining
import DenotTransformers.IdentityDenotTransformer
import MacroAnnotations.hasMacroAnnotation
import inlines.Inlines
import quoted.*
import staging.StagingLevel
import util.Property

import dotty.tools.dotc.ast.tpd
import dotty.tools.dotc.ast.Trees.*
import dotty.tools.dotc.quoted.*
import dotty.tools.dotc.inlines.Inlines
import dotty.tools.dotc.ast.TreeMapWithImplicits
import dotty.tools.dotc.core.DenotTransformers.IdentityDenotTransformer
import dotty.tools.dotc.staging.StagingLevel

import scala.collection.mutable.ListBuffer
import scala.collection.mutable

/** Inlines all calls to inline methods that are not in an inline method or a quote */
class Inlining extends MacroTransform, IdentityDenotTransformer {
Expand Down Expand Up @@ -56,38 +60,21 @@ class Inlining extends MacroTransform, IdentityDenotTransformer {

def newTransformer(using Context): Transformer = new Transformer {
override def transform(tree: tpd.Tree)(using Context): tpd.Tree =
new InliningTreeMap().transform(tree)
InliningTreeMap().transform(tree)
}

private class InliningTreeMap extends TreeMapWithImplicits {
private class InliningTreeMap extends TreeMapWithTrackedStats {

/** List of top level classes added by macro annotation in a package object.
* These are added to the PackageDef that owns this particular package object.
*/
private val newTopClasses = MutableSymbolMap[ListBuffer[Tree]]()
private val newTopClasses = MutableSymbolMap[mutable.ListBuffer[Tree]]()

override def transform(tree: Tree)(using Context): Tree = {
tree match
case tree: MemberDef =>
if tree.symbol.is(Inline) then tree
else if tree.symbol.is(Param) then super.transform(tree)
else if
!tree.symbol.isPrimaryConstructor
&& StagingLevel.level == 0
&& MacroAnnotations.hasMacroAnnotation(tree.symbol)
then
val trees = (new MacroAnnotations(self)).expandAnnotations(tree)
val trees1 = trees.map(super.transform)

// Find classes added to the top level from a package object
val (topClasses, trees2) =
if ctx.owner.isPackageObject then trees1.partition(_.symbol.owner == ctx.owner.owner)
else (Nil, trees1)
if topClasses.nonEmpty then
newTopClasses.getOrElseUpdate(ctx.owner.owner, new ListBuffer) ++= topClasses

flatTree(trees2)
else super.transform(tree)
// Fetch the latest tracked tree (It might have already been transformed by its companion)
transformMemberDef(getTracked(tree.symbol).getOrElse(tree))
case _: Typed | _: Block =>
super.transform(tree)
case _: PackageDef =>
Expand All @@ -113,7 +100,49 @@ class Inlining extends MacroTransform, IdentityDenotTransformer {
else Inlines.inlineCall(tree1)
else super.transform(tree)
}

private def transformMemberDef(tree: MemberDef)(using Context) : Tree =
if tree.symbol.is(Inline) then tree
else if tree.symbol.is(Param) then
super.transform(tree)
else if
!tree.symbol.isPrimaryConstructor
&& StagingLevel.level == 0
&& tree.symbol.hasMacroAnnotation
then
// Fetch the companion's tree
val companionSym =
if tree.symbol.is(ModuleClass) then tree.symbol.companionClass
else if tree.symbol.is(ModuleVal) then NoSymbol
else tree.symbol.companionModule.moduleClass

// Expand and process MacroAnnotations
val companion = getTracked(companionSym)
val (trees, newCompanion) = MacroAnnotations.expandAnnotations(tree, companion)

// Enter the new symbols & Update the tracked trees
(newCompanion.toList ::: trees).foreach: tree =>
MacroAnnotations.enterMissingSymbols(tree, self)
updateTracked(tree)

// Perform inlining on the expansion of the annotations
val trees1 = trees.map(super.transform)
trees1.foreach(updateTracked)
if newCompanion ne companion then
newCompanion.map(super.transform).foreach(updateTracked)
// Find classes added to the top level from a package object
val (topClasses, trees2) =
if ctx.owner.isPackageObject then trees1.partition(_.symbol.owner == ctx.owner.owner)
else (Nil, trees1)
if topClasses.nonEmpty then
newTopClasses.getOrElseUpdate(ctx.owner.owner, new mutable.ListBuffer) ++= topClasses
flatTree(trees2)
else
updateTracked(super.transform(tree))
end transformMemberDef

}

}

object Inlining:
Expand Down
Loading

0 comments on commit 58fe2ac

Please sign in to comment.