Skip to content

Commit

Permalink
Named tuples second implementation (#19174)
Browse files Browse the repository at this point in the history
This implementation follows the alternative representation scheme, where
a named tuple type is represented as a
pair of two tuples: one for the names, the other for the values. 

Compare with #19075, where named tupes were regular types, with special
element types that combine name and value.

In both cases, we use an opaque type alias so that named tuples are
represented at runtime by just their values - the names are forgotten.

The new implementation has some advantages

- we can control in the type that named and unnamed elements are not
mixed,
 - no element types are leaked to user code,
- non-sensical operations such as concatenating a named and an unnamed
tuple are statically excluded,
- it's generally easier to enforce well-formedness constraints on the
type level.

The main disadvantage compared to #19075 is that there is a certain
amount of duplication in types and methods between `Tuple` and
`NamedTuple`. On the other hand, we can make sure by this that no
non-sensical tuple operations are accidentally applied to named tuples.
  • Loading branch information
odersky authored May 7, 2024
2 parents 5854959 + f80a8dd commit 7e27c4b
Show file tree
Hide file tree
Showing 61 changed files with 2,365 additions and 292 deletions.
110 changes: 94 additions & 16 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ import Decorators.*
import Annotations.Annotation
import NameKinds.{UniqueName, ContextBoundParamName, ContextFunctionParamName, DefaultGetterName, WildcardParamName}
import typer.{Namer, Checking}
import util.{Property, SourceFile, SourcePosition, Chars}
import util.{Property, SourceFile, SourcePosition, SrcPos, Chars}
import config.Feature.{sourceVersion, migrateTo3, enabled}
import config.SourceVersion.*
import collection.mutable.ListBuffer
import collection.mutable
import reporting.*
import annotation.constructorOnly
import printing.Formatting.hl
Expand Down Expand Up @@ -234,7 +234,7 @@ object desugar {

private def elimContextBounds(meth: DefDef, isPrimaryConstructor: Boolean)(using Context): DefDef =
val DefDef(_, paramss, tpt, rhs) = meth
val evidenceParamBuf = ListBuffer[ValDef]()
val evidenceParamBuf = mutable.ListBuffer[ValDef]()

var seenContextBounds: Int = 0
def desugarContextBounds(rhs: Tree): Tree = rhs match
Expand Down Expand Up @@ -1254,8 +1254,9 @@ object desugar {
pats.forall(isVarPattern)
case _ => false
}

val isMatchingTuple: Tree => Boolean = {
case Tuple(es) => isTuplePattern(es.length)
case Tuple(es) => isTuplePattern(es.length) && !hasNamedArg(es)
case _ => false
}

Expand Down Expand Up @@ -1441,22 +1442,99 @@ object desugar {
AppliedTypeTree(
TypeTree(defn.throwsAlias.typeRef).withSpan(op.span), tpt :: excepts :: Nil)

/** Translate tuple expressions of arity <= 22
private def checkWellFormedTupleElems(elems: List[Tree])(using Context): List[Tree] =
val seen = mutable.Set[Name]()
for case arg @ NamedArg(name, _) <- elems do
if seen.contains(name) then
report.error(em"Duplicate tuple element name", arg.srcPos)
seen += name
if name.startsWith("_") && name.toString.tail.toIntOption.isDefined then
report.error(
em"$name cannot be used as the name of a tuple element because it is a regular tuple selector",
arg.srcPos)

elems match
case elem :: elems1 =>
val mismatchOpt =
if elem.isInstanceOf[NamedArg]
then elems1.find(!_.isInstanceOf[NamedArg])
else elems1.find(_.isInstanceOf[NamedArg])
mismatchOpt match
case Some(misMatch) =>
report.error(em"Illegal combination of named and unnamed tuple elements", misMatch.srcPos)
elems.mapConserve(stripNamedArg)
case None => elems
case _ => elems
end checkWellFormedTupleElems

/** Translate tuple expressions
*
* () ==> ()
* (t) ==> t
* (t1, ..., tN) ==> TupleN(t1, ..., tN)
*/
def smallTuple(tree: Tuple)(using Context): Tree = {
val ts = tree.trees
val arity = ts.length
assert(arity <= Definitions.MaxTupleArity)
def tupleTypeRef = defn.TupleType(arity).nn
if (arity == 0)
if (ctx.mode is Mode.Type) TypeTree(defn.UnitType) else unitLiteral
else if (ctx.mode is Mode.Type) AppliedTypeTree(ref(tupleTypeRef), ts)
else Apply(ref(tupleTypeRef.classSymbol.companionModule.termRef), ts)
}
def tuple(tree: Tuple, pt: Type)(using Context): Tree =
var elems = checkWellFormedTupleElems(tree.trees)
if ctx.mode.is(Mode.Pattern) then elems = adaptPatternArgs(elems, pt)
val elemValues = elems.mapConserve(stripNamedArg)
val tup =
val arity = elems.length
if arity <= Definitions.MaxTupleArity then
def tupleTypeRef = defn.TupleType(arity).nn
val tree1 =
if arity == 0 then
if ctx.mode is Mode.Type then TypeTree(defn.UnitType) else unitLiteral
else if ctx.mode is Mode.Type then AppliedTypeTree(ref(tupleTypeRef), elemValues)
else Apply(ref(tupleTypeRef.classSymbol.companionModule.termRef), elemValues)
tree1.withSpan(tree.span)
else
cpy.Tuple(tree)(elemValues)
val names = elems.collect:
case NamedArg(name, arg) => name
if names.isEmpty || ctx.mode.is(Mode.Pattern) then
tup
else
def namesTuple = withModeBits(ctx.mode &~ Mode.Pattern | Mode.Type):
tuple(Tuple(
names.map: name =>
SingletonTypeTree(Literal(Constant(name.toString))).withSpan(tree.span)),
WildcardType)
if ctx.mode.is(Mode.Type) then
AppliedTypeTree(ref(defn.NamedTupleTypeRef), namesTuple :: tup :: Nil)
else
TypeApply(
Apply(Select(ref(defn.NamedTupleModule), nme.withNames), tup),
namesTuple :: Nil)

/** When desugaring a list pattern arguments `elems` adapt them and the
* expected type `pt` to each other. This means:
* - If `elems` are named pattern elements, rearrange them to match `pt`.
* This requires all names in `elems` to be also present in `pt`.
*/
def adaptPatternArgs(elems: List[Tree], pt: Type)(using Context): List[Tree] =

def reorderedNamedArgs(wildcardSpan: Span): List[untpd.Tree] =
var selNames = pt.namedTupleElementTypes.map(_(0))
if selNames.isEmpty && pt.classSymbol.is(CaseClass) then
selNames = pt.classSymbol.caseAccessors.map(_.name.asTermName)
val nameToIdx = selNames.zipWithIndex.toMap
val reordered = Array.fill[untpd.Tree](selNames.length):
untpd.Ident(nme.WILDCARD).withSpan(wildcardSpan)
for case arg @ NamedArg(name: TermName, _) <- elems do
nameToIdx.get(name) match
case Some(idx) =>
if reordered(idx).isInstanceOf[Ident] then
reordered(idx) = arg
else
report.error(em"Duplicate named pattern", arg.srcPos)
case _ =>
report.error(em"No element named `$name` is defined in selector type $pt", arg.srcPos)
reordered.toList

elems match
case (first @ NamedArg(_, _)) :: _ => reorderedNamedArgs(first.span.startPos)
case _ => elems
end adaptPatternArgs

private def isTopLevelDef(stat: Tree)(using Context): Boolean = stat match
case _: ValDef | _: PatDef | _: DefDef | _: Export | _: ExtMethods => true
Expand Down Expand Up @@ -1990,7 +2068,7 @@ object desugar {
* without duplicates
*/
private def getVariables(tree: Tree, shouldAddGiven: Context ?=> Bind => Boolean)(using Context): List[VarInfo] = {
val buf = ListBuffer[VarInfo]()
val buf = mutable.ListBuffer[VarInfo]()
def seenName(name: Name) = buf exists (_._1.name == name)
def add(named: NameTree, t: Tree): Unit =
if (!seenName(named.name) && named.name.isTermName) buf += ((named, t))
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ trait TreeInfo[T <: Untyped] { self: Trees.Instance[T] =>
case _ =>
tree

def stripNamedArg(tree: Tree) = tree match
case NamedArg(_, arg) => arg
case _ => tree

/** The number of arguments in an application */
def numArgs(tree: Tree): Int = unsplice(tree) match {
case Apply(fn, args) => numArgs(fn) + args.length
Expand Down
10 changes: 5 additions & 5 deletions compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
def forwardTo: Tree = t
}
case class Tuple(trees: List[Tree])(implicit @constructorOnly src: SourceFile) extends Tree {
override def isTerm: Boolean = trees.isEmpty || trees.head.isTerm
override def isTerm: Boolean = trees.isEmpty || stripNamedArg(trees.head).isTerm
override def isType: Boolean = !isTerm
}
case class Throw(expr: Tree)(implicit @constructorOnly src: SourceFile) extends TermTree
Expand Down Expand Up @@ -528,15 +528,15 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
def makeSelfDef(name: TermName, tpt: Tree)(using Context): ValDef =
ValDef(name, tpt, EmptyTree).withFlags(PrivateLocal)

def makeTupleOrParens(ts: List[Tree])(using Context): Tree = ts match {
def makeTupleOrParens(ts: List[Tree])(using Context): Tree = ts match
case (t: NamedArg) :: Nil => Tuple(t :: Nil)
case t :: Nil => Parens(t)
case _ => Tuple(ts)
}

def makeTuple(ts: List[Tree])(using Context): Tree = ts match {
def makeTuple(ts: List[Tree])(using Context): Tree = ts match
case (t: NamedArg) :: Nil => Tuple(t :: Nil)
case t :: Nil => t
case _ => Tuple(ts)
}

def makeAndType(left: Tree, right: Tree)(using Context): AppliedTypeTree =
AppliedTypeTree(ref(defn.andType.typeRef), left :: right :: Nil)
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/config/Feature.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ object Feature:
val pureFunctions = experimental("pureFunctions")
val captureChecking = experimental("captureChecking")
val into = experimental("into")
val namedTuples = experimental("namedTuples")

def experimentalAutoEnableFeatures(using Context): List[TermName] =
defn.languageExperimentalFeatures
Expand Down
6 changes: 3 additions & 3 deletions compiler/src/dotty/tools/dotc/core/Contexts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,14 @@ object Contexts {
inline def atPhaseNoEarlier[T](limit: Phase)(inline op: Context ?=> T)(using Context): T =
op(using if !limit.exists || limit <= ctx.phase then ctx else ctx.withPhase(limit))

inline private def inMode[T](mode: Mode)(inline op: Context ?=> T)(using ctx: Context): T =
inline def withModeBits[T](mode: Mode)(inline op: Context ?=> T)(using ctx: Context): T =
op(using if mode != ctx.mode then ctx.fresh.setMode(mode) else ctx)

inline def withMode[T](mode: Mode)(inline op: Context ?=> T)(using ctx: Context): T =
inMode(ctx.mode | mode)(op)
withModeBits(ctx.mode | mode)(op)

inline def withoutMode[T](mode: Mode)(inline op: Context ?=> T)(using ctx: Context): T =
inMode(ctx.mode &~ mode)(op)
withModeBits(ctx.mode &~ mode)(op)

/** A context is passed basically everywhere in dotc.
* This is convenient but carries the risk of captured contexts in
Expand Down
17 changes: 16 additions & 1 deletion compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -949,6 +949,9 @@ class Definitions {
def TupleXXL_fromIterator(using Context): Symbol = TupleXXLModule.requiredMethod("fromIterator")
def TupleXXL_unapplySeq(using Context): Symbol = TupleXXLModule.requiredMethod(nme.unapplySeq)

@tu lazy val NamedTupleModule = requiredModule("scala.NamedTuple")
@tu lazy val NamedTupleTypeRef: TypeRef = NamedTupleModule.termRef.select(tpnme.NamedTuple).asInstanceOf

@tu lazy val RuntimeTupleMirrorTypeRef: TypeRef = requiredClassRef("scala.runtime.TupleMirror")

@tu lazy val RuntimeTuplesModule: Symbol = requiredModule("scala.runtime.Tuples")
Expand Down Expand Up @@ -1304,9 +1307,20 @@ class Definitions {
case ByNameFunction(_) => true
case _ => false

object NamedTuple:
def apply(nmes: Type, vals: Type)(using Context): Type =
AppliedType(NamedTupleTypeRef, nmes :: vals :: Nil)
def unapply(t: Type)(using Context): Option[(Type, Type)] = t match
case AppliedType(tycon, nmes :: vals :: Nil) if tycon.typeSymbol == NamedTupleTypeRef.symbol =>
Some((nmes, vals))
case _ => None

final def isCompiletime_S(sym: Symbol)(using Context): Boolean =
sym.name == tpnme.S && sym.owner == CompiletimeOpsIntModuleClass

final def isNamedTuple_From(sym: Symbol)(using Context): Boolean =
sym.name == tpnme.From && sym.owner == NamedTupleModule.moduleClass

private val compiletimePackageAnyTypes: Set[Name] = Set(
tpnme.Equals, tpnme.NotEquals, tpnme.IsConst, tpnme.ToString
)
Expand Down Expand Up @@ -1335,7 +1349,7 @@ class Definitions {
tpnme.Plus, tpnme.Length, tpnme.Substring, tpnme.Matches, tpnme.CharAt
)
private val compiletimePackageOpTypes: Set[Name] =
Set(tpnme.S)
Set(tpnme.S, tpnme.From)
++ compiletimePackageAnyTypes
++ compiletimePackageIntTypes
++ compiletimePackageLongTypes
Expand All @@ -1348,6 +1362,7 @@ class Definitions {
compiletimePackageOpTypes.contains(sym.name)
&& (
isCompiletime_S(sym)
|| isNamedTuple_From(sym)
|| sym.owner == CompiletimeOpsAnyModuleClass && compiletimePackageAnyTypes.contains(sym.name)
|| sym.owner == CompiletimeOpsIntModuleClass && compiletimePackageIntTypes.contains(sym.name)
|| sym.owner == CompiletimeOpsLongModuleClass && compiletimePackageLongTypes.contains(sym.name)
Expand Down
5 changes: 5 additions & 0 deletions compiler/src/dotty/tools/dotc/core/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,8 @@ object StdNames {
val EnumValue: N = "EnumValue"
val ExistentialTypeTree: N = "ExistentialTypeTree"
val Flag : N = "Flag"
val Fields: N = "Fields"
val From: N = "From"
val Ident: N = "Ident"
val Import: N = "Import"
val Literal: N = "Literal"
Expand All @@ -374,6 +376,7 @@ object StdNames {
val MirroredMonoType: N = "MirroredMonoType"
val MirroredType: N = "MirroredType"
val Modifiers: N = "Modifiers"
val NamedTuple: N = "NamedTuple"
val NestedAnnotArg: N = "NestedAnnotArg"
val NoFlags: N = "NoFlags"
val NoPrefix: N = "NoPrefix"
Expand Down Expand Up @@ -620,6 +623,7 @@ object StdNames {
val throws: N = "throws"
val toArray: N = "toArray"
val toList: N = "toList"
val toTuple: N = "toTuple"
val toObjectArray : N = "toObjectArray"
val toSeq: N = "toSeq"
val toString_ : N = "toString"
Expand Down Expand Up @@ -649,6 +653,7 @@ object StdNames {
val wildcardType: N = "wildcardType"
val withFilter: N = "withFilter"
val withFilterIfRefutable: N = "withFilterIfRefutable$"
val withNames: N = "withNames"
val WorksheetWrapper: N = "WorksheetWrapper"
val wrap: N = "wrap"
val writeReplace: N = "writeReplace"
Expand Down
28 changes: 25 additions & 3 deletions compiler/src/dotty/tools/dotc/core/TypeEval.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@ import Types.*, Contexts.*, Symbols.*, Constants.*, Decorators.*
import config.Printers.typr
import reporting.trace
import StdNames.tpnme
import Flags.CaseClass
import TypeOps.nestedPairs

object TypeEval:

def tryCompiletimeConstantFold(tp: AppliedType)(using Context): Type = tp.tycon match
case tycon: TypeRef if defn.isCompiletimeAppliedType(tycon.symbol) =>

extension (tp: Type) def fixForEvaluation: Type =
tp.normalized.dealias match
// enable operations for constant singleton terms. E.g.:
Expand Down Expand Up @@ -94,6 +97,22 @@ object TypeEval:
throw TypeError(em"${e.getMessage.nn}")
ConstantType(Constant(result))

def fieldsOf: Option[Type] =
expectArgsNum(1)
val arg = tp.args.head
val cls = arg.classSymbol
if cls.is(CaseClass) then
val fields = cls.caseAccessors
val fieldLabels = fields.map: field =>
ConstantType(Constant(field.name.toString))
val fieldTypes = fields.map(arg.memberInfo)
Some:
defn.NamedTupleTypeRef.appliedTo:
nestedPairs(fieldLabels) :: nestedPairs(fieldTypes) :: Nil
else arg.widenDealias match
case arg @ defn.NamedTuple(_, _) => Some(arg)
case _ => None

def constantFold1[T](extractor: Type => Option[T], op: T => Any): Option[Type] =
expectArgsNum(1)
extractor(tp.args.head).map(a => runConstantOp(op(a)))
Expand Down Expand Up @@ -122,11 +141,14 @@ object TypeEval:
yield runConstantOp(op(a, b, c))

trace(i"compiletime constant fold $tp", typr, show = true) {
val name = tycon.symbol.name
val owner = tycon.symbol.owner
val sym = tycon.symbol
val name = sym.name
val owner = sym.owner
val constantType =
if defn.isCompiletime_S(tycon.symbol) then
if defn.isCompiletime_S(sym) then
constantFold1(natValue, _ + 1)
else if defn.isNamedTuple_From(sym) then
fieldsOf
else if owner == defn.CompiletimeOpsAnyModuleClass then name match
case tpnme.Equals => constantFold2(constValue, _ == _)
case tpnme.NotEquals => constantFold2(constValue, _ != _)
Expand Down
7 changes: 6 additions & 1 deletion compiler/src/dotty/tools/dotc/core/TypeOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,12 @@ object TypeOps:
(tp.tp1.dealias, tp.tp2.dealias) match
case (tp1 @ AppliedType(tycon1, args1), tp2 @ AppliedType(tycon2, args2))
if tycon1.typeSymbol == tycon2.typeSymbol && (tycon1 =:= tycon2) =>
mergeRefinedOrApplied(tp1, tp2)
mergeRefinedOrApplied(tp1, tp2) match
case tp: AppliedType if tp.isUnreducibleWild =>
// fall back to or-dominators rather than inferring a type that would
// cause an unreducible type error later.
approximateOr(tp1, tp2)
case tp => tp
case (tp1, tp2) =>
approximateOr(tp1, tp2)
case _ =>
Expand Down
Loading

0 comments on commit 7e27c4b

Please sign in to comment.