diff --git a/compiler/src/dotty/tools/dotc/reporting/ErrorMessageID.scala b/compiler/src/dotty/tools/dotc/reporting/ErrorMessageID.scala index b832832590d9..6dcd5edd51a8 100644 --- a/compiler/src/dotty/tools/dotc/reporting/ErrorMessageID.scala +++ b/compiler/src/dotty/tools/dotc/reporting/ErrorMessageID.scala @@ -212,6 +212,7 @@ enum ErrorMessageID(val isActive: Boolean = true) extends java.lang.Enum[ErrorMe case ContextBoundCompanionNotValueID // errorNumber: 196 - unused in LTS case InlinedAnonClassWarningID // errorNumber: 197 case UnusedSymbolID // errorNumber: 198 + case TailrecNestedCallID //errorNumber: 199 def errorNumber = ordinal - 1 diff --git a/compiler/src/dotty/tools/dotc/reporting/messages.scala b/compiler/src/dotty/tools/dotc/reporting/messages.scala index 38c79f90a733..40f4fa36b475 100644 --- a/compiler/src/dotty/tools/dotc/reporting/messages.scala +++ b/compiler/src/dotty/tools/dotc/reporting/messages.scala @@ -1899,6 +1899,20 @@ class TailrecNotApplicable(symbol: Symbol)(using Context) def explain(using Context) = "" } +class TailrecNestedCall(definition: Symbol, innerDef: Symbol)(using Context) + extends SyntaxMsg(TailrecNestedCallID) { + def msg(using Context) = { + s"The tail recursive def ${definition.name} contains a recursive call inside the non-inlined inner def ${innerDef.name}" + } + + def explain(using Context) = + """Tail recursion is only validated and optimised directly in the definition. + |Any calls to the recursive method via an inner def cannot be validated as + |tail recursive, nor optimised if they are. To enable tail recursion from + |inner calls, mark the inner def as inline. + |""".stripMargin +} + class FailureToEliminateExistential(tp: Type, tp1: Type, tp2: Type, boundSyms: List[Symbol], classRoot: Symbol)(using Context) extends Message(FailureToEliminateExistentialID) { def kind = MessageKind.Compatibility diff --git a/compiler/src/dotty/tools/dotc/transform/TailRec.scala b/compiler/src/dotty/tools/dotc/transform/TailRec.scala index 158b72f7abdb..d0cec7865027 100644 --- a/compiler/src/dotty/tools/dotc/transform/TailRec.scala +++ b/compiler/src/dotty/tools/dotc/transform/TailRec.scala @@ -427,10 +427,23 @@ class TailRec extends MiniPhase { assert(false, "We should never have gotten inside a pattern") tree - case tree: ValOrDefDef => + case tree: ValDef => if (isMandatory) noTailTransform(tree.rhs) tree + case tree: DefDef => + if (isMandatory) + if (tree.symbol.is(Synthetic)) + noTailTransform(tree.rhs) + else + // We can't tail recurse through nested definitions, so don't want to propagate to child nodes + // We don't want to fail if there is a call that would recurse (as this would be a non self recurse), so don't + // want to call noTailTransform + // We can however warn in this case, as its likely in this situation that someone would expect a tail + // recursion optimization and enabling this to optimise would be a simple case of inlining the inner method + new NestedTailRecAlerter(method, tree.symbol).traverse(tree) + tree + case _: Super | _: This | _: Literal | _: TypeTree | _: TypeDef | EmptyTree => tree @@ -444,7 +457,8 @@ class TailRec extends MiniPhase { case Return(expr, from) => val fromSym = from.symbol - val inTailPosition = !fromSym.is(Label) || tailPositionLabeledSyms.contains(fromSym) + val inTailPosition = tailPositionLabeledSyms.contains(fromSym) // Label returns are only tail if the label is in tail position + || (fromSym eq method) // Method returns are only tail if we are looking at the original method cpy.Return(tree)(transform(expr, inTailPosition), from) case _ => @@ -452,6 +466,19 @@ class TailRec extends MiniPhase { } } } + + class NestedTailRecAlerter(method: Symbol, inner: Symbol) extends TreeTraverser { + override def traverse(tree: tpd.Tree)(using Context): Unit = + tree match { + case a: Apply => + if (a.fun.symbol eq method) { + report.warning(new TailrecNestedCall(method, inner), a.srcPos) + } + traverseChildren(tree) + case _ => + traverseChildren(tree) + } + } } object TailRec { diff --git a/tests/neg/i20105.check b/tests/neg/i20105.check new file mode 100644 index 000000000000..5fb33283387b --- /dev/null +++ b/tests/neg/i20105.check @@ -0,0 +1,10 @@ +-- [E199] Syntax Warning: tests/neg/i20105.scala:6:9 ------------------------------------------------------------------- +6 | foo() + | ^^^^^ + | The tail recursive def foo contains a recursive call inside the non-inlined inner def bar + | + | longer explanation available when compiling with `-explain` +-- [E097] Syntax Error: tests/neg/i20105.scala:3:4 --------------------------------------------------------------------- +3 |def foo(): Unit = // error + | ^ + | TailRec optimisation not applicable, method foo contains no recursive calls diff --git a/tests/neg/i20105.scala b/tests/neg/i20105.scala new file mode 100644 index 000000000000..08d54e895ec1 --- /dev/null +++ b/tests/neg/i20105.scala @@ -0,0 +1,9 @@ +import scala.annotation.tailrec +@tailrec +def foo(): Unit = // error + def bar(): Unit = + if (???) + foo() + else + bar() + bar() \ No newline at end of file diff --git a/tests/neg/i5397.scala b/tests/neg/i5397.scala index d38b0e67bff9..ebe89875b3df 100644 --- a/tests/neg/i5397.scala +++ b/tests/neg/i5397.scala @@ -16,8 +16,10 @@ object Test { rec3 // error: not in tail position }) - @tailrec def rec4: Unit = { - def local = rec4 // error: not in tail position + // This is technically not breaching tail recursion as rec4 does not call itself, local does + // This instead fails due to having no tail recursion at all + @tailrec def rec4: Unit = { // error: no recursive calls + def local = rec4 } @tailrec def rec5: Int = { diff --git a/tests/warn/i20105.check b/tests/warn/i20105.check new file mode 100644 index 000000000000..d291931748cf --- /dev/null +++ b/tests/warn/i20105.check @@ -0,0 +1,6 @@ +-- [E199] Syntax Warning: tests/warn/i20105.scala:6:9 ------------------------------------------------------------------ +6 | foo() // warn + | ^^^^^ + | The tail recursive def foo contains a recursive call inside the non-inlined inner def bar + | + | longer explanation available when compiling with `-explain` diff --git a/tests/warn/i20105.scala b/tests/warn/i20105.scala new file mode 100644 index 000000000000..6d691b7e6bfb --- /dev/null +++ b/tests/warn/i20105.scala @@ -0,0 +1,10 @@ +import scala.annotation.tailrec +@tailrec +def foo(): Unit = + def bar(): Unit = + if (???) + foo() // warn + else + bar() + bar() + foo() \ No newline at end of file