Skip to content

Commit

Permalink
Fix pattern matching for get matches
Browse files Browse the repository at this point in the history
  • Loading branch information
odersky committed Dec 13, 2023
1 parent 9076944 commit 11c65aa
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 88 deletions.
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ object PatternMatcher {
assert(isGetMatch(unappType))
val argsPlan = {
val get = ref(unappResult).select(nme.get, _.info.isParameterless)
val arity = productArity(get.tpe, unapp.srcPos)
val arity = productArity(get.tpe.stripNamedTuple, unapp.srcPos)
if (isUnapplySeq)
letAbstract(get) { getResult =>
if (arity > 0) unapplyProductSeqPlan(getResult, args, arity)
Expand All @@ -389,7 +389,7 @@ object PatternMatcher {
letAbstract(get) { getResult =>
val selectors =
if (args.tail.isEmpty) ref(getResult) :: Nil
else productSelectors(get.tpe).map(ref(getResult).select(_))
else productSelectors(getResult.info).map(ref(getResult).select(_))
matchArgsPlan(selectors, args, onSuccess)
}
}
Expand Down
151 changes: 96 additions & 55 deletions compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import Names.*
import StdNames.*
import ContextOps.*
import NameKinds.DefaultGetterName
import Typer.tryEither
import ProtoTypes.*
import Inferencing.*
import reporting.*
Expand Down Expand Up @@ -134,37 +135,37 @@ object Applications {
sels.takeWhile(_.exists).toList
}

def getUnapplySelectors(tp: Type, args: List[untpd.Tree], pos: SrcPos)(using Context): List[Type] =
if (args.length > 1 && !(tp.derivesFrom(defn.SeqClass))) {
val sels = productSelectorTypes(tp, pos)
if (sels.length == args.length) sels
else tp :: Nil
}
else tp :: Nil

def productSeqSelectors(tp: Type, argsNum: Int, pos: SrcPos)(using Context): List[Type] = {
val selTps = productSelectorTypes(tp, pos)
val arity = selTps.length
val elemTp = unapplySeqTypeElemTp(selTps.last)
(0 until argsNum).map(i => if (i < arity - 1) selTps(i) else elemTp).toList
}

def unapplyArgs(unapplyResult: Type, unapplyFn: Tree, args: List[untpd.Tree], pos: SrcPos)(using Context): List[Type] =
def getName(fn: Tree): Name =
/** A utility class that matches results of unapplys with patterns. Two queriable members:
* val argTypes: List[Type]
* def typedPatterns(qual: untpd.Tree, typer: Typer): List[Tree]
* TODO: Move into Applications trait. No need to keep it outside. But it's a large
* refactor, so do this when the rest is merged.
*/
class UnapplyArgs(unapplyResult: Type, unapplyFn: Tree, unadaptedArgs: List[untpd.Tree], pos: SrcPos)(using Context):
private var args = unadaptedArgs

private def getName(fn: Tree): Name =
fn match
case TypeApply(fn, _) => getName(fn)
case Apply(fn, _) => getName(fn)
case fn: RefTree => fn.name
val unapplyName = getName(unapplyFn) // tolerate structural `unapply`, which does not have a symbol
private val unapplyName = getName(unapplyFn) // tolerate structural `unapply`, which does not have a symbol

def getTp = extractorMemberType(unapplyResult, nme.get, pos)
private def getTp = extractorMemberType(unapplyResult, nme.get, pos)

def fail = {
private def fail = {
report.error(UnapplyInvalidReturnType(unapplyResult, unapplyName), pos)
Nil
}

def unapplySeq(tp: Type)(fallback: => List[Type]): List[Type] =
private def unapplySeq(tp: Type)(fallback: => List[Type]): List[Type] =
val elemTp = unapplySeqTypeElemTp(tp)
if elemTp.exists then
args.map(Function.const(elemTp))
Expand All @@ -174,26 +175,84 @@ object Applications {
tp.tupleElementTypes.getOrElse(Nil)
else fallback

if unapplyName == nme.unapplySeq then
unapplySeq(unapplyResult):
if (isGetMatch(unapplyResult, pos)) unapplySeq(getTp)(fail)
else fail
else
assert(unapplyName == nme.unapply)
if isProductMatch(unapplyResult, args.length, pos) then
productSelectorTypes(unapplyResult, pos)
else if isGetMatch(unapplyResult, pos) then
getUnapplySelectors(getTp, args, pos)
else if unapplyResult.derivesFrom(defn.BooleanClass) then
Nil
else if defn.isProductSubType(unapplyResult) && productArity(unapplyResult, pos) != 0 then
productSelectorTypes(unapplyResult, pos)
// this will cause a "wrong number of arguments in pattern" error later on,
// which is better than the message in `fail`.
else if unapplyResult.derivesFrom(defn.NonEmptyTupleClass) then
unapplyResult.tupleElementTypes.getOrElse(Nil)
else fail
end unapplyArgs
private def tryAdaptPatternArgs(elems: List[untpd.Tree], pt: Type)(using Context): Option[List[untpd.Tree]] =
tryEither[Option[List[untpd.Tree]]]
(Some(desugar.adaptPatternArgs(elems, pt)))
((_, _) => None)

private def getUnapplySelectors(tp: Type)(using Context): List[Type] =
if args.length > 1 && !(tp.derivesFrom(defn.SeqClass)) then
productUnapplySelectors(tp).getOrElse:
// There are unapplys with return types which have `get` and `_1, ..., _n`
// as members, but which are not subtypes of Product. So `productUnapplySelectors`
// would return None for these, but they are still valid types
// for a get match. A test case is pos/extractors.scala.
val sels = productSelectorTypes(tp, pos)
if (sels.length == args.length) sels
else tp :: Nil
else tp :: Nil

private def productUnapplySelectors(tp: Type)(using Context): Option[List[Type]] =
if defn.isProductSubType(tp) then
tryAdaptPatternArgs(args, tp) match
case Some(args1) if isProductMatch(tp, args1.length, pos) =>
args = args1
Some(productSelectorTypes(tp, pos))
case _ => None
else tp.widen.normalized.dealias match
case tp @ defn.NamedTuple(_, tt) =>
tryAdaptPatternArgs(args, tp) match
case Some(args1) =>
args = args1
tt.tupleElementTypes
case _ => None
case _ => None

/** The computed argument types which will be the scutinees of the sub-patterns. */
val argTypes: List[Type] =
if unapplyName == nme.unapplySeq then
unapplySeq(unapplyResult):
if (isGetMatch(unapplyResult, pos)) unapplySeq(getTp)(fail)
else fail
else
assert(unapplyName == nme.unapply)
productUnapplySelectors(unapplyResult).getOrElse:
if isGetMatch(unapplyResult, pos) then
getUnapplySelectors(getTp)
else if unapplyResult.derivesFrom(defn.BooleanClass) then
Nil
else if unapplyResult.derivesFrom(defn.NonEmptyTupleClass) then
unapplyResult.tupleElementTypes.getOrElse(Nil)
else if defn.isProductSubType(unapplyResult) && productArity(unapplyResult, pos) != 0 then
productSelectorTypes(unapplyResult, pos)
// this will cause a "wrong number of arguments in pattern" error later on,
// which is better than the message in `fail`.
else fail

/** The typed pattens of this unapply */
def typedPatterns(qual: untpd.Tree, typer: Typer): List[Tree] =
unapp.println(i"unapplyQual = $qual, unapplyArgs = ${unapplyResult} with $argTypes / $args")
for argType <- argTypes do
assert(!isBounds(argType), unapplyResult.show)
val alignedArgs = argTypes match
case argType :: Nil
if args.lengthCompare(1) > 0
&& Feature.autoTuplingEnabled
&& defn.isTupleNType(argType) =>
untpd.Tuple(args) :: Nil
case _ =>
args
val alignedArgTypes =
if argTypes.length == alignedArgs.length then
argTypes
else
report.error(UnapplyInvalidNumberOfArguments(qual, argTypes), pos)
argTypes.take(args.length) ++
List.fill(argTypes.length - args.length)(WildcardType)
alignedArgs.lazyZip(alignedArgTypes).map(typer.typed(_, _))
.showing(i"unapply patterns = $result", unapp)

end UnapplyArgs

def wrapDefs(defs: mutable.ListBuffer[Tree] | Null, tree: Tree)(using Context): Tree =
if (defs != null && defs.nonEmpty) tpd.Block(defs.toList, tree) else tree
Expand Down Expand Up @@ -1452,28 +1511,10 @@ trait Applications extends Compatibility {
loop(unapp)
res.result()
}
val args = desugar.adaptPatternArgs(unadaptedArgs, unapplyApp.tpe)

var argTypes = unapplyArgs(unapplyApp.tpe.stripNamedTuple, unapplyFn, args, tree.srcPos)
unapp.println(i"unapplyArgs = ${unapplyApp.tpe} with $argTypes / $args")
for (argType <- argTypes) assert(!isBounds(argType), unapplyApp.tpe.show)
val bunchedArgs = argTypes match {
case argType :: Nil =>
if args.lengthCompare(1) > 0
&& Feature.autoTuplingEnabled
&& defn.isTupleNType(argType)
then untpd.Tuple(args) :: Nil
else args
case _ => args
}
if (argTypes.length != bunchedArgs.length) {
report.error(UnapplyInvalidNumberOfArguments(qual, argTypes), tree.srcPos)
argTypes = argTypes.take(args.length) ++
List.fill(argTypes.length - args.length)(WildcardType)
}
val unapplyPatterns = bunchedArgs.lazyZip(argTypes) map (typed(_, _))

val unapplyPatterns = UnapplyArgs(unapplyApp.tpe, unapplyFn, unadaptedArgs, tree.srcPos)
.typedPatterns(qual, this)
val result = assignType(cpy.UnApply(tree)(unapplyFn, unapplyImplicits(unapplyApp), unapplyPatterns), ownType)
unapp.println(s"unapply patterns = $unapplyPatterns")
if (ownType.stripped eq selType.stripped) || ownType.isError then result
else tryWithTypeTest(Typed(result, TypeTree(ownType)), selType)
case tp =>
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/Checking.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import config.Printers.{typr, patmatch}
import NameKinds.DefaultGetterName
import NameOps.*
import SymDenotations.{NoCompleter, NoDenotation}
import Applications.unapplyArgs
import Applications.UnapplyArgs
import Inferencing.isFullyDefined
import transform.patmat.SpaceEngine.{isIrrefutable, isIrrefutableQuotePattern}
import transform.ValueClasses.underlyingOfValueClass
Expand Down Expand Up @@ -952,7 +952,7 @@ trait Checking {
case UnApply(fn, implicits, pats) =>
check(pat, pt) &&
(isIrrefutable(fn, pats.length) || fail(pat, pt, Reason.RefutableExtractor)) && {
val argPts = unapplyArgs(fn.tpe.widen.finalResultType, fn, pats, pat.srcPos)
val argPts = UnapplyArgs(fn.tpe.widen.finalResultType, fn, pats, pat.srcPos).argTypes
pats.corresponds(argPts)(recur)
}
case Alternative(pats) =>
Expand Down
50 changes: 25 additions & 25 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,31 @@ object Typer {
def rememberSearchFailure(tree: tpd.Tree, fail: SearchFailure) =
tree.putAttachment(HiddenSearchFailure,
fail :: tree.attachmentOrElse(HiddenSearchFailure, Nil))

def tryEither[T](op: Context ?=> T)(fallBack: (T, TyperState) => T)(using Context): T = {
val nestedCtx = ctx.fresh.setNewTyperState()
val result = op(using nestedCtx)
if (nestedCtx.reporter.hasErrors && !nestedCtx.reporter.hasStickyErrors) {
record("tryEither.fallBack")
fallBack(result, nestedCtx.typerState)
}
else {
record("tryEither.commit")
nestedCtx.typerState.commit()
result
}
}

/** Try `op1`, if there are errors, try `op2`, if `op2` also causes errors, fall back
* to errors and result of `op1`.
*/
def tryAlternatively[T](op1: Context ?=> T)(op2: Context ?=> T)(using Context): T =
tryEither(op1) { (failedVal, failedState) =>
tryEither(op2) { (_, _) =>
failedState.commit()
failedVal
}
}
}
/** Typecheck trees, the main entry point is `typed`.
*
Expand Down Expand Up @@ -3461,31 +3486,6 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
def typedPattern(tree: untpd.Tree, selType: Type = WildcardType)(using Context): Tree =
withMode(Mode.Pattern)(typed(tree, selType))

def tryEither[T](op: Context ?=> T)(fallBack: (T, TyperState) => T)(using Context): T = {
val nestedCtx = ctx.fresh.setNewTyperState()
val result = op(using nestedCtx)
if (nestedCtx.reporter.hasErrors && !nestedCtx.reporter.hasStickyErrors) {
record("tryEither.fallBack")
fallBack(result, nestedCtx.typerState)
}
else {
record("tryEither.commit")
nestedCtx.typerState.commit()
result
}
}

/** Try `op1`, if there are errors, try `op2`, if `op2` also causes errors, fall back
* to errors and result of `op1`.
*/
def tryAlternatively[T](op1: Context ?=> T)(op2: Context ?=> T)(using Context): T =
tryEither(op1) { (failedVal, failedState) =>
tryEither(op2) { (_, _) =>
failedState.commit()
failedVal
}
}

/** Is `pt` a prototype of an `apply` selection, or a parameterless function yielding one? */
def isApplyProto(pt: Type)(using Context): Boolean = pt.revealIgnored match {
case pt: SelectionProto => pt.name == nme.apply
Expand Down
5 changes: 5 additions & 0 deletions tests/run/named-patterns.check
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ name Bob
age 22
age 22, name Bob
Bob, 22
name Bob, age 22
name (Bob,22)
age (Bob,22)
age 22, name Bob
Bob, 22
1003 Lausanne, Rue de la Gare 44
1003 Lausanne
Rue de la Gare in Lausanne
Expand Down
20 changes: 16 additions & 4 deletions tests/run/named-patterns.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ object Test1:
object Person:
def unapply(p: Person): (name: String, age: Int) = (p.name, p.age)

class Person2(val name: String, val age: Int)
object Person2:
def unapply(p: Person2): Option[(name: String, age: Int)] = Some((p.name, p.age))

case class Address(city: String, zip: Int, street: String, number: Int)

@main def Test =
Expand All @@ -21,6 +25,18 @@ object Test1:
bob match
case Person(age, name) => println(s"$age, $name")

val bob2 = Person2("Bob", 22)
bob2 match
case Person2(name = n, age = a) => println(s"name $n, age $a")
bob2 match
case Person2(name = n) => println(s"name $n")
bob2 match
case Person2(age = a) => println(s"age $a")
bob2 match
case Person2(age = a, name = n) => println(s"age $a, name $n")
bob2 match
case Person2(age, name) => println(s"$age, $name")

val addr = Address("Lausanne", 1003, "Rue de la Gare", 44)
addr match
case Address(city = c, zip = z, street = s, number = n) =>
Expand All @@ -37,7 +53,3 @@ object Test1:
addr match
case Address(c, z, s, number) =>
println(s"$z $c, $s $number")




0 comments on commit 11c65aa

Please sign in to comment.