From 4f597f9108f1fadc2af56a97e1192104be991df6 Mon Sep 17 00:00:00 2001 From: Dimi Racordon Date: Thu, 22 Aug 2024 12:32:11 +0200 Subject: [PATCH 1/3] Add an IR instruction to switch over the contents of a union `union_switch` replaces `switch` so that discriminators can remain abstract until monomorphization. --- Sources/CodeGen/LLVM/Transpilation.swift | 35 +++++-- .../Module+NormalizeObjectStates.swift | 2 + Sources/IR/Emitter.swift | 92 ++++++++----------- Sources/IR/InstructionTransformer.swift | 9 ++ .../IR/Operands/Instruction/UnionSwitch.swift | 82 +++++++++++++++++ 5 files changed, 158 insertions(+), 62 deletions(-) create mode 100644 Sources/IR/Operands/Instruction/UnionSwitch.swift diff --git a/Sources/CodeGen/LLVM/Transpilation.swift b/Sources/CodeGen/LLVM/Transpilation.swift index b31383170..579c06326 100644 --- a/Sources/CodeGen/LLVM/Transpilation.swift +++ b/Sources/CodeGen/LLVM/Transpilation.swift @@ -674,6 +674,8 @@ extension SwiftyLLVM.Module { insert(switch: i) case is IR.UnionDiscriminator: insert(unionDiscriminator: i) + case is IR.UnionSwitch: + insert(unionSwitch: i) case is IR.Unreachable: insert(unreachable: i) case is IR.WrapExistentialAddr: @@ -1199,14 +1201,22 @@ extension SwiftyLLVM.Module { /// Inserts the transpilation of `i` at `insertionPoint`. func insert(unionDiscriminator i: IR.InstructionID) { let s = m[i] as! UnionDiscriminator - let t = UnionType(m.type(of: s.container).ast)! + register[.register(i)] = discriminator(s.container) + } - let baseType = ir.llvm(unionType: t, in: &self) - let container = llvm(s.container) - let indices = [i32.constant(0), i32.constant(1)] - let discriminator = insertGetElementPointerInBounds( - of: container, typed: baseType, indices: indices, at: insertionPoint) - register[.register(i)] = insertLoad(word(), from: discriminator, at: insertionPoint) + /// Inserts the transpilation of `i` at `insertionPoint`. + func insert(unionSwitch i: IR.InstructionID) { + let s = m[i] as! UnionSwitch + let d = discriminator(s.scrutinee) + let e = m.program.discriminatorToElement(in: UnionType(m.type(of: s.scrutinee).ast)!) + let branches = s.targets.map { (t, b) in + (word().constant(e.firstIndex(of: t)!), block[b]!) + } + + // The last branch is the "default". + insertSwitch( + on: d, cases: branches.dropLast(), default: branches.last!.1, + at: insertionPoint) } /// Inserts the transpilation of `i` at `insertionPoint`. @@ -1292,6 +1302,17 @@ extension SwiftyLLVM.Module { v = insertInsertValue(llvm(table), at: 1, into: v, at: insertionPoint) return v } + + /// Returns the value of `container`'s discriminator. + func discriminator(_ container: IR.Operand) -> SwiftyLLVM.Instruction { + let union = UnionType(m.type(of: container).ast)! + let baseType = ir.llvm(unionType: union, in: &self) + let container = llvm(container) + let indices = [i32.constant(0), i32.constant(1)] + let discriminator = insertGetElementPointerInBounds( + of: container, typed: baseType, indices: indices, at: insertionPoint) + return insertLoad(word(), from: discriminator, at: insertionPoint) + } } /// Inserts the prologue of the subscript `transpilation` at the end of its entry and returns diff --git a/Sources/IR/Analysis/Module+NormalizeObjectStates.swift b/Sources/IR/Analysis/Module+NormalizeObjectStates.swift index 2a9d418d8..f57d04775 100644 --- a/Sources/IR/Analysis/Module+NormalizeObjectStates.swift +++ b/Sources/IR/Analysis/Module+NormalizeObjectStates.swift @@ -90,6 +90,8 @@ extension Module { pc = interpret(subfieldView: user, in: &context) case is UnionDiscriminator: pc = interpret(unionDiscriminator: user, in: &context) + case is UnionSwitch: + pc = successor(of: user) case is Unreachable: pc = successor(of: user) case is WrapExistentialAddr: diff --git a/Sources/IR/Emitter.swift b/Sources/IR/Emitter.swift index ff17742a3..41671d31e 100644 --- a/Sources/IR/Emitter.swift +++ b/Sources/IR/Emitter.swift @@ -737,20 +737,15 @@ struct Emitter { } // Otherwise, use a switch to select the correct move-initialization. - let elements = program.discriminatorToElement(in: t) - var successors: [Block.ID] = [] - for _ in t.elements { - successors.append(appendBlock()) - } - - let n = emitUnionDiscriminator(argument, at: site) - insert(module.makeSwitch(on: n, toOneOf: successors, at: site)) + let targets = UnionSwitch.Targets( + t.elements.map({ (e) in (key: e, value: appendBlock()) }), + uniquingKeysWith: { (a, _) in a }) + insert(module.makeUnionSwitch(on: receiver, toOneOf: targets, at: site)) let tail = appendBlock() - for i in 0 ..< elements.count { - insertionPoint = .end(of: successors[i]) - emitMoveInitUnionPayload( - of: receiver, consuming: argument, containing: elements[i], at: site) + for (u, b) in targets { + insertionPoint = .end(of: b) + emitMoveInitUnionPayload(of: receiver, consuming: argument, containing: u, at: site) insert(module.makeBranch(to: tail, at: site)) } @@ -897,20 +892,16 @@ struct Emitter { return } - // Otherwise, use a switch to select the correct move-initialization. - let elements = program.discriminatorToElement(in: t) - var successors: [Block.ID] = [] - for _ in t.elements { - successors.append(appendBlock()) - } - - let n = emitUnionDiscriminator(source, at: site) - insert(module.makeSwitch(on: n, toOneOf: successors, at: site)) + // Otherwise, use a switch to select the correct copy method. + let targets = UnionSwitch.Targets( + t.elements.map({ (e) in (key: e, value: appendBlock()) }), + uniquingKeysWith: { (a, _) in a }) + insert(module.makeUnionSwitch(on: source, toOneOf: targets, at: site)) let tail = appendBlock() - for i in 0 ..< elements.count { - insertionPoint = .end(of: successors[i]) - emitCopyUnionPayload(from: source, containing: elements[i], to: target, at: site) + for (u, b) in targets { + insertionPoint = .end(of: b) + emitCopyUnionPayload(from: source, containing: u, to: target, at: site) insert(module.makeBranch(to: tail, at: site)) } @@ -2391,24 +2382,22 @@ struct Emitter { /// /// This method method implements conditional narrowing for union types. private mutating func emitConditionalNarrowing( - _ subject: Operand, typed subjectType: UnionType, + _ subject: Operand, typed union: UnionType, as pattern: BindingPattern.ID, typed patternType: AnyType, to storage: Operand, else failure: Block.ID, in scope: AnyScopeID ) -> Block.ID { // TODO: Implement narrowing to an arbitrary subtype. - guard subjectType.elements.contains(patternType) else { UNIMPLEMENTED() } + guard union.elements.contains(patternType) else { UNIMPLEMENTED() } let site = ast[pattern].site - let i = program.discriminatorToElement(in: subjectType).firstIndex(of: patternType)! - let expected = IntegerConstant(i, bitWidth: 64) // FIXME: should be width of 'word' - let actual = emitUnionDiscriminator(subject, at: site) - - let test = insert( - module.makeLLVM(applying: .icmp(.eq, .word), to: [.constant(expected), actual], at: site))! let next = appendBlock(in: scope) - insert(module.makeCondBranch(if: test, then: next, else: failure, at: site)) + var targets = UnionSwitch.Targets( + union.elements.map({ (e) in (key: e, value: failure) }), + uniquingKeysWith: { (a, _) in a }) + targets[patternType] = next + insert(module.makeUnionSwitch(on: subject, toOneOf: targets, at: site)) insertionPoint = .end(of: next) let x0 = insert(module.makeOpenUnion(subject, as: patternType, at: site))! pushing(Frame()) { (this) in @@ -3085,19 +3074,15 @@ struct Emitter { } // One successor per member in the union, ordered by their mangled representation. - let elements = program.discriminatorToElement(in: t) - var successors: [Block.ID] = [] - for _ in t.elements { - successors.append(appendBlock()) - } - - let n = emitUnionDiscriminator(storage, at: site) - insert(module.makeSwitch(on: n, toOneOf: successors, at: site)) + let targets = UnionSwitch.Targets( + t.elements.map({ (e) in (key: e, value: appendBlock()) }), + uniquingKeysWith: { (a, _) in a }) + insert(module.makeUnionSwitch(on: storage, toOneOf: targets, at: site)) let tail = appendBlock() - for i in 0 ..< elements.count { - insertionPoint = .end(of: successors[i]) - emitDeinitUnionPayload(of: storage, containing: elements[i], at: site) + for (u, b) in targets { + insertionPoint = .end(of: b) + emitDeinitUnionPayload(of: storage, containing: u, at: site) insert(module.makeBranch(to: tail, at: site)) } @@ -3188,13 +3173,10 @@ struct Emitter { } // Otherwise, compare their payloads. - let elements = program.discriminatorToElement(in: union) - let same = appendBlock() - var successors: [Block.ID] = [] - for _ in elements { - successors.append(appendBlock()) - } + let targets = UnionSwitch.Targets( + union.elements.map({ (e) in (key: e, value: appendBlock()) }), + uniquingKeysWith: { (a, _) in a }) let fail = appendBlock() let tail = appendBlock() @@ -3205,11 +3187,11 @@ struct Emitter { insert(module.makeCondBranch(if: x0, then: same, else: fail, at: site)) insertionPoint = .end(of: same) - insert(module.makeSwitch(on: dl, toOneOf: successors, at: site)) - for i in 0 ..< elements.count { - insertionPoint = .end(of: successors[i]) - let y0 = insert(module.makeOpenUnion(lhs, as: elements[i], at: site))! - let y1 = insert(module.makeOpenUnion(rhs, as: elements[i], at: site))! + insert(module.makeUnionSwitch(on: lhs, toOneOf: targets, at: site)) + for (u, b) in targets { + insertionPoint = .end(of: b) + let y0 = insert(module.makeOpenUnion(lhs, as: u, at: site))! + let y1 = insert(module.makeOpenUnion(rhs, as: u, at: site))! emitStoreEquality(y0, y1, to: target, at: site) insert(module.makeCloseUnion(y1, at: site)) insert(module.makeCloseUnion(y0, at: site)) diff --git a/Sources/IR/InstructionTransformer.swift b/Sources/IR/InstructionTransformer.swift index dd2eb16df..c5c7bd74f 100644 --- a/Sources/IR/InstructionTransformer.swift +++ b/Sources/IR/InstructionTransformer.swift @@ -239,6 +239,15 @@ extension IR.Program { target.makeUnionDiscriminator(x0, at: s.site) } + case let s as UnionSwitch: + let x0 = t.transform(s.scrutinee, in: &self) + let x1 = s.targets.reduce(into: UnionSwitch.Targets()) { (d, kv) in + _ = d[t.transform(kv.key, in: &self)].setIfNil(t.transform(kv.value, in: &self)) + } + return insert(at: p, in:n) { (target) in + target.makeUnionSwitch(on: x0, toOneOf: x1, at: s.site) + } + case let s as Unreachable: return modules[n]!.insert(s, at: p) diff --git a/Sources/IR/Operands/Instruction/UnionSwitch.swift b/Sources/IR/Operands/Instruction/UnionSwitch.swift new file mode 100644 index 000000000..c4ce077eb --- /dev/null +++ b/Sources/IR/Operands/Instruction/UnionSwitch.swift @@ -0,0 +1,82 @@ +import FrontEnd +import OrderedCollections + +/// Branches to one of several basic blocks based on the discriminator of a union. +public struct UnionSwitch: Terminator { + + /// The type of a map from payload type to its target. + public typealias Targets = OrderedDictionary + + /// The union container whose discriminator is read. + public private(set) var scrutinee: Operand + + /// A map from payload type to its target. + public private(set) var targets: Targets + + /// The site of the code corresponding to that instruction. + public let site: SourceRange + + /// Creates an instance with the given properties. + fileprivate init(scrutinee: Operand, targets: Targets, site: SourceRange) { + self.scrutinee = scrutinee + self.targets = targets + self.site = site + } + + + + public var operands: [Operand] { + [scrutinee] + } + + public var successors: [Block.ID] { + Array(targets.values) + } + + public mutating func replaceOperand(at i: Int, with new: Operand) { + precondition(i == 0) + scrutinee = new + } + + mutating func replaceSuccessor(_ old: Block.ID, with new: Block.ID) -> Bool { + precondition(new.function == successors[0].function) + for (t, b) in targets { + if b == old { targets[t] = b; return true } + } + return false + } + +} + +extension UnionSwitch: CustomStringConvertible { + + public var description: String { + var s = "union_switch \(scrutinee)" + for (t, b) in targets { + s.write(", \(t) => \(b)") + } + return s + } + +} + +extension Module { + + /// Creates a `union_switch` anchored at `site` that jumps to the block assigned to the type of + /// `scrutinee`'s payload in `targets`. + /// + /// - Requires: `scrutinee` is a union container and `targets` has a key defined for each of the + /// elements in scrutinee's type. + func makeUnionSwitch( + on scrutinee: Operand, toOneOf targets: UnionSwitch.Targets, at site: SourceRange + ) -> UnionSwitch { + let t = type(of: scrutinee) + guard t.isAddress, let u = UnionType(t.ast) else { + preconditionFailure("invalid type '\(t)'") + } + precondition(u.elements.allSatisfy({ (e) in targets[e] != nil })) + + return .init(scrutinee: scrutinee, targets: targets, site: site) + } + +} From 80ed53060626b536a246f97b8cde97b5a909c5cf Mon Sep 17 00:00:00 2001 From: Dimi Racordon Date: Thu, 22 Aug 2024 13:08:34 +0200 Subject: [PATCH 2/3] Avoid generating switch instructions for singleton unions --- Sources/CodeGen/LLVM/Transpilation.swift | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/Sources/CodeGen/LLVM/Transpilation.swift b/Sources/CodeGen/LLVM/Transpilation.swift index 579c06326..73fce8678 100644 --- a/Sources/CodeGen/LLVM/Transpilation.swift +++ b/Sources/CodeGen/LLVM/Transpilation.swift @@ -1207,16 +1207,22 @@ extension SwiftyLLVM.Module { /// Inserts the transpilation of `i` at `insertionPoint`. func insert(unionSwitch i: IR.InstructionID) { let s = m[i] as! UnionSwitch - let d = discriminator(s.scrutinee) - let e = m.program.discriminatorToElement(in: UnionType(m.type(of: s.scrutinee).ast)!) - let branches = s.targets.map { (t, b) in - (word().constant(e.firstIndex(of: t)!), block[b]!) - } - // The last branch is the "default". - insertSwitch( - on: d, cases: branches.dropLast(), default: branches.last!.1, - at: insertionPoint) + if let (_, b) = s.targets.elements.uniqueElement { + insertBr(to: block[b]!, at: insertionPoint) + } else { + let d = discriminator(s.scrutinee) + let t = UnionType(m.type(of: s.scrutinee).ast)! + let e = m.program.discriminatorToElement(in: t) + let branches = s.targets.map { (t, b) in + (word().constant(e.firstIndex(of: t)!), block[b]!) + } + + // The last branch is the "default". + insertSwitch( + on: d, cases: branches.dropLast(), default: branches.last!.1, + at: insertionPoint) + } } /// Inserts the transpilation of `i` at `insertionPoint`. From 29bf165494d63f954b8e41c68f866c5a62436f16 Mon Sep 17 00:00:00 2001 From: Dimi Racordon Date: Thu, 22 Aug 2024 13:11:05 +0200 Subject: [PATCH 3/3] Test union narrowing in generic function --- Tests/EndToEndTests/TestCases/UnionNarrowing.hylo | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/Tests/EndToEndTests/TestCases/UnionNarrowing.hylo b/Tests/EndToEndTests/TestCases/UnionNarrowing.hylo index 24facf842..1d8b280d2 100644 --- a/Tests/EndToEndTests/TestCases/UnionNarrowing.hylo +++ b/Tests/EndToEndTests/TestCases/UnionNarrowing.hylo @@ -1,5 +1,9 @@ //- compileAndRun expecting: .success +fun f(_ u: sink Union) -> Bool { + if let _: T = u { true } else { false } +} + public fun main() { var x: Union<{a: Bool}, {b: Int}> = (b: 42) if let y: {b: _} = x { @@ -7,4 +11,8 @@ public fun main() { } else { fatal_error() } + + precondition(f(42 as _)) + precondition(f(42 as _)) + precondition(!f(true as _)) }