Skip to content

Commit

Permalink
Refactor QuiescingHelper to exhaustively iterate state (#193)
Browse files Browse the repository at this point in the history
# Motivation
Currently the `QuiescingHelper` is crashing on a precondition if you call shutdown when it already was shutdown. However, that can totally happen and we should support it.

# Modification
Refactor the `QuiescingHelper` to exhaustively switch over its state in every method. Furthermore, I added a few more test cases to test realistic scenarios.

# Result
We are now reliable checking our state and making sure to allow most transitions.
  • Loading branch information
FranzBusch authored Feb 24, 2023
1 parent 6bd9bf5 commit d75ed70
Show file tree
Hide file tree
Showing 3 changed files with 264 additions and 71 deletions.
132 changes: 80 additions & 52 deletions Sources/NIOExtras/QuiescingHelper.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,25 @@ private enum ShutdownError: Error {
/// `channelAdded` method in the same event loop tick as the `Channel` is actually created.
private final class ChannelCollector {
enum LifecycleState {
case upAndRunning
case shuttingDown
case upAndRunning(
openChannels: [ObjectIdentifier: Channel],
serverChannel: Channel
)
case shuttingDown(
openChannels: [ObjectIdentifier: Channel],
fullyShutdownPromise: EventLoopPromise<Void>
)
case shutdownCompleted
}

private var openChannels: [ObjectIdentifier: Channel] = [:]
private let serverChannel: Channel
private var fullyShutdownPromise: EventLoopPromise<Void>? = nil
private var lifecycleState = LifecycleState.upAndRunning
private var lifecycleState: LifecycleState

private var eventLoop: EventLoop {
return self.serverChannel.eventLoop
}
private let eventLoop: EventLoop

/// Initializes a `ChannelCollector` for `Channel`s accepted by `serverChannel`.
init(serverChannel: Channel) {
self.serverChannel = serverChannel
self.eventLoop = serverChannel.eventLoop
self.lifecycleState = .upAndRunning(openChannels: [:], serverChannel: serverChannel)
}

/// Add a channel to the `ChannelCollector`.
Expand All @@ -51,30 +53,64 @@ private final class ChannelCollector {
func channelAdded(_ channel: Channel) throws {
self.eventLoop.assertInEventLoop()

guard self.lifecycleState != .shutdownCompleted else {
switch self.lifecycleState {
case .upAndRunning(var openChannels, let serverChannel):
openChannels[ObjectIdentifier(channel)] = channel
self.lifecycleState = .upAndRunning(openChannels: openChannels, serverChannel: serverChannel)

case .shuttingDown(var openChannels, let fullyShutdownPromise):
openChannels[ObjectIdentifier(channel)] = channel
channel.eventLoop.execute {
channel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent())
}
self.lifecycleState = .shuttingDown(openChannels: openChannels, fullyShutdownPromise: fullyShutdownPromise)

case .shutdownCompleted:
channel.close(promise: nil)
throw ShutdownError.alreadyShutdown
}

self.openChannels[ObjectIdentifier(channel)] = channel
}

private func shutdownCompleted() {
self.eventLoop.assertInEventLoop()
assert(self.lifecycleState == .shuttingDown)

self.lifecycleState = .shutdownCompleted
self.fullyShutdownPromise?.succeed(())
switch self.lifecycleState {
case .upAndRunning:
preconditionFailure("This can never happen because we transition to shuttingDown first")

case .shuttingDown(_, let fullyShutdownPromise):
self.lifecycleState = .shutdownCompleted
fullyShutdownPromise.succeed(())

case .shutdownCompleted:
preconditionFailure("We should only complete the shutdown once")
}
}

private func channelRemoved0(_ channel: Channel) {
self.eventLoop.assertInEventLoop()
precondition(self.openChannels.keys.contains(ObjectIdentifier(channel)),
"channel \(channel) not in ChannelCollector \(self.openChannels)")

self.openChannels.removeValue(forKey: ObjectIdentifier(channel))
if self.lifecycleState != .upAndRunning && self.openChannels.isEmpty {
shutdownCompleted()
switch self.lifecycleState {
case .upAndRunning(var openChannels, let serverChannel):
let removedChannel = openChannels.removeValue(forKey: ObjectIdentifier(channel))

precondition(removedChannel != nil, "channel \(channel) not in ChannelCollector \(openChannels)")

self.lifecycleState = .upAndRunning(openChannels: openChannels, serverChannel: serverChannel)

case .shuttingDown(var openChannels, let fullyShutdownPromise):
let removedChannel = openChannels.removeValue(forKey: ObjectIdentifier(channel))

precondition(removedChannel != nil, "channel \(channel) not in ChannelCollector \(openChannels)")

if openChannels.isEmpty {
self.shutdownCompleted()
} else {
self.lifecycleState = .shuttingDown(openChannels: openChannels, fullyShutdownPromise: fullyShutdownPromise)
}

case .shutdownCompleted:
preconditionFailure("We should not have channels removed after transitioned to completed")
}
}

Expand All @@ -96,44 +132,39 @@ private final class ChannelCollector {

private func initiateShutdown0(promise: EventLoopPromise<Void>?) {
self.eventLoop.assertInEventLoop()
precondition(self.lifecycleState == .upAndRunning)

self.lifecycleState = .shuttingDown
switch self.lifecycleState {
case .upAndRunning(let openChannels, let serverChannel):
let fullyShutdownPromise = promise ?? serverChannel.eventLoop.makePromise(of: Void.self)

if let promise = promise {
if let alreadyExistingPromise = self.fullyShutdownPromise {
alreadyExistingPromise.futureResult.cascade(to: promise)
} else {
self.fullyShutdownPromise = promise
}
}
self.lifecycleState = .shuttingDown(openChannels: openChannels, fullyShutdownPromise: fullyShutdownPromise)

self.serverChannel.close().cascadeFailure(to: self.fullyShutdownPromise)
serverChannel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent())
serverChannel.close().cascadeFailure(to: fullyShutdownPromise)

for channel in self.openChannels.values {
channel.eventLoop.execute {
channel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent())
for channel in openChannels.values {
channel.eventLoop.execute {
channel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent())
}
}
}

if self.openChannels.isEmpty {
shutdownCompleted()
if openChannels.isEmpty {
self.shutdownCompleted()
}

case .shuttingDown(_, let fullyShutdownPromise):
fullyShutdownPromise.futureResult.cascade(to: promise)

case .shutdownCompleted:
promise?.succeed(())
}
}

/// Initiate the shutdown fulfilling `promise` when all the previously registered `Channel`s have been closed.
///
/// - parameters:
/// - promise: The `EventLoopPromise` to fulfill when the shutdown of all previously registered `Channel`s has been completed.
/// - promise: The `EventLoopPromise` to fulfil when the shutdown of all previously registered `Channel`s has been completed.
func initiateShutdown(promise: EventLoopPromise<Void>?) {
if self.serverChannel.eventLoop.inEventLoop {
self.serverChannel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent())
} else {
self.eventLoop.execute {
self.serverChannel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent())
}
}

if self.eventLoop.inEventLoop {
self.initiateShutdown0(promise: promise)
} else {
Expand All @@ -144,7 +175,6 @@ private final class ChannelCollector {
}
}


extension ChannelCollector: @unchecked Sendable {}

/// A `ChannelHandler` that adds all channels that it receives through the `ChannelPipeline` to a `ChannelCollector`.
Expand Down Expand Up @@ -173,7 +203,7 @@ private final class CollectAcceptedChannelsHandler: ChannelInboundHandler {
do {
try self.channelCollector.channelAdded(channel)
let closeFuture = channel.closeFuture
closeFuture.whenComplete { (_: Result<(), Error>) in
closeFuture.whenComplete { (_: Result<Void, Error>) in
self.channelCollector.channelRemoved(channel)
}
context.fireChannelRead(data)
Expand Down Expand Up @@ -231,7 +261,7 @@ public final class ServerQuiescingHelper {
deinit {
self.channelCollectorPromise.fail(UnusedQuiescingHelperError())
}

/// Create the `ChannelHandler` for the server `channel` to collect all accepted child `Channel`s.
///
/// - parameters:
Expand Down Expand Up @@ -262,6 +292,4 @@ public final class ServerQuiescingHelper {
}
}

extension ServerQuiescingHelper: Sendable {

}
extension ServerQuiescingHelper: Sendable {}
5 changes: 5 additions & 0 deletions Tests/NIOExtrasTests/QuiescingHelperTest+XCTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ extension QuiescingHelperTest {
("testQuiesceUserEventReceivedOnShutdown", testQuiesceUserEventReceivedOnShutdown),
("testQuiescingDoesNotSwallowCloseErrorsFromAcceptHandler", testQuiescingDoesNotSwallowCloseErrorsFromAcceptHandler),
("testShutdownIsImmediateWhenPromiseDoesNotSucceed", testShutdownIsImmediateWhenPromiseDoesNotSucceed),
("testShutdown_whenAlreadyShutdown", testShutdown_whenAlreadyShutdown),
("testShutdown_whenNoOpenChild", testShutdown_whenNoOpenChild),
("testChannelClose_whenRunning", testChannelClose_whenRunning),
("testChannelAdded_whenShuttingDown", testChannelAdded_whenShuttingDown),
("testChannelAdded_whenShutdown", testChannelAdded_whenShutdown),
]
}
}
Expand Down
Loading

0 comments on commit d75ed70

Please sign in to comment.