Skip to content

Commit

Permalink
Merge pull request #1556 from hylo-lang/fix-canonicalization
Browse files Browse the repository at this point in the history
Fix canonicalization of associated types
  • Loading branch information
kyouko-taiga authored Aug 17, 2024
2 parents 3a2840a + a43ac74 commit 61cf91c
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 89 deletions.
128 changes: 62 additions & 66 deletions Sources/CodeGen/LLVM/TypeLowering.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,33 @@ import Utils

extension IR.Program {

/// Returns the LLVM form of `val` in `module`.
/// Returns the LLVM form of `t` in `module`.
///
/// - Requires: `val` is representable in LLVM.
func llvm<T: TypeProtocol>(_ val: T, in module: inout SwiftyLLVM.Module) -> SwiftyLLVM.IRType {
switch val {
case let t as AnyType:
return llvm(t.base, in: &module)
case let t as ArrowType:
return llvm(arrowType: t, in: &module)
case let t as BufferType:
return llvm(bufferType: t, in: &module)
case let t as BuiltinType:
return llvm(builtinType: t, in: &module)
case let t as BoundGenericType:
return llvm(boundGenericType: t, in: &module)
/// - Requires: `t` is representable in LLVM.
func llvm<T: TypeProtocol>(_ t: T, in module: inout SwiftyLLVM.Module) -> SwiftyLLVM.IRType {
switch t {
case let u as AnyType:
return llvm(u.base, in: &module)
case let u as ArrowType:
return llvm(arrowType: u, in: &module)
case let u as BufferType:
return llvm(bufferType: u, in: &module)
case let u as BuiltinType:
return llvm(builtinType: u, in: &module)
case let u as BoundGenericType:
return llvm(boundGenericType: u, in: &module)
case is MetatypeType:
return module.ptr
case let t as ProductType:
return llvm(productType: t, in: &module)
case let u as ProductType:
return llvm(productType: u, in: &module)
case is RemoteType:
return module.ptr
case let t as TupleType:
return llvm(tupleType: t, in: &module)
case let t as UnionType:
return llvm(unionType: t, in: &module)
case let u as TupleType:
return llvm(tupleType: u, in: &module)
case let u as UnionType:
return llvm(unionType: u, in: &module)
default:
notLLVMRepresentable(val)
notLLVMRepresentable(t)
}
}

Expand All @@ -44,22 +44,22 @@ extension IR.Program {
return SwiftyLLVM.StructType([module.ptr, e], in: &module)
}

/// Returns the LLVM form of `val` in `module`.
/// Returns the LLVM form of `t` in `module`.
///
/// - Requires: `val` is representable in LLVM.
func llvm(bufferType val: BufferType, in module: inout SwiftyLLVM.Module) -> SwiftyLLVM.IRType {
let e = llvm(val.element, in: &module)
guard let n = val.count.asCompilerKnown(Int.self) else {
notLLVMRepresentable(val)
/// - Requires: `t` is representable in LLVM.
func llvm(bufferType t: BufferType, in module: inout SwiftyLLVM.Module) -> SwiftyLLVM.IRType {
let e = llvm(t.element, in: &module)
guard let n = t.count.asCompilerKnown(Int.self) else {
notLLVMRepresentable(t)
}
return SwiftyLLVM.ArrayType(n, e, in: &module)
}

/// Returns the LLVM form of `val` in `module`.
/// Returns the LLVM form of `t` in `module`.
///
/// - Requires: `val` is representable in LLVM.
func llvm(builtinType val: BuiltinType, in module: inout SwiftyLLVM.Module) -> SwiftyLLVM.IRType {
switch val {
/// - Requires: `t` is representable in LLVM.
func llvm(builtinType t: BuiltinType, in module: inout SwiftyLLVM.Module) -> SwiftyLLVM.IRType {
switch t {
case .i(let width):
return SwiftyLLVM.IntegerType(width, in: &module)
case .word:
Expand All @@ -75,25 +75,21 @@ extension IR.Program {
case .ptr:
return module.ptr
case .module:
notLLVMRepresentable(val)
notLLVMRepresentable(t)
}
}

/// Returns the LLVM form of `val` in `module`.
/// Returns the LLVM form of `t` in `module`.
///
/// - Requires: `val` is representable in LLVM.
/// - Requires: `t` is representable in LLVM.
func llvm(
boundGenericType val: BoundGenericType, in module: inout SwiftyLLVM.Module
boundGenericType t: BoundGenericType, in module: inout SwiftyLLVM.Module
) -> SwiftyLLVM.IRType {
precondition(val[.isCanonical])
precondition(t[.isCanonical])

let fields = base.storage(of: val.base).map { (part) in
let z = GenericArguments(val)
let u = base.specialize(part.type, for: z, in: AnyScopeID(base.ast.coreLibrary!))
return llvm(u, in: &module)
}
let fields = base.storage(of: t).map({ (p) in llvm(p.type, in: &module) })

switch val.base.base {
switch t.base.base {
case let u as ProductType:
return SwiftyLLVM.StructType(named: u.name.value, fields, in: &module)
case is TupleType:
Expand All @@ -103,19 +99,19 @@ extension IR.Program {
}
}

/// Returns the LLVM form of `val` in `module`.
/// Returns the LLVM form of `t` in `module`.
///
/// - Requires: `val` is representable in LLVM.
func llvm(productType val: ProductType, in module: inout SwiftyLLVM.Module) -> SwiftyLLVM.IRType {
precondition(val[.isCanonical])

let n = base.mangled(val)
if let t = module.type(named: n) {
assert(SwiftyLLVM.StructType(t) != nil)
return t
/// - Requires: `t` is representable in LLVM.
func llvm(productType t: ProductType, in module: inout SwiftyLLVM.Module) -> SwiftyLLVM.IRType {
precondition(t[.isCanonical])

let n = base.mangled(t)
if let u = module.type(named: n) {
assert(SwiftyLLVM.StructType(u) != nil)
return u
}

let l = AbstractTypeLayout(of: val, definedIn: base)
let l = AbstractTypeLayout(of: t, definedIn: base)
var fields: [SwiftyLLVM.IRType] = []
for p in l.properties {
fields.append(llvm(p.type, in: &module))
Expand All @@ -124,35 +120,35 @@ extension IR.Program {
return SwiftyLLVM.StructType(named: n, fields, in: &module)
}

/// Returns the LLVM form of `val` in `module`.
/// Returns the LLVM form of `t` in `module`.
///
/// - Requires: `val` is representable in LLVM.
func llvm(tupleType val: TupleType, in module: inout SwiftyLLVM.Module) -> SwiftyLLVM.IRType {
precondition(val[.isCanonical])
/// - Requires: `t` is representable in LLVM.
func llvm(tupleType t: TupleType, in module: inout SwiftyLLVM.Module) -> SwiftyLLVM.IRType {
precondition(t[.isCanonical])

var fields: [SwiftyLLVM.IRType] = []
for e in val.elements {
for e in t.elements {
fields.append(llvm(e.type, in: &module))
}

return SwiftyLLVM.StructType(fields, in: &module)
}

/// Returns the LLVM form of `val` in `module`.
/// Returns the LLVM form of `t` in `module`.
///
/// - Requires: `val` is representable in LLVM.
func llvm(unionType val: UnionType, in module: inout SwiftyLLVM.Module) -> SwiftyLLVM.IRType {
precondition(val[.isCanonical])
/// - Requires: `t` is representable in LLVM.
func llvm(unionType t: UnionType, in module: inout SwiftyLLVM.Module) -> SwiftyLLVM.IRType {
precondition(t[.isCanonical])

var payload: SwiftyLLVM.IRType = SwiftyLLVM.StructType([], in: &module)
if val.isNever {
if t.isNever {
return payload
}

for e in val.elements {
let t = llvm(e, in: &module)
if module.layout.storageSize(of: t) > module.layout.storageSize(of: payload) {
payload = t
for e in t.elements {
let u = llvm(e, in: &module)
if module.layout.storageSize(of: u) > module.layout.storageSize(of: payload) {
payload = u
}
}
return StructType([payload, module.word()], in: &module)
Expand Down
15 changes: 14 additions & 1 deletion Sources/FrontEnd/TypeChecking/ConstraintSystem.swift
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,10 @@ struct ConstraintSystem {
return nil
}

goals[g].modifyTypes({ typeAssumptions[$0] })
goals[g].modifyTypes { (t) in
typeAssumptions.reify(t, withVariables: .kept)
}

log("- solve: \"\(goals[g])\"")
indentation += 1
log("actions:")
Expand Down Expand Up @@ -443,10 +446,20 @@ struct ConstraintSystem {
return delegate(to: [s])
}

case (let l as AssociatedTypeType, _) where l.root.base is TypeVariable:
postpone(g)
return nil

case (_, let r as AssociatedTypeType) where r.root.base is TypeVariable:
postpone(g)
return nil

default:
if !goal.left[.isCanonical] || !goal.right[.isCanonical] {
let l = checker.canonical(goal.left, in: scope)
let r = checker.canonical(goal.right, in: scope)
assert(l[.isCanonical] && r[.isCanonical])

let s = schedule(
SubtypingConstraint(l, r, strictly: goal.isStrict, origin: goal.origin.subordinate()))
return delegate(to: [s])
Expand Down
24 changes: 11 additions & 13 deletions Sources/FrontEnd/TypeChecking/SubstitutionMap.swift
Original file line number Diff line number Diff line change
Expand Up @@ -81,44 +81,42 @@ struct SubstitutionMap {
return occurs
}

/// Substitutes each type variable occurring in `type` by its corresponding substitution in `self`,
/// apply `substitutionPolicy` to deal with free variables.
/// Returns a copy of `type` where type variable occurring in is replaced by its corresponding
/// substitution in `self`, applying `substitutionPolicy` to deal with free variables.
///
/// The default substitution policy is `substituteByError` because we typically use `reify` after
/// having built a complete solution and therefore don't expect its result to still contain open
/// type variables.
func reify(
_ type: AnyType, withVariables substitutionPolicy: SubstitutionPolicy = .substitutedByError
) -> AnyType {
return type.transform(transform(type:))

func transform(type: AnyType) -> TypeTransformAction {
if type.base is TypeVariable {
let walked = self[type]
type.transform { (t: AnyType) -> TypeTransformAction in
if t.base is TypeVariable {
let walked = self[t]

// Substitute `walked` for `type`.
if walked.base is TypeVariable {
switch substitutionPolicy {
case .substitutedByError:
return .stepOver(.error)
case .kept:
return .stepOver(type)
return .stepOver(walked)
}
} else {
return .stepInto(walked)
}
} else if !type[.hasVariable] {
} else if !t[.hasVariable] {
// Nothing to do if the type doesn't contain any variable.
return .stepOver(type)
return .stepOver(t)
} else {
// Recursively visit other types.
return .stepInto(type)
return .stepInto(t)
}
}
}

/// Returns `r` where each type variable occurring in its generic arguments of `r` are replaced by
/// their corresponding value in `self`, applying `substitutionPolicy` to handle free variables.
/// Returns a copy of `r` where each generic argument is replaced by the result of applying
/// `reify(withVariables:)` on it.
func reify(
_ r: DeclReference, withVariables substitutionPolicy: SubstitutionPolicy
) -> DeclReference {
Expand Down
26 changes: 20 additions & 6 deletions Sources/FrontEnd/TypeChecking/TypeChecker.swift
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ struct TypeChecker {
if t[.isCanonical] { return t }

switch t.base {
case let u as AssociatedTypeType:
return canonical(u, in: scopeOfUse)
case let u as BoundGenericType:
return canonical(u, in: scopeOfUse)
case let u as TypeAliasType:
Expand All @@ -90,6 +92,15 @@ struct TypeChecker {
}
}

/// Returns the canonical form of `t` in `scopeOfUse`.
private mutating func canonical(_ t: AssociatedTypeType, in scopeOfUse: AnyScopeID) -> AnyType {
if let u = demandImplementation(of: t.decl, for: t.domain, in: scopeOfUse) {
return canonical(u, in: scopeOfUse)
} else {
return .error
}
}

/// Returns the canonical form of `t` in `scopeOfUse`.
private mutating func canonical(_ t: BoundGenericType, in scopeOfUse: AnyScopeID) -> AnyType {
if t[.isCanonical] { return ^t }
Expand Down Expand Up @@ -595,15 +606,18 @@ struct TypeChecker {
return result
}

/// Returns the type implementing requirement `r` for the model `m` in `scopeOfUse`, or `nil` if
/// `m` does not implement `r`.
/// Returns the type implementing `requirement` for `model` in `scopeOfUse`, or `nil` if `model`
/// does not implement `requirement`.
private mutating func demandImplementation(
of r: AssociatedTypeDecl.ID, for m: AnyType, in scopeOfUse: AnyScopeID
of requirement: AssociatedTypeDecl.ID, for model: AnyType, in scopeOfUse: AnyScopeID
) -> AnyType? {
if let c = demandConformance(of: m, to: traitDeclaring(r)!, exposedTo: scopeOfUse) {
return demandImplementation(of: r, in: c)
let p = traitDeclaring(requirement)!
let m = canonical(model, in: scopeOfUse)

if let c = demandConformance(of: m, to: p, exposedTo: scopeOfUse) {
return demandImplementation(of: requirement, in: c)
} else if m.base is GenericTypeParameterType {
return ^AssociatedTypeType(r, domain: m, ast: program.ast)
return ^AssociatedTypeType(requirement, domain: m, ast: program.ast)
} else {
return nil
}
Expand Down
7 changes: 5 additions & 2 deletions Sources/FrontEnd/TypedProgram.swift
Original file line number Diff line number Diff line change
Expand Up @@ -296,9 +296,12 @@ public struct TypedProgram {
/// Returns the names and types of `t`'s stored properties.
public func storage(of t: BoundGenericType) -> [TupleType.Element] {
storage(of: t.base).map { (p) in
// FIXME: Probably wrong to specialize/canonicalize in any random scope.
let arbitraryScope = AnyScopeID(base.ast.coreLibrary!)
let z = GenericArguments(t)
let t = specialize(p.type, for: z, in: AnyScopeID(base.ast.coreLibrary!))
return .init(label: p.label, type: t)
let u = specialize(p.type, for: z, in: arbitraryScope)
let v = canonical(u, in: arbitraryScope)
return .init(label: p.label, type: v)
}
}

Expand Down
9 changes: 8 additions & 1 deletion Sources/FrontEnd/Types/AssociatedTypeType.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ public struct AssociatedTypeType: TypeProtocol {
/// Creates an instance with the given properties.
init(decl: AssociatedTypeDecl.ID, domain: AnyType, name: String) {
var fs = domain.flags
if !domain.isSkolem && !(domain.base is TypeVariable) {

let d = AssociatedTypeType(domain)?.root ?? domain
if !(d.isSkolem || d.base is TypeVariable) {
fs.remove(.isCanonical)
}

Expand All @@ -35,6 +37,11 @@ public struct AssociatedTypeType: TypeProtocol {
self.flags = fs
}

/// Returns the root of `self`'s qualification.
public var root: AnyType {
AssociatedTypeType(domain)?.root ?? domain
}

public func transformParts<M>(
mutating m: inout M, _ transformer: (inout M, AnyType) -> TypeTransformAction
) -> Self {
Expand Down
Loading

0 comments on commit 61cf91c

Please sign in to comment.