Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #21619: Refactor NotNullInfo to record every reference which is retracted once. #21624

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions compiler/src/dotty/tools/dotc/typer/Nullables.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,10 @@ object Nullables:
val hiTree = if(hiTpe eq hi.typeOpt) hi else TypeTree(hiTpe)
TypeBoundsTree(lo, hiTree, alias)

/** A set of val or var references that are known to be not null, plus a set of
* variable references that are not known (anymore) to be not null
/** A set of val or var references that are known to be not null,
* plus a set of variable references that are once assigned to null.
*/
case class NotNullInfo(asserted: Set[TermRef], retracted: Set[TermRef]):
assert((asserted & retracted).isEmpty)

def isEmpty = this eq NotNullInfo.empty

def retractedInfo = NotNullInfo(Set(), retracted)
Expand All @@ -67,15 +65,18 @@ object Nullables:
if this.isEmpty then that
else if that.isEmpty then this
else NotNullInfo(
this.asserted.union(that.asserted).diff(that.retracted),
this.retracted.union(that.retracted).diff(that.asserted))
this.asserted.diff(that.retracted).union(that.asserted),
this.retracted.union(that.retracted))

/** The alternative path combination with another not-null info. Used to merge
* the nullability info of the two branches of an if.
*/
def alt(that: NotNullInfo): NotNullInfo =
NotNullInfo(this.asserted.intersect(that.asserted), this.retracted.union(that.retracted))

def withRetracted(that: NotNullInfo): NotNullInfo =
NotNullInfo(this.asserted, this.retracted.union(that.retracted))

object NotNullInfo:
val empty = new NotNullInfo(Set(), Set())
def apply(asserted: Set[TermRef], retracted: Set[TermRef]): NotNullInfo =
Expand Down Expand Up @@ -233,16 +234,13 @@ object Nullables:
* or retractions in `info` supersede infos in existing entries of `infos`.
*/
def extendWith(info: NotNullInfo) =
if info.isEmpty
|| info.asserted.forall(infos.impliesNotNull(_))
&& !info.retracted.exists(infos.impliesNotNull(_))
then infos
if info.isEmpty then infos
else info :: infos

/** Retract all references to mutable variables */
def retractMutables(using Context) =
val mutables = infos.foldLeft(Set[TermRef]())((ms, info) =>
ms.union(info.asserted.filter(_.symbol.is(Mutable))))
val mutables = infos.foldLeft(Set[TermRef]()):
(ms, info) => ms.union(info.asserted.filter(_.symbol.is(Mutable)))
infos.extendWith(NotNullInfo(Set(), mutables))

end extension
Expand Down
37 changes: 24 additions & 13 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1543,8 +1543,10 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
def thenPathInfo = cond1.notNullInfoIf(true).seq(result.thenp.notNullInfo)
def elsePathInfo = cond1.notNullInfoIf(false).seq(result.elsep.notNullInfo)
result.withNotNullInfo(
if result.thenp.tpe.isRef(defn.NothingClass) then elsePathInfo
else if result.elsep.tpe.isRef(defn.NothingClass) then thenPathInfo
if result.thenp.tpe.isRef(defn.NothingClass) then
elsePathInfo.withRetracted(thenPathInfo)
else if result.elsep.tpe.isRef(defn.NothingClass) then
thenPathInfo.withRetracted(elsePathInfo)
else thenPathInfo.alt(elsePathInfo)
)
end typedIf
Expand Down Expand Up @@ -2151,9 +2153,9 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
def typedMatchFinish(tree: untpd.Match, sel: Tree, wideSelType: Type, cases: List[untpd.CaseDef], pt: Type)(using Context): Tree = {
val cases1 = harmonic(harmonize, pt)(typedCases(cases, sel, wideSelType, pt.dropIfProto))
.asInstanceOf[List[CaseDef]]
var nni = sel.notNullInfo
if cases1.nonEmpty then nni = nni.seq(cases1.map(_.notNullInfo).reduce(_.alt(_)))
assignType(cpy.Match(tree)(sel, cases1), sel, cases1).withNotNullInfo(nni)
var nnInfo = sel.notNullInfo
if cases1.nonEmpty then nnInfo = nnInfo.seq(cases1.map(_.notNullInfo).reduce(_.alt(_)))
assignType(cpy.Match(tree)(sel, cases1), sel, cases1).withNotNullInfo(nnInfo)
}

def typedCases(cases: List[untpd.CaseDef], sel: Tree, wideSelType0: Type, pt: Type)(using Context): List[CaseDef] =
Expand Down Expand Up @@ -2335,7 +2337,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
val capabilityProof = caughtExceptions.reduce(OrType(_, _, true))
untpd.Block(makeCanThrow(capabilityProof), expr)

def typedTry(tree: untpd.Try, pt: Type)(using Context): Try = {
def typedTry(tree: untpd.Try, pt: Type)(using Context): Try =
var nnInfo = NotNullInfo.empty
val expr2 :: cases2x = harmonic(harmonize, pt) {
// We want to type check tree.expr first to comput NotNullInfo, but `addCanThrowCapabilities`
// uses the types of patterns in `tree.cases` to determine the capabilities.
Expand All @@ -2347,18 +2350,26 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
val casesEmptyBody1 = tree.cases.mapconserve(cpy.CaseDef(_)(body = EmptyTree))
val casesEmptyBody2 = typedCases(casesEmptyBody1, EmptyTree, defn.ThrowableType, WildcardType)
val expr1 = typed(addCanThrowCapabilities(tree.expr, casesEmptyBody2), pt.dropIfProto)
val casesCtx = ctx.addNotNullInfo(expr1.notNullInfo.retractedInfo)

// Since we don't know at which point the the exception is thrown in the body,
// we have to collect any reference that is once retracted.
nnInfo = expr1.notNullInfo.retractedInfo

val casesCtx = ctx.addNotNullInfo(nnInfo)
val cases1 = typedCases(tree.cases, EmptyTree, defn.ThrowableType, pt.dropIfProto)(using casesCtx)
expr1 :: cases1
}: @unchecked
val cases2 = cases2x.asInstanceOf[List[CaseDef]]

var nni = expr2.notNullInfo.retractedInfo
if cases2.nonEmpty then nni = nni.seq(cases2.map(_.notNullInfo.retractedInfo).reduce(_.alt(_)))
val finalizer1 = typed(tree.finalizer, defn.UnitType)(using ctx.addNotNullInfo(nni))
nni = nni.seq(finalizer1.notNullInfo)
assignType(cpy.Try(tree)(expr2, cases2, finalizer1), expr2, cases2).withNotNullInfo(nni)
}
// It is possible to have non-exhaustive cases, and some exceptions are thrown and not caught.
// Therefore, the code in the finallizer and after the try block can only rely on the retracted
// info from the cases' body.
if cases2.nonEmpty then
nnInfo = nnInfo.seq(cases2.map(_.notNullInfo.retractedInfo).reduce(_.alt(_)))

val finalizer1 = typed(tree.finalizer, defn.UnitType)(using ctx.addNotNullInfo(nnInfo))
nnInfo = nnInfo.seq(finalizer1.notNullInfo)
assignType(cpy.Try(tree)(expr2, cases2, finalizer1), expr2, cases2).withNotNullInfo(nnInfo)

def typedTry(tree: untpd.ParsedTry, pt: Type)(using Context): Try =
val cases: List[untpd.CaseDef] = tree.handler match
Expand Down
6 changes: 3 additions & 3 deletions tests/explicit-nulls/neg/i21380c.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ def test4: Int =
case npe: NullPointerException => x = ""
case _ => x = ""
x.length // error
// Although the catch block here is exhaustive,
// it is possible that the exception is thrown and not caught.
// Therefore, the code after the try block can only rely on the retracted info.
// Although the catch block here is exhaustive, it is possible to have non-exhaustive cases,
// and some exceptions are thrown and not caught. Therefore, the code in the finallizer and
// after the try block can only rely on the retracted info from the cases' body.

def test5: Int =
var x: String | Null = null
Expand Down
79 changes: 79 additions & 0 deletions tests/explicit-nulls/neg/i21619.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
def test1: String =
var x: String | Null = null
x = ""
var i: Int = 1
try
i match
case _ =>
x = null
throw new Exception()
x = ""
catch
case e: Exception =>
x.replace("", "") // error

def test2: String =
var x: String | Null = null
x = ""
var i: Int = 1
try
i match
case _ =>
x = null
throw new Exception()
x = ""
catch
case e: Exception =>
x = "e"
x.replace("", "") // error

def test3: String =
var x: String | Null = null
x = ""
var i: Int = 1
try
i match
case _ =>
x = null
throw new Exception()
x = ""
catch
case e: Exception =>
finally
x = "f"
x.replace("", "") // ok

def test4: String =
var x: String | Null = null
x = ""
var i: Int = 1
try
try
if i == 1 then
x = null
throw new Exception()
else
x = ""
catch
case _ =>
x = ""
catch
case _ =>
x.replace("", "") // error

def test5: Unit =
var x: String | Null = null
var y: String | Null = null
x = ""
y = ""
var i: Int = 1
try
i match
case _ =>
x = null
throw new Exception()
x = ""
catch
case _ =>
val z1: String = x.replace("", "") // error
val z2: String = y.replace("", "")
Loading