Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use LockedValueBoxes for state protection #35

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ let package = Package(
.library(name: "CassandraClient", targets: ["CassandraClient"]),
],
dependencies: [
.package(url: "https://github.com/apple/swift-nio", .upToNextMajor(from: "2.41.1")),
.package(url: "https://github.com/apple/swift-nio", .upToNextMajor(from: "2.42.0")),
.package(url: "https://github.com/apple/swift-nio-ssl", .upToNextMajor(from: "2.21.0")),
.package(url: "https://github.com/apple/swift-atomics", from: "1.0.2"),
.package(url: "https://github.com/apple/swift-log", .upToNextMajor(from: "1.0.0")),
Expand Down
4 changes: 2 additions & 2 deletions Sources/CassandraClient/CassandraClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import NIOConcurrencyHelpers
/// `CassandraClient` is a wrapper around the [Datastax Cassandra C++ Driver](https://github.com/datastax/cpp-driver)
/// and can be used to run queries against a Cassandra database.
public class CassandraClient: CassandraSession {
private let eventLoopGroupContainer: EventLoopGroupConainer
private let eventLoopGroupContainer: EventLoopGroupContainer
public var eventLoopGroup: EventLoopGroup {
self.eventLoopGroupContainer.value
}
Expand Down Expand Up @@ -247,4 +247,4 @@ extension CassandraClient {
}
#endif

internal typealias EventLoopGroupConainer = (value: EventLoopGroup, managed: Bool)
internal typealias EventLoopGroupContainer = (value: EventLoopGroup, managed: Bool)
88 changes: 48 additions & 40 deletions Sources/CassandraClient/Session.swift
Original file line number Diff line number Diff line change
Expand Up @@ -190,15 +190,14 @@ extension CassandraSession {

extension CassandraClient {
internal final class Session: CassandraSession {
private let eventLoopGroupContainer: EventLoopGroupConainer
private let eventLoopGroupContainer: EventLoopGroupContainer
public var eventLoopGroup: EventLoopGroup {
self.eventLoopGroupContainer.value
}

private let configuration: Configuration
private let logger: Logger
private var state = State.idle
private let lock = Lock()
private let stateStore = NIOLockedValueBox(State.idle)

private let rawPointer: OpaquePointer

Expand All @@ -212,58 +211,68 @@ extension CassandraClient {
case disconnected
}

internal init(configuration: Configuration, logger: Logger, eventLoopGroupContainer: EventLoopGroupConainer) {
internal init(configuration: Configuration, logger: Logger, eventLoopGroupContainer: EventLoopGroupContainer) {
self.configuration = configuration
self.logger = logger
self.eventLoopGroupContainer = eventLoopGroupContainer
self.rawPointer = cass_session_new()
}

deinit {
guard case .disconnected = (self.lock.withLock { self.state }) else {
let isDisconnected = self.stateStore.withLockedValue { state in
if case .disconnected = state {
return true
}
return false
}
guard isDisconnected else {
preconditionFailure("Session not shut down before the deinit. Please call session.shutdown() when no longer needed.")
}
cass_session_free(self.rawPointer)
}

func shutdown() throws {
self.lock.lock()
defer {
self.state = .disconnected
self.lock.unlock()
}
switch self.state {
case .connected:
try self.disconect()
default:
break
try self.stateStore.withLockedValue { (state: inout State) in
defer {
state = .disconnected
}
switch state {
case .connected:
try self.disconect()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty nervous about this function being called with the lock held.

default:
break
}
}
}

func execute(statement: Statement, on eventLoop: EventLoop?, logger: Logger? = .none) -> EventLoopFuture<Rows> {
let eventLoop = eventLoop ?? self.eventLoopGroup.next()
let logger = logger ?? self.logger

self.lock.lock()
switch self.state {
let (startingState, future) = self.stateStore.withLockedValue { (state: inout State) -> (State, EventLoopFuture<Void>?) in
if case .idle = state {
let future = self.connect(on: eventLoop, logger: logger)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same notes here.

state = .connectingFuture(future)
return (.idle, future)
} else {
return (state, nil)
}
}

switch startingState {
case .idle:
let future = self.connect(on: eventLoop, logger: logger)
self.state = .connectingFuture(future)
self.lock.unlock()
return future.flatMap { _ in
self.lock.withLock {
self.state = .connected
return future!.flatMap { _ in
self.stateStore.withLockedValue { (state: inout State) in
state = .connected
}
return self.execute(statement: statement, on: eventLoop, logger: logger)
}
case .connectingFuture(let future):
self.lock.unlock()
return future.flatMap { _ in
self.execute(statement: statement, on: eventLoop, logger: logger)
}
#if compiler(>=5.5) && canImport(_Concurrency)
case .connecting(let task):
self.lock.unlock()
let promise = eventLoop.makePromise(of: Rows.self)
if #available(macOS 12, iOS 15, tvOS 15, watchOS 8, *) {
promise.completeWithTask {
Expand All @@ -274,7 +283,6 @@ extension CassandraClient {
return promise.futureResult
#endif
case .connected:
self.lock.unlock()
logger.debug("executing: \(statement.query)")
logger.trace("\(statement.parameters)")
let promise = eventLoop.makePromise(of: Rows.self)
Expand All @@ -284,7 +292,6 @@ extension CassandraClient {
}
return promise.futureResult
case .disconnected:
self.lock.unlock()
if self.eventLoopGroupContainer.managed {
// eventloop *is* shutdown now
preconditionFailure("client is disconnected")
Expand Down Expand Up @@ -444,28 +451,30 @@ extension CassandraClient.Session {
func execute(statement: CassandraClient.Statement, logger: Logger? = .none) async throws -> CassandraClient.Rows {
let logger = logger ?? self.logger

lock.lock()
switch state {
case .idle:
let task = self.connect(logger: logger)
state = .connecting(ConnectionTask(task))
lock.unlock()
let (startingState, task) = self.stateStore.withLockedValue { (state: inout State) -> (State, Task<Void, Swift.Error>?) in
if case .idle = state {
let task = self.connect(logger: logger)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And here.

state = .connecting(ConnectionTask(task))
return (.idle, task)
} else {
return (state, nil)
}
}

try await task.value
lock.withLock {
self.state = .connected
switch startingState {
case .idle:
try await task!.value
self.stateStore.withLockedValue { (state: inout State) in
state = .connected
}
return try await self.execute(statement: statement, logger: logger)
case .connectingFuture(let future):
lock.unlock()
try await future.get()
return try await self.execute(statement: statement, logger: logger)
case .connecting(let task):
lock.unlock()
try await task.task.value
return try await self.execute(statement: statement, logger: logger)
case .connected:
lock.unlock()
logger.debug("executing: \(statement.query)")
logger.trace("\(statement.parameters)")
let future = cass_session_execute(rawPointer, statement.rawPointer)
Expand All @@ -475,7 +484,6 @@ extension CassandraClient.Session {
}
}
case .disconnected:
lock.unlock()
if eventLoopGroupContainer.managed {
// eventloop *is* shutdown now
preconditionFailure("client is disconnected")
Expand Down