From 131f2e371d9c1782721bc6228f5a2885c396be60 Mon Sep 17 00:00:00 2001 From: Matt Bovel Date: Thu, 18 Apr 2024 14:13:04 +0200 Subject: [PATCH] Fix mapping and pickling of annotated types MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `Annotation.mapWith` maps an `Annotation` with a type map `tm`. Before actually applying `tm` to the annotation’s `tree`, it first checks if `tm` would result in any change by applying it to the types of the annotation’s arguments, and checking if the mapped types are different. This optimization had two problems: it didn’t include type parameters, and used `frozen_=:=` to compare types, which failed to detected some changes. This commit changes `Annotation.arguments` to also include type parameters, and, and changes `Annotation.MapWith` to use `==` to compare types instead of `frozen_=:=`. Furthermore, in case of changes, the symbol in the annotation's tree should be copied to make sure that the same symbol is not used for different trees. This commit achieves this by using a custom `TreeTypeMap` with an overridden `withMappedSyms` method where `Symbols.mapSymbols` is called with the argument `mapAlways = true`. Finally, positons of trees that appear inside `AnnotatedType` only were not pickled. This commit also fixes this. --- .../src/dotty/tools/dotc/ast/TreeInfo.scala | 4 +-- .../dotty/tools/dotc/core/Annotations.scala | 19 ++++++++--- .../dotc/core/tasty/PositionPickler.scala | 4 +++ .../tools/dotc/core/tasty/TreePickler.scala | 7 ++++ .../tools/dotc/quoted/PickledQuotes.scala | 2 +- .../dotty/tools/dotc/transform/Pickler.scala | 2 +- tests/pos/annot-17939.scala | 7 ++++ tests/pos/annot-17939b.scala | 10 ++++++ tests/pos/annot-19846.scala | 8 +++++ tests/pos/annot-19846b.scala | 7 ++++ tests/pos/annot-5789.scala | 10 ++++++ tests/printing/annot-18064.check | 16 +++++++++ tests/printing/annot-18064.scala | 7 ++++ tests/printing/annot-19846b.check | 33 +++++++++++++++++++ tests/printing/annot-19846b.scala | 7 ++++ 15 files changed, 134 insertions(+), 9 deletions(-) create mode 100644 tests/pos/annot-17939.scala create mode 100644 tests/pos/annot-17939b.scala create mode 100644 tests/pos/annot-19846.scala create mode 100644 tests/pos/annot-19846b.scala create mode 100644 tests/pos/annot-5789.scala create mode 100644 tests/printing/annot-18064.check create mode 100644 tests/printing/annot-18064.scala create mode 100644 tests/printing/annot-19846b.check create mode 100644 tests/printing/annot-19846b.scala diff --git a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala index a1bba544cc06..1ba86a66aae0 100644 --- a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala +++ b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala @@ -134,10 +134,10 @@ trait TreeInfo[T <: Untyped] { self: Trees.Instance[T] => case _ => argss loop(tree, Nil) - /** All term arguments of an application in a single flattened list */ + /** All type and term arguments of an application in a single flattened list */ def allArguments(tree: Tree): List[Tree] = unsplice(tree) match { case Apply(fn, args) => allArguments(fn) ::: args - case TypeApply(fn, _) => allArguments(fn) + case TypeApply(fn, args) => allArguments(fn) ::: args case Block(_, expr) => allArguments(expr) case _ => Nil } diff --git a/compiler/src/dotty/tools/dotc/core/Annotations.scala b/compiler/src/dotty/tools/dotc/core/Annotations.scala index a5ef4c26eed1..2e7945d78059 100644 --- a/compiler/src/dotty/tools/dotc/core/Annotations.scala +++ b/compiler/src/dotty/tools/dotc/core/Annotations.scala @@ -3,8 +3,9 @@ package dotc package core import Symbols.*, Types.*, Contexts.*, Constants.*, Phases.* -import ast.tpd, tpd.* -import util.Spans.Span +import ast.{tpd, untpd, TreeTypeMap} +import tpd.* +import util.Spans.{Span, NoSpan} import printing.{Showable, Printer} import printing.Texts.Text @@ -30,7 +31,7 @@ object Annotations { def derivedAnnotation(tree: Tree)(using Context): Annotation = if (tree eq this.tree) this else Annotation(tree) - /** All arguments to this annotation in a single flat list */ + /** All type and term arguments to this annotation in a single flat list */ def arguments(using Context): List[Tree] = tpd.allArguments(tree) def argument(i: Int)(using Context): Option[Tree] = { @@ -57,15 +58,23 @@ object Annotations { val args = arguments if args.isEmpty then this else + // Checks if tm would result in any change by applying on the annotations's argument and checking if the resulting types are different. val findDiff = new TreeAccumulator[Type]: def apply(x: Type, tree: Tree)(using Context): Type = if tm.isRange(x) then x else val tp1 = tm(tree.tpe) - foldOver(if tp1 frozen_=:= tree.tpe then x else tp1, tree) + foldOver(if tp1 == tree.tpe then x else tp1, tree) val diff = findDiff(NoType, args) if tm.isRange(diff) then EmptyAnnotation - else if diff.exists then derivedAnnotation(tm.mapOver(tree)) + else if diff.exists then + // In case of changes, the symbol in the annotation's tree should be + // copied so that the same symbol is not used for different trees. + val ttm = + new TreeTypeMap(typeMap = tm): + final override def withMappedSyms(syms: List[Symbol]): TreeTypeMap = + withMappedSyms(syms, mapSymbols(syms, this, mapAlways = true)) + derivedAnnotation(ttm.transform(tree)) else this /** Does this annotation refer to a parameter of `tl`? */ diff --git a/compiler/src/dotty/tools/dotc/core/tasty/PositionPickler.scala b/compiler/src/dotty/tools/dotc/core/tasty/PositionPickler.scala index 86076517021a..3d8080e72a29 100644 --- a/compiler/src/dotty/tools/dotc/core/tasty/PositionPickler.scala +++ b/compiler/src/dotty/tools/dotc/core/tasty/PositionPickler.scala @@ -33,6 +33,7 @@ object PositionPickler: pickler: TastyPickler, addrOfTree: TreeToAddr, treeAnnots: untpd.MemberDef => List[tpd.Tree], + typeAnnots: List[tpd.Tree], relativePathReference: String, source: SourceFile, roots: List[Tree], @@ -136,6 +137,9 @@ object PositionPickler: } for (root <- roots) traverse(root, NoSource) + + for annotTree <- typeAnnots do + traverse(annotTree, NoSource) end picklePositions end PositionPickler diff --git a/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala b/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala index 0a8669292a74..1cede78c96f4 100644 --- a/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala +++ b/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala @@ -40,6 +40,10 @@ class TreePickler(pickler: TastyPickler, attributes: Attributes) { */ private val annotTrees = util.EqHashMap[untpd.MemberDef, mutable.ListBuffer[Tree]]() + /** A set of annotation trees appearing in annotated types. + */ + private val annotatedTypeTrees = mutable.ListBuffer[Tree]() + /** A map from member definitions to their doc comments, so that later * parallel comment pickling does not need to access symbols of trees (which * would involve accessing symbols of named types and possibly changing phases @@ -56,6 +60,8 @@ class TreePickler(pickler: TastyPickler, attributes: Attributes) { val ts = annotTrees.lookup(tree) if ts == null then Nil else ts.toList + def typeAnnots: List[Tree] = annotatedTypeTrees.toList + def docString(tree: untpd.MemberDef): Option[Comment] = Option(docStrings.lookup(tree)) @@ -266,6 +272,7 @@ class TreePickler(pickler: TastyPickler, attributes: Attributes) { case tpe: AnnotatedType => writeByte(ANNOTATEDtype) withLength { pickleType(tpe.parent, richTypes); pickleTree(tpe.annot.tree) } + annotatedTypeTrees += tpe.annot.tree case tpe: AndType => writeByte(ANDtype) withLength { pickleType(tpe.tp1, richTypes); pickleType(tpe.tp2, richTypes) } diff --git a/compiler/src/dotty/tools/dotc/quoted/PickledQuotes.scala b/compiler/src/dotty/tools/dotc/quoted/PickledQuotes.scala index 8ebd1f6973f2..db40283076aa 100644 --- a/compiler/src/dotty/tools/dotc/quoted/PickledQuotes.scala +++ b/compiler/src/dotty/tools/dotc/quoted/PickledQuotes.scala @@ -224,7 +224,7 @@ object PickledQuotes { if tree.span.exists then val positionWarnings = new mutable.ListBuffer[Message]() val reference = ctx.settings.sourceroot.value - PositionPickler.picklePositions(pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, reference, + PositionPickler.picklePositions(pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, treePkl.typeAnnots, reference, ctx.compilationUnit.source, tree :: Nil, positionWarnings) positionWarnings.foreach(report.warning(_)) diff --git a/compiler/src/dotty/tools/dotc/transform/Pickler.scala b/compiler/src/dotty/tools/dotc/transform/Pickler.scala index 3a4212547d16..6841b9b686a8 100644 --- a/compiler/src/dotty/tools/dotc/transform/Pickler.scala +++ b/compiler/src/dotty/tools/dotc/transform/Pickler.scala @@ -143,7 +143,7 @@ class Pickler extends Phase { if tree.span.exists then val reference = ctx.settings.sourceroot.value PositionPickler.picklePositions( - pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, reference, + pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, treePkl.typeAnnots, reference, unit.source, tree :: Nil, positionWarnings, scratch.positionBuffer, scratch.pickledIndices) diff --git a/tests/pos/annot-17939.scala b/tests/pos/annot-17939.scala new file mode 100644 index 000000000000..2b3adf0ac1cc --- /dev/null +++ b/tests/pos/annot-17939.scala @@ -0,0 +1,7 @@ +class qualified[T](f: T => Boolean) extends annotation.StaticAnnotation + +class Box[T](val x: T) +class Box2(val x: Int) + +class A(a: String @qualified((x: Int) => Box(3).x == 3)) // crash +class A2(a2: String @qualified((x: Int) => Box2(3).x == 3)) // works diff --git a/tests/pos/annot-17939b.scala b/tests/pos/annot-17939b.scala new file mode 100644 index 000000000000..a48f4690d0b2 --- /dev/null +++ b/tests/pos/annot-17939b.scala @@ -0,0 +1,10 @@ +import scala.annotation.Annotation +class myRefined(f: ? => Boolean) extends Annotation + +def test(axes: Int) = true + +trait Tensor: + def mean(axes: Int): Int @myRefined(_ => test(axes)) + +class TensorImpl() extends Tensor: + def mean(axes: Int) = ??? diff --git a/tests/pos/annot-19846.scala b/tests/pos/annot-19846.scala new file mode 100644 index 000000000000..09c24a5cf3cf --- /dev/null +++ b/tests/pos/annot-19846.scala @@ -0,0 +1,8 @@ +class qualified[T](predicate: T => Boolean) extends annotation.StaticAnnotation + +class EqualPair(val x: Int, val y: Int @qualified[Int](it => it == x)) + +@main def main = + val p = EqualPair(42, 42) + val y = p.y + println(42) diff --git a/tests/pos/annot-19846b.scala b/tests/pos/annot-19846b.scala new file mode 100644 index 000000000000..81f25065d980 --- /dev/null +++ b/tests/pos/annot-19846b.scala @@ -0,0 +1,7 @@ +class qualified[T](predicate: T => Boolean) extends annotation.StaticAnnotation + +def f(x: Int): Int @qualified[Int](it => it == x) = ??? + +@main def main = + val z = f(42) + () diff --git a/tests/pos/annot-5789.scala b/tests/pos/annot-5789.scala new file mode 100644 index 000000000000..bdf4438c9d5d --- /dev/null +++ b/tests/pos/annot-5789.scala @@ -0,0 +1,10 @@ +class Annot[T] extends scala.annotation.Annotation + +class D[T](val f: Int@Annot[T]) + +object A{ + def main(a:Array[String]) = { + val c = new D[Int](1) + c.f + } +} diff --git a/tests/printing/annot-18064.check b/tests/printing/annot-18064.check new file mode 100644 index 000000000000..d93ddb95afee --- /dev/null +++ b/tests/printing/annot-18064.check @@ -0,0 +1,16 @@ +[[syntax trees at end of typer]] // tests/printing/annot-18064.scala +package { + class myAnnot[T >: Nothing <: Any]() extends annotation.Annotation() { + T + } + trait Tensor[T >: Nothing <: Any]() extends Object { + T + def add: Tensor[Tensor.this.T] @myAnnot[T] + } + class TensorImpl[A >: Nothing <: Any]() extends Object(), Tensor[ + TensorImpl.this.A] { + A + def add: Tensor[A] @myAnnot[A] = this + } +} + diff --git a/tests/printing/annot-18064.scala b/tests/printing/annot-18064.scala new file mode 100644 index 000000000000..95554fd3a1b7 --- /dev/null +++ b/tests/printing/annot-18064.scala @@ -0,0 +1,7 @@ +class myAnnot[T]() extends annotation.Annotation + +trait Tensor[T]: + def add: Tensor[T] @myAnnot[T]() + +class TensorImpl[A]() extends Tensor[A]: + def add /* : Tensor[A] @myAnnot[A] */ = this diff --git a/tests/printing/annot-19846b.check b/tests/printing/annot-19846b.check new file mode 100644 index 000000000000..3f63a46c4286 --- /dev/null +++ b/tests/printing/annot-19846b.check @@ -0,0 +1,33 @@ +[[syntax trees at end of typer]] // tests/printing/annot-19846b.scala +package { + class lambdaAnnot(g: () => Int) extends scala.annotation.Annotation(), + annotation.StaticAnnotation { + private[this] val g: () => Int + } + final lazy module val Test: Test = new Test() + final module class Test() extends Object() { this: Test.type => + val y: Int = ??? + val z: + Int @lambdaAnnot( + { + def $anonfun(): Int = Test.y + closure($anonfun) + } + ) + = f(Test.y) + } + final lazy module val annot-19846b$package: annot-19846b$package = + new annot-19846b$package() + final module class annot-19846b$package() extends Object() { + this: annot-19846b$package.type => + def f(x: Int): + Int @lambdaAnnot( + { + def $anonfun(): Int = x + closure($anonfun) + } + ) + = x + } +} + diff --git a/tests/printing/annot-19846b.scala b/tests/printing/annot-19846b.scala new file mode 100644 index 000000000000..951a3c8116ff --- /dev/null +++ b/tests/printing/annot-19846b.scala @@ -0,0 +1,7 @@ +class lambdaAnnot(g: () => Int) extends annotation.StaticAnnotation + +def f(x: Int): Int @lambdaAnnot(() => x) = x + +object Test: + val y: Int = ??? + val z /* : Int @lambdaAnnot(() => y) */ = f(y)