Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup retains annotations in all inferred type trees #20305

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1135,6 +1135,13 @@ class CheckCaptures extends Recheck, SymTransformer:
else actual
end adaptBoxed

/** Replace all variable capture sets with constants */
class MakeCapturesConstant(using Context) extends TypeMap with IdempotentCaptRefMap:
def apply(tp: Type): Type = tp match
case CapturingType(parent, refs: CaptureSet.Var) =>
tp.derivedCapturingType(mapOver(parent), CaptureSet(refs.elems))
case _ => mapOver(tp)

/** Check overrides again, taking capture sets into account.
* TODO: Can we avoid doing overrides checks twice?
* We need to do them here since only at this phase CaptureTypes are relevant
Expand Down Expand Up @@ -1162,7 +1169,13 @@ class CheckCaptures extends Recheck, SymTransformer:
adapted.stripCapturing
case _ => adapted
finally curEnv = saved
actual1 frozen_<:< expected1

// Make variable capture sets constant before performing the check
val makeConst = new MakeCapturesConstant
val actual2 = makeConst(actual1)
val expected2 = makeConst(expected1)

actual2 frozen_<:< expected2

override def needsCheck(overriding: Symbol, overridden: Symbol)(using Context): Boolean =
!setup.isPreCC(overriding) && !setup.isPreCC(overridden)
Expand Down
10 changes: 4 additions & 6 deletions compiler/src/dotty/tools/dotc/cc/Setup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -329,10 +329,10 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
end transformExplicitType

/** Transform type of type tree, and remember the transformed type as the type the tree */
private def transformTT(tree: TypeTree, boxed: Boolean, exact: Boolean)(using Context): Unit =
private def transformTT(tree: TypeTree, boxed: Boolean)(using Context): Unit =
if !tree.hasRememberedType then
val transformed =
if tree.isInstanceOf[InferredTypeTree] && !exact
if tree.isInstanceOf[InferredTypeTree]
then transformInferredType(tree.tpe)
else transformExplicitType(tree.tpe, tptToCheck = Some(tree))
tree.rememberType(if boxed then box(transformed) else transformed)
Expand Down Expand Up @@ -394,8 +394,6 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
transformTT(tpt,
boxed = !ccConfig.allowUniversalInBoxed && sym.is(Mutable, butNot = Method),
// types of mutable variables are boxed in pre 3.3 codee
exact = sym.allOverriddenSymbols.hasNext,
// types of symbols that override a parent don't get a capture set TODO drop
)
catch case ex: IllegalCaptureRef =>
capt.println(i"fail while transforming result type $tpt of $sym")
Expand Down Expand Up @@ -437,7 +435,7 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
// No need to box type arguments of an asInstanceOf call. See #20224.
case _ =>
for case arg: TypeTree <- args do
transformTT(arg, boxed = true, exact = false) // type arguments in type applications are boxed
transformTT(arg, boxed = true) // type arguments in type applications are boxed

case tree: TypeDef if tree.symbol.isClass =>
inContext(ctx.withOwner(tree.symbol)):
Expand All @@ -454,7 +452,7 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:

def postProcess(tree: Tree)(using Context): Unit = tree match
case tree: TypeTree =>
transformTT(tree, boxed = false, exact = false)
transformTT(tree, boxed = false)
case tree: ValOrDefDef =>
val sym = tree.symbol

Expand Down
16 changes: 7 additions & 9 deletions compiler/src/dotty/tools/dotc/transform/PostTyper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -301,12 +301,9 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
// during the typer, it is infeasible to correctly infer the capture sets in most
// cases, resulting ill-formed capture sets that could crash the pickler later on.
// See #20035.
private def cleanupRetainsAnnot(symbol: Symbol, tpt: Tree)(using Context): Tree =
private def cleanupRetainsAnnot(tpt: Tree)(using Context): Tree =
tpt match
case tpt: InferredTypeTree
if !symbol.allOverriddenSymbols.hasNext =>
// if there are overridden symbols, the annotation comes from an explicit type of the overridden symbol
// and should be retained.
case tpt: InferredTypeTree =>
val tm = new CleanupRetains
val tpe1 = tm(tpt.tpe)
tpt.withType(tpe1)
Expand Down Expand Up @@ -421,7 +418,7 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
registerIfHasMacroAnnotations(tree)
checkErasedDef(tree)
Checking.checkPolyFunctionType(tree.tpt)
val tree1 = cpy.ValDef(tree)(tpt = cleanupRetainsAnnot(tree.symbol, tree.tpt), rhs = normalizeErasedRhs(tree.rhs, tree.symbol))
val tree1 = cpy.ValDef(tree)(rhs = normalizeErasedRhs(tree.rhs, tree.symbol))
if tree1.removeAttachment(desugar.UntupledParam).isDefined then
checkStableSelection(tree.rhs)
processValOrDefDef(super.transform(tree1))
Expand All @@ -431,7 +428,7 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
checkErasedDef(tree)
Checking.checkPolyFunctionType(tree.tpt)
annotateContextResults(tree)
val tree1 = cpy.DefDef(tree)(tpt = cleanupRetainsAnnot(tree.symbol, tree.tpt), rhs = normalizeErasedRhs(tree.rhs, tree.symbol))
val tree1 = cpy.DefDef(tree)(rhs = normalizeErasedRhs(tree.rhs, tree.symbol))
processValOrDefDef(superAcc.wrapDefDef(tree1)(super.transform(tree1).asInstanceOf[DefDef]))
case tree: TypeDef =>
registerIfHasMacroAnnotations(tree)
Expand Down Expand Up @@ -504,8 +501,9 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
report.error(em"type ${alias.tpe} outside bounds $bounds", tree.srcPos)
super.transform(tree)
case tree: TypeTree =>
tree.withType(
tree.tpe match {
val tree1 = cleanupRetainsAnnot(tree)
tree1.withType(
tree1.tpe match {
case AnnotatedType(tpe, annot) => AnnotatedType(tpe, transformAnnot(annot))
case tpe => tpe
}
Expand Down
27 changes: 20 additions & 7 deletions tests/neg-custom-args/captures/lazylist.check
Original file line number Diff line number Diff line change
@@ -1,10 +1,3 @@
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylist.scala:17:15 -------------------------------------
17 | def tail = xs() // error
| ^^^^
| Found: lazylists.LazyList[T]^{LazyCons.this.xs}
| Required: lazylists.LazyList[T]
|
| longer explanation available when compiling with `-explain`
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylist.scala:35:29 -------------------------------------
35 | val ref1c: LazyList[Int] = ref1 // error
| ^^^^
Expand Down Expand Up @@ -33,10 +26,30 @@
| Required: lazylists.LazyList[Int]^{cap1, ref3, cap3}
|
| longer explanation available when compiling with `-explain`
-- [E164] Declaration Error: tests/neg-custom-args/captures/lazylist.scala:17:6 ----------------------------------------
17 | def tail = xs() // error // error
| ^
| error overriding method tail in class LazyList of type -> lazylists.LazyList[T];
| method tail of type -> lazylists.LazyList[box T^?]^{LazyCons.this.xs} has incompatible type
|
| longer explanation available when compiling with `-explain`
-- [E164] Declaration Error: tests/neg-custom-args/captures/lazylist.scala:22:6 ----------------------------------------
22 | def tail: LazyList[Nothing]^ = ??? // error overriding
| ^
| error overriding method tail in class LazyList of type -> lazylists.LazyList[Nothing];
| method tail of type -> lazylists.LazyList[Nothing]^ has incompatible type
|
| longer explanation available when compiling with `-explain`
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylist.scala:17:15 -------------------------------------
17 | def tail = xs() // error // error
| ^^^^
| Found: lazylists.LazyList[T]^{LazyCons.this.xs}
| Required: lazylists.LazyList[T]
|
| Note that the expected type lazylists.LazyList[T]
| is the previously inferred result type of method tail
| which is also the type seen in separately compiled sources.
| The new inferred type lazylists.LazyList[T]^{LazyCons.this.xs}
| must conform to this type.
|
| longer explanation available when compiling with `-explain`
2 changes: 1 addition & 1 deletion tests/neg-custom-args/captures/lazylist.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ abstract class LazyList[+T]:
class LazyCons[+T](val x: T, val xs: () => LazyList[T]^) extends LazyList[T]:
def isEmpty = false
def head = x
def tail = xs() // error
def tail = xs() // error // error

object LazyNil extends LazyList[Nothing]:
def isEmpty = true
Expand Down
20 changes: 20 additions & 0 deletions tests/pos-custom-args/captures/i20272a.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import language.experimental.captureChecking

trait Iterable[T] { self: Iterable[T]^ =>
def map[U](f: T => U): Iterable[U]^{this, f}
}

object Test {
def assertEquals[A, B](a: A, b: B): Boolean = ???

def foo[T](level: Int, lines: Iterable[T]) =
lines.map(x => x)

def bar(messages: Iterable[String]) =
foo(1, messages)

val it: Iterable[String] = ???
val msgs = bar(it)

assertEquals(msgs, msgs)
}
16 changes: 16 additions & 0 deletions tests/pos-custom-args/captures/i20272b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import language.experimental.captureChecking

trait Iterable[T] { self: Iterable[T]^ =>
def map[U](f: T => U): Iterable[U]^{this, f}
}

object Test {
def foo[T](level: Int, lines: Iterable[T]) =
lines.map(x => x)

class Bar:
def bar(messages: Iterable[String]) =
foo(1, messages)
class Baz extends Bar:
override def bar(messages: Iterable[String]) = ???
}
Loading