Skip to content

Commit

Permalink
Merge pull request #1543 from hylo-lang/fix-ownership-analysis
Browse files Browse the repository at this point in the history
Fix ownership analysis
  • Loading branch information
kyouko-taiga authored Aug 3, 2024
2 parents 0db569f + a0e4d40 commit 4362ffc
Show file tree
Hide file tree
Showing 12 changed files with 222 additions and 80 deletions.
9 changes: 7 additions & 2 deletions Sources/FrontEnd/AccessEffectSet.swift
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ public struct AccessEffectSet: OptionSet, Hashable {
}

public mutating func update(with newMember: AccessEffect) -> AccessEffect? {
insert(newMember).memberAfterInsert
let (i, k) = insert(newMember)
return i ? nil : k
}

/// A set with `set` and `inout`.
Expand Down Expand Up @@ -103,8 +104,12 @@ extension AccessEffectSet {
self.base = s
}

public var isEmpty: Bool {
base.rawValue == 0
}

public var startIndex: UInt8 {
base.rawValue & (~base.rawValue + 1)
isEmpty ? endIndex : (base.rawValue & (~base.rawValue + 1))
}

public var endIndex: UInt8 {
Expand Down
18 changes: 10 additions & 8 deletions Sources/IR/Analysis/Module+AccessReification.swift
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,20 @@ extension Module {
while let i = work.popFirst() {
let s = self[i] as! ReifiableAccess
let available = s.capabilities
var requested: AccessEffectSet = [available.weakest!]
assert(!available.isSingleton, "access already reified")

var lower = AccessEffect.let
var upper = AccessEffect.let

forEachClient(of: i) { (u) in
let r = requests(u).intersection(available)
if let k = r.uniqueElement {
requested = [requested.strongest(including: k)]
} else {
requested.formUnion(r)
}
let rs = requests(u)
lower = max(rs.weakest!, lower)
upper = rs.strongest(including: upper)
}

if let k = requested.uniqueElement {
if lower == upper {
// We have to "promote" a request if it can be satisfied by a stronger capability.
let k = available.elements.first(where: { (a) in a >= lower }) ?? available.weakest!
reify(i, as: k)
} else {
work.append(i)
Expand Down
11 changes: 8 additions & 3 deletions Sources/IR/Analysis/Module+Ownership.swift
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,14 @@ extension Module {
}

// The access must be immutable if the source of the access is a let-parameter.
if let c = passingConvention(of: s.source), (c == .let) && (request != .let) {
diagnostics.insert(.error(illegalMutableAccessAt: s.site))
return
if (request != .let) && isBoundImmutably(s.source) {
// Built-in values are never consumed.
if self.type(of: s.source).ast.isBuiltin {
assert(request != .inout, "unexpected inout access on built-in value")
} else {
diagnostics.insert(.error(illegalMutableAccessAt: s.site))
return
}
}

let former = reborrowedSource(s)
Expand Down
64 changes: 52 additions & 12 deletions Sources/IR/Module.swift
Original file line number Diff line number Diff line change
Expand Up @@ -99,19 +99,61 @@ public struct Module {
}
}

/// If `p` is a parameter, returns its passing convention. Otherwise, returns `nil`.
/// Returns `true` iff cannot be used to modify or update a value.
public func isBoundImmutably(_ p: Operand) -> Bool {
switch p {
case .parameter(let e, let i):
let f = e.function
return (entry(of: f) == e) && (passingConvention(parameter: i, of: f) == .let)
case .constant:
return false
case .register(let i):
return isBoundImmutably(register: i)
}
}

/// Returns `true` iff the result of `i` cannot be used to modify or update a value.
public func isBoundImmutably(register i: InstructionID) -> Bool {
switch self[i] {
case is AllocStack:
return false
case let s as AdvancedByBytes:
return isBoundImmutably(s.base)
case let s as Access:
return isBoundImmutably(s.source)
case let s as OpenCapture:
return s.isAccess(.let)
case is OpenUnion:
return false
case let s as PointerToAddress:
return s.isAccess(.let)
case let s as Project:
return s.projection.access == .let
case let s as SubfieldView:
return isBoundImmutably(s.recordAddress)
case let s as WrapExistentialAddr:
return isBoundImmutably(s.witness)
default:
return true
}
}

/// If `p` is a function parameter, returns its passing convention. Otherwise, returns `nil`.
public func passingConvention(of p: Operand) -> AccessEffect? {
if case .parameter(let e, let i) = p {
assert(entry(of: e.function) == e)
return read(self[e.function].inputs) { (ps) in
// The last parameter of a function denotes its return value.
(i == ps.count) ? .set : ps[i].type.access
}
if case .parameter(let e, let i) = p, (entry(of: e.function) == e) {
return passingConvention(parameter: i, of: e.function)
} else {
return nil
}
}

/// Returns the passing convention of the `i`-th parameter of `f`.
public func passingConvention(parameter i: Int, of f: Function.ID) -> AccessEffect {
// The last parameter of a function denotes its return value.
let ps = self[f].inputs
return (i == ps.count) ? .set : ps[i].type.access
}

/// Returns the scope in which `i` is used.
public func scope(containing i: InstructionID) -> AnyScopeID {
functions[i.function]![i.block].scope
Expand Down Expand Up @@ -905,11 +947,9 @@ public struct Module {
case let s as Access:
return provenances(s.source)
case let s as Project:
return s.operands.reduce(
into: [],
{ (p, o) in
if type(of: o).isAddress { p.formUnion(provenances(o)) }
})
return s.operands.reduce(into: []) { (p, o) in
if type(of: o).isAddress { p.formUnion(provenances(o)) }
}
case let s as SubfieldView:
return provenances(s.recordAddress)
case let s as WrapExistentialAddr:
Expand Down
35 changes: 19 additions & 16 deletions StandardLibrary/Sources/Array.hylo
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ public type Array<Element: SemiRegular>: SemiRegular {
/// The projected pointer is valid only for the duration of the projection and can be advanced up
/// to `count()`. It may be null if `self` is empty.
public property contiguous_storage: Pointer<Element> {
yield if capacity() == 0 { .null() } else { .new(pointer_to_element[at: 0]) }
yield if capacity() == 0 { .null() } else { .new(pointer_to_element(at: 0)) }
}

/// Calls `action` with a pointer to the start of the array's mutable contiguous storage.
Expand All @@ -96,13 +96,13 @@ public type Array<Element: SemiRegular>: SemiRegular {
public fun with_mutable_contiguous_storage<E, T: Movable>(
_ action: inout [E](PointerToMutable<Element>) inout -> T
) inout -> T {
if capacity() == 0 { &action(.null()) } else { &action(pointer_to_element[at: 0]) }
if capacity() == 0 { &action(.null()) } else { &action(pointer_to_element(at: 0)) }
}

/// Adds a new element at the end of the array.
public fun append(_ source: sink Element) inout {
&reserve_capacity(count() + 1)
pointer_to_element[at: count()].unsafe_initialize_pointee(source)
pointer_to_element(at: count()).unsafe_initialize_pointee(source)
&storage.header += 1
}

Expand All @@ -117,11 +117,11 @@ public type Array<Element: SemiRegular>: SemiRegular {
var c = count()
&reserve_capacity(c + 1)
while c > index {
pointer_to_element[at: c].unsafe_initialize_pointee(
pointer_to_element[at: c - 1].unsafe_pointee())
pointer_to_element(at: c).unsafe_initialize_pointee(
pointer_to_element(at: c - 1).unsafe_pointee())
&c -= 1
}
pointer_to_element[at: c].unsafe_initialize_pointee(source)
pointer_to_element(at: c).unsafe_initialize_pointee(source)
&storage.header += 1
}

Expand All @@ -130,11 +130,11 @@ public type Array<Element: SemiRegular>: SemiRegular {
/// - Requires: `index` is in the range `0 ..< count()`.
/// - Complexity: O(n), where n is the number of elements in `self`.
public fun remove(at index: Int) inout -> Element {
let result = pointer_to_element[at: index].unsafe_pointee()
let result = pointer_to_element(at: index).unsafe_pointee()
var i = index + 1
while i < count() {
pointer_to_element[at: i - 1].unsafe_initialize_pointee(
pointer_to_element[at: i].unsafe_pointee())
pointer_to_element(at: i - 1).unsafe_initialize_pointee(
pointer_to_element(at: i).unsafe_pointee())
&i += 1
}
&storage.header -= 1
Expand All @@ -148,7 +148,7 @@ public type Array<Element: SemiRegular>: SemiRegular {
let c = count()
var i = 0
while i < c {
&pointer_to_element[at: i].unsafe_pointee().deinit()
&pointer_to_element(at: i).unsafe_pointee().deinit()
&i += 1
}
&storage.header = 0
Expand All @@ -161,7 +161,7 @@ public type Array<Element: SemiRegular>: SemiRegular {
public fun pop_last() inout -> Optional<Element> {
let c = count()
if c > 0 {
let result = pointer_to_element[at: c - 1].unsafe_pointee()
let result = pointer_to_element(at: c - 1).unsafe_pointee()
&storage.header -= 1
return result as _
} else {
Expand All @@ -182,10 +182,10 @@ public type Array<Element: SemiRegular>: SemiRegular {
}
}

/// Projects the address of the element at `position`.
/// Returns the address of the element at `position`.
///
/// - Requires: `position` is in the range `0 ..< capacity()`.
subscript pointer_to_element(at position: Int): PointerToMutable<Element> {
fun pointer_to_element(at position: Int) -> PointerToMutable<Element> {
storage.first_element_address().advance(by: position)
}

Expand Down Expand Up @@ -245,11 +245,11 @@ public conformance Array: Collection {
public subscript(_ position: Int): Element {
let {
precondition((position >= 0) && (position < count()), "position is out of bounds")
yield pointer_to_element[at: position].unsafe[]
yield pointer_to_element(at: position).unsafe[]
}
inout {
precondition((position >= 0) && (position < count()), "position is out of bounds")
yield &pointer_to_element[at: position].unsafe[]
yield &(pointer_to_element(at: position).unsafe[])
}
}

Expand All @@ -261,7 +261,10 @@ public conformance Array: MutableCollection {
precondition((i >= 0) && (i < count()), "position is out of bounds")
precondition((j >= 0) && (j < count()), "position is out of bounds")
if i == j { return }
&pointer_to_element[at: i].unsafe[].exchange(with: &pointer_to_element[at: j].unsafe[])

var p = pointer_to_element(at: i)
var q = pointer_to_element(at: j)
&p.unsafe[].exchange(with: &q.unsafe[])
}

}
44 changes: 14 additions & 30 deletions Tests/EndToEndTests/TestCases/Autoclosure.hylo
Original file line number Diff line number Diff line change
@@ -1,48 +1,32 @@
//- compileAndRun expecting: .success

let counter: Int = 0
let log: PointerToMutable<Int> = .allocate(count: 1)

fun prime() -> Int {
&counter += 1
&(log.copy()).unsafe[] += 1
return 17
}

fun logic(
with_delayed value: @autoclosure []() -> Int,
counter_starting_at s: Int,
fun run<E>(
lazily_evaluating value: @autoclosure [E]() -> Int,
starting_at s: Int,
ending_at e: Int
) {
precondition(counter == s)
precondition(log.unsafe[] == s)
let x = value()
precondition(counter == e)
precondition(x == 17)
}

fun logic2<E>(with_delayed value: @autoclosure [E]() -> Int) {
precondition(counter == 0)
let x = value()
precondition(counter == 1)
precondition(log.unsafe[] == e)
precondition(x == 17)
}

public fun main() {
&counter = 0
_ = log.unsafe_initialize_pointee(fun (_ a: set Int) -> Void { &a = 0 })

// Test with function call.
logic(with_delayed: prime(), counter_starting_at: 0, ending_at: 1)

// Test with constant.
logic(with_delayed: 17, counter_starting_at: 1, ending_at: 1)

&counter = 0
// Test with generics.
logic2(with_delayed: prime())
// Lazy parameter without captures.
run(lazily_evaluating: prime(), starting_at: 0, ending_at: 1)
run(lazily_evaluating: 17, starting_at: 1, ending_at: 1)
run(lazily_evaluating: prime(), starting_at: 1, ending_at: 2)

// TODO: Test with generics and non-void environments.
// &counter = 0
// let r = 17
// logic2(with_delayed: fun() -> Int {
// &counter += 1
// return r
// })

log.copy().deallocate()
}
13 changes: 7 additions & 6 deletions Tests/EndToEndTests/TestCases/Concurrency/concurrent_sort.hylo
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ fun my_concurrent_sort<Element: Regular & Comparable>(_ a: inout ArraySlice<Elem

// Spawn work to sort the right-hand side.
let future = spawn_inplace(fun[sink let q=mutable_pointer[to: &rhs].copy()]() -> Int {
inout rhs = &(q.copy().unsafe[])
inout rhs = &(q.copy()).unsafe[]
return my_concurrent_sort(&rhs)
})

Expand Down Expand Up @@ -112,6 +112,7 @@ type ArraySlice<Element: Regular & Comparable> : Deinitializable, Movable {
&self.start_index = full_array.start_position()
&self.end_index = full_array.end_position()
}

/// Initializes `self` to represent elements [`start`, `end`) of `full_array`.
public init(source: Self, from start: Int, to end: Int) {
precondition(start >= 0 && start <= end && end <= source.count())
Expand All @@ -126,9 +127,9 @@ type ArraySlice<Element: Regular & Comparable> : Deinitializable, Movable {
}

/// Sorts the elements in `self`.
public fun sort() {
public fun sort() inout {
// Use bubble sort for simplicity.
inout elements = origin.unsafe[]
inout elements = &origin.unsafe[]
do {
var swapped = false
var i = start_index.copy()
Expand All @@ -149,8 +150,8 @@ type ArraySlice<Element: Regular & Comparable> : Deinitializable, Movable {
}

/// Partitions the slice in 3 parts: one with elements lower than `mid_value`, one with elements equal to `mid_value` and one with elements greater than `mid_value`, returning the indices that separates these parts.
public fun partition(on mid_value: Element) -> {Int, Int} {
inout elements = origin.unsafe[]
public fun partition(on mid_value: Element) inout -> {Int, Int} {
inout elements = &origin.unsafe[]
// First pass to move elements smaller than mid_value to the left.
var i = start_index.copy()
var j = end_index.copy()
Expand Down Expand Up @@ -182,7 +183,7 @@ type ArraySlice<Element: Regular & Comparable> : Deinitializable, Movable {
/// Returns a slice of `self` from `start` to `end` (relative indices).
public fun slice(from start: Int, to end: Int) -> Self {
precondition(0 <= start && start <= end && end <= count())
let r: Self = .new(full_array: &origin.unsafe[])
let r: Self = .new(full_array: &(origin.copy()).unsafe[])
&r.start_index = start_index + start
&r.end_index = start_index + end
return r
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ fun test_mutating_spawn() -> Int {
let p = mutable_pointer[to: &local_variable]

var future = spawn_(fun[sink let q=p.copy()] () -> Int {
&(q.copy().unsafe[]) = 19
&(q.copy()).unsafe[] = 19
return 1
})
let y = future.await()
Expand Down
2 changes: 1 addition & 1 deletion Tests/EndToEndTests/TestCases/ExplicitCaptures.hylo
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ public fun main() {
var local_variable = 0
let p = mutable_pointer[to: &local_variable]
let n = apply(fun[sink let q = p.copy()]() -> Int {
&(q.unsafe[]) = 19
&(q.copy()).unsafe[] = 19
return 19
})
precondition(n == local_variable)
Expand Down
Loading

0 comments on commit 4362ffc

Please sign in to comment.