diff --git a/Sources/NIOExtras/QuiescingHelper.swift b/Sources/NIOExtras/QuiescingHelper.swift index cff0cac9..e66de895 100644 --- a/Sources/NIOExtras/QuiescingHelper.swift +++ b/Sources/NIOExtras/QuiescingHelper.swift @@ -23,23 +23,25 @@ private enum ShutdownError: Error { /// `channelAdded` method in the same event loop tick as the `Channel` is actually created. private final class ChannelCollector { enum LifecycleState { - case upAndRunning - case shuttingDown + case upAndRunning( + openChannels: [ObjectIdentifier: Channel], + serverChannel: Channel + ) + case shuttingDown( + openChannels: [ObjectIdentifier: Channel], + fullyShutdownPromise: EventLoopPromise + ) case shutdownCompleted } - private var openChannels: [ObjectIdentifier: Channel] = [:] - private let serverChannel: Channel - private var fullyShutdownPromise: EventLoopPromise? = nil - private var lifecycleState = LifecycleState.upAndRunning + private var lifecycleState: LifecycleState - private var eventLoop: EventLoop { - return self.serverChannel.eventLoop - } + private let eventLoop: EventLoop /// Initializes a `ChannelCollector` for `Channel`s accepted by `serverChannel`. init(serverChannel: Channel) { - self.serverChannel = serverChannel + self.eventLoop = serverChannel.eventLoop + self.lifecycleState = .upAndRunning(openChannels: [:], serverChannel: serverChannel) } /// Add a channel to the `ChannelCollector`. @@ -51,30 +53,64 @@ private final class ChannelCollector { func channelAdded(_ channel: Channel) throws { self.eventLoop.assertInEventLoop() - guard self.lifecycleState != .shutdownCompleted else { + switch self.lifecycleState { + case .upAndRunning(var openChannels, let serverChannel): + openChannels[ObjectIdentifier(channel)] = channel + self.lifecycleState = .upAndRunning(openChannels: openChannels, serverChannel: serverChannel) + + case .shuttingDown(var openChannels, let fullyShutdownPromise): + openChannels[ObjectIdentifier(channel)] = channel + channel.eventLoop.execute { + channel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent()) + } + self.lifecycleState = .shuttingDown(openChannels: openChannels, fullyShutdownPromise: fullyShutdownPromise) + + case .shutdownCompleted: channel.close(promise: nil) throw ShutdownError.alreadyShutdown } - - self.openChannels[ObjectIdentifier(channel)] = channel } private func shutdownCompleted() { self.eventLoop.assertInEventLoop() - assert(self.lifecycleState == .shuttingDown) - self.lifecycleState = .shutdownCompleted - self.fullyShutdownPromise?.succeed(()) + switch self.lifecycleState { + case .upAndRunning: + preconditionFailure("This can never happen because we transition to shuttingDown first") + + case .shuttingDown(_, let fullyShutdownPromise): + self.lifecycleState = .shutdownCompleted + fullyShutdownPromise.succeed(()) + + case .shutdownCompleted: + preconditionFailure("We should only complete the shutdown once") + } } private func channelRemoved0(_ channel: Channel) { self.eventLoop.assertInEventLoop() - precondition(self.openChannels.keys.contains(ObjectIdentifier(channel)), - "channel \(channel) not in ChannelCollector \(self.openChannels)") - self.openChannels.removeValue(forKey: ObjectIdentifier(channel)) - if self.lifecycleState != .upAndRunning && self.openChannels.isEmpty { - shutdownCompleted() + switch self.lifecycleState { + case .upAndRunning(var openChannels, let serverChannel): + let removedChannel = openChannels.removeValue(forKey: ObjectIdentifier(channel)) + + precondition(removedChannel != nil, "channel \(channel) not in ChannelCollector \(openChannels)") + + self.lifecycleState = .upAndRunning(openChannels: openChannels, serverChannel: serverChannel) + + case .shuttingDown(var openChannels, let fullyShutdownPromise): + let removedChannel = openChannels.removeValue(forKey: ObjectIdentifier(channel)) + + precondition(removedChannel != nil, "channel \(channel) not in ChannelCollector \(openChannels)") + + if openChannels.isEmpty { + self.shutdownCompleted() + } else { + self.lifecycleState = .shuttingDown(openChannels: openChannels, fullyShutdownPromise: fullyShutdownPromise) + } + + case .shutdownCompleted: + preconditionFailure("We should not have channels removed after transitioned to completed") } } @@ -96,44 +132,39 @@ private final class ChannelCollector { private func initiateShutdown0(promise: EventLoopPromise?) { self.eventLoop.assertInEventLoop() - precondition(self.lifecycleState == .upAndRunning) - self.lifecycleState = .shuttingDown + switch self.lifecycleState { + case .upAndRunning(let openChannels, let serverChannel): + let fullyShutdownPromise = promise ?? serverChannel.eventLoop.makePromise(of: Void.self) - if let promise = promise { - if let alreadyExistingPromise = self.fullyShutdownPromise { - alreadyExistingPromise.futureResult.cascade(to: promise) - } else { - self.fullyShutdownPromise = promise - } - } + self.lifecycleState = .shuttingDown(openChannels: openChannels, fullyShutdownPromise: fullyShutdownPromise) - self.serverChannel.close().cascadeFailure(to: self.fullyShutdownPromise) + serverChannel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent()) + serverChannel.close().cascadeFailure(to: fullyShutdownPromise) - for channel in self.openChannels.values { - channel.eventLoop.execute { - channel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent()) + for channel in openChannels.values { + channel.eventLoop.execute { + channel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent()) + } } - } - if self.openChannels.isEmpty { - shutdownCompleted() + if openChannels.isEmpty { + self.shutdownCompleted() + } + + case .shuttingDown(_, let fullyShutdownPromise): + fullyShutdownPromise.futureResult.cascade(to: promise) + + case .shutdownCompleted: + promise?.succeed(()) } } /// Initiate the shutdown fulfilling `promise` when all the previously registered `Channel`s have been closed. /// /// - parameters: - /// - promise: The `EventLoopPromise` to fulfill when the shutdown of all previously registered `Channel`s has been completed. + /// - promise: The `EventLoopPromise` to fulfil when the shutdown of all previously registered `Channel`s has been completed. func initiateShutdown(promise: EventLoopPromise?) { - if self.serverChannel.eventLoop.inEventLoop { - self.serverChannel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent()) - } else { - self.eventLoop.execute { - self.serverChannel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent()) - } - } - if self.eventLoop.inEventLoop { self.initiateShutdown0(promise: promise) } else { @@ -144,7 +175,6 @@ private final class ChannelCollector { } } - extension ChannelCollector: @unchecked Sendable {} /// A `ChannelHandler` that adds all channels that it receives through the `ChannelPipeline` to a `ChannelCollector`. @@ -173,7 +203,7 @@ private final class CollectAcceptedChannelsHandler: ChannelInboundHandler { do { try self.channelCollector.channelAdded(channel) let closeFuture = channel.closeFuture - closeFuture.whenComplete { (_: Result<(), Error>) in + closeFuture.whenComplete { (_: Result) in self.channelCollector.channelRemoved(channel) } context.fireChannelRead(data) @@ -231,7 +261,7 @@ public final class ServerQuiescingHelper { deinit { self.channelCollectorPromise.fail(UnusedQuiescingHelperError()) } - + /// Create the `ChannelHandler` for the server `channel` to collect all accepted child `Channel`s. /// /// - parameters: @@ -262,6 +292,4 @@ public final class ServerQuiescingHelper { } } -extension ServerQuiescingHelper: Sendable { - -} +extension ServerQuiescingHelper: Sendable {} diff --git a/Tests/NIOExtrasTests/QuiescingHelperTest+XCTest.swift b/Tests/NIOExtrasTests/QuiescingHelperTest+XCTest.swift index 7ca56760..de40aef4 100644 --- a/Tests/NIOExtrasTests/QuiescingHelperTest+XCTest.swift +++ b/Tests/NIOExtrasTests/QuiescingHelperTest+XCTest.swift @@ -31,6 +31,11 @@ extension QuiescingHelperTest { ("testQuiesceUserEventReceivedOnShutdown", testQuiesceUserEventReceivedOnShutdown), ("testQuiescingDoesNotSwallowCloseErrorsFromAcceptHandler", testQuiescingDoesNotSwallowCloseErrorsFromAcceptHandler), ("testShutdownIsImmediateWhenPromiseDoesNotSucceed", testShutdownIsImmediateWhenPromiseDoesNotSucceed), + ("testShutdown_whenAlreadyShutdown", testShutdown_whenAlreadyShutdown), + ("testShutdown_whenNoOpenChild", testShutdown_whenNoOpenChild), + ("testChannelClose_whenRunning", testChannelClose_whenRunning), + ("testChannelAdded_whenShuttingDown", testChannelAdded_whenShuttingDown), + ("testChannelAdded_whenShutdown", testChannelAdded_whenShutdown), ] } } diff --git a/Tests/NIOExtrasTests/QuiescingHelperTest.swift b/Tests/NIOExtrasTests/QuiescingHelperTest.swift index c4a611dd..ac62e78b 100644 --- a/Tests/NIOExtrasTests/QuiescingHelperTest.swift +++ b/Tests/NIOExtrasTests/QuiescingHelperTest.swift @@ -12,12 +12,27 @@ // //===----------------------------------------------------------------------===// -import XCTest import NIOCore import NIOEmbedded +@testable import NIOExtras import NIOPosix import NIOTestUtils -@testable import NIOExtras +import XCTest + +private final class WaitForQuiesceUserEvent: ChannelInboundHandler { + typealias InboundIn = Never + private let promise: EventLoopPromise + + init(promise: EventLoopPromise) { + self.promise = promise + } + + func userInboundEventTriggered(context _: ChannelHandlerContext, event: Any) { + if event is ChannelShouldQuiesceEvent { + self.promise.succeed(()) + } + } +} public class QuiescingHelperTest: XCTestCase { func testShutdownIsImmediateWhenNoChannelsCollected() throws { @@ -35,21 +50,6 @@ public class QuiescingHelperTest: XCTestCase { } func testQuiesceUserEventReceivedOnShutdown() throws { - class WaitForQuiesceUserEvent: ChannelInboundHandler { - typealias InboundIn = Never - private let promise: EventLoopPromise - - init(promise: EventLoopPromise) { - self.promise = promise - } - - func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { - if event is ChannelShouldQuiesceEvent { - self.promise.succeed(()) - } - } - } - let el = EmbeddedEventLoop() let allShutdownPromise: EventLoopPromise = el.makePromise() let serverChannel = EmbeddedChannel(handler: nil, loop: el) @@ -63,7 +63,7 @@ public class QuiescingHelperTest: XCTestCase { // add a bunch of channels for pretendPort in 1...128 { - let waitForPromise: EventLoopPromise<()> = el.makePromise() + let waitForPromise: EventLoopPromise = el.makePromise() let channel = EmbeddedChannel(handler: WaitForQuiesceUserEvent(promise: waitForPromise), loop: el) // activate the child chan XCTAssertNoThrow(try channel.connect(to: .init(ipAddress: "1.2.3.4", port: pretendPort)).wait()) @@ -137,7 +137,7 @@ public class QuiescingHelperTest: XCTestCase { } } - ///verifying that the promise fails when goes out of scope for shutdown + /// verifying that the promise fails when goes out of scope for shutdown func testShutdownIsImmediateWhenPromiseDoesNotSucceed() throws { let el = EmbeddedEventLoop() @@ -151,4 +151,164 @@ public class QuiescingHelperTest: XCTestCase { XCTAssertTrue(error is ServerQuiescingHelper.UnusedQuiescingHelperError) } } + + func testShutdown_whenAlreadyShutdown() throws { + let el = EmbeddedEventLoop() + let channel = EmbeddedChannel(handler: nil, loop: el) + // let's activate the server channel, nothing actually happens as this is an EmbeddedChannel + XCTAssertNoThrow(try channel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 1)).wait()) + XCTAssertTrue(channel.isActive) + let quiesce = ServerQuiescingHelper(group: el) + _ = quiesce.makeServerChannelHandler(channel: channel) + let p1: EventLoopPromise = el.makePromise() + quiesce.initiateShutdown(promise: p1) + XCTAssertNoThrow(try p1.futureResult.wait()) + XCTAssertFalse(channel.isActive) + + let p2: EventLoopPromise = el.makePromise() + quiesce.initiateShutdown(promise: p2) + XCTAssertNoThrow(try p2.futureResult.wait()) + } + + func testShutdown_whenNoOpenChild() throws { + let el = EmbeddedEventLoop() + let channel = EmbeddedChannel(handler: nil, loop: el) + // let's activate the server channel, nothing actually happens as this is an EmbeddedChannel + XCTAssertNoThrow(try channel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 1)).wait()) + XCTAssertTrue(channel.isActive) + let quiesce = ServerQuiescingHelper(group: el) + _ = quiesce.makeServerChannelHandler(channel: channel) + let p1: EventLoopPromise = el.makePromise() + quiesce.initiateShutdown(promise: p1) + el.run() + XCTAssertNoThrow(try p1.futureResult.wait()) + XCTAssertFalse(channel.isActive) + } + + func testChannelClose_whenRunning() throws { + let el = EmbeddedEventLoop() + let allShutdownPromise: EventLoopPromise = el.makePromise() + let serverChannel = EmbeddedChannel(handler: nil, loop: el) + // let's activate the server channel, nothing actually happens as this is an EmbeddedChannel + XCTAssertNoThrow(try serverChannel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 1)).wait()) + let quiesce = ServerQuiescingHelper(group: el) + let collectionHandler = quiesce.makeServerChannelHandler(channel: serverChannel) + XCTAssertNoThrow(try serverChannel.pipeline.addHandler(collectionHandler).wait()) + + // let's one channels + let eventCounterHandler = EventCounterHandler() + let childChannel1 = EmbeddedChannel(handler: eventCounterHandler, loop: el) + // activate the child channel + XCTAssertNoThrow(try childChannel1.connect(to: .init(ipAddress: "1.2.3.4", port: 1)).wait()) + serverChannel.pipeline.fireChannelRead(NIOAny(childChannel1)) + + // check that the server channel and channel are active before initiating the shutdown + XCTAssertTrue(serverChannel.isActive) + XCTAssertTrue(childChannel1.isActive) + + XCTAssertEqual(eventCounterHandler.userInboundEventTriggeredCalls, 0) + + // now close the first child channel + childChannel1.close(promise: nil) + el.run() + + // check that the server is active and child is not + XCTAssertTrue(serverChannel.isActive) + XCTAssertFalse(childChannel1.isActive) + + quiesce.initiateShutdown(promise: allShutdownPromise) + el.run() + + // check that the server channel is closed as the first thing + XCTAssertFalse(serverChannel.isActive) + + el.run() + + // check that the shutdown has completed + XCTAssertNoThrow(try allShutdownPromise.futureResult.wait()) + } + + func testChannelAdded_whenShuttingDown() throws { + let el = EmbeddedEventLoop() + let allShutdownPromise: EventLoopPromise = el.makePromise() + let serverChannel = EmbeddedChannel(handler: nil, loop: el) + // let's activate the server channel, nothing actually happens as this is an EmbeddedChannel + XCTAssertNoThrow(try serverChannel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 1)).wait()) + let quiesce = ServerQuiescingHelper(group: el) + let collectionHandler = quiesce.makeServerChannelHandler(channel: serverChannel) + XCTAssertNoThrow(try serverChannel.pipeline.addHandler(collectionHandler).wait()) + + // let's add one channel + let waitForPromise1: EventLoopPromise = el.makePromise() + let childChannel1 = EmbeddedChannel(handler: WaitForQuiesceUserEvent(promise: waitForPromise1), loop: el) + // activate the child channel + XCTAssertNoThrow(try childChannel1.connect(to: .init(ipAddress: "1.2.3.4", port: 1)).wait()) + serverChannel.pipeline.fireChannelRead(NIOAny(childChannel1)) + + el.run() + + // check that the server and channel are running + XCTAssertTrue(serverChannel.isActive) + XCTAssertTrue(childChannel1.isActive) + + // let's shut down + quiesce.initiateShutdown(promise: allShutdownPromise) + + // let's add one more channel + let waitForPromise2: EventLoopPromise = el.makePromise() + let childChannel2 = EmbeddedChannel(handler: WaitForQuiesceUserEvent(promise: waitForPromise2), loop: el) + // activate the child channel + XCTAssertNoThrow(try childChannel2.connect(to: .init(ipAddress: "1.2.3.4", port: 2)).wait()) + serverChannel.pipeline.fireChannelRead(NIOAny(childChannel2)) + el.run() + + // Check that we got all quiescing events + XCTAssertNoThrow(try waitForPromise1.futureResult.wait() as Void) + XCTAssertNoThrow(try waitForPromise2.futureResult.wait() as Void) + + // check that the server is closed and the children are running + XCTAssertFalse(serverChannel.isActive) + XCTAssertTrue(childChannel1.isActive) + XCTAssertTrue(childChannel2.isActive) + + // let's close the children + childChannel1.close(promise: nil) + childChannel2.close(promise: nil) + el.run() + + // check that everything is closed + XCTAssertFalse(serverChannel.isActive) + XCTAssertFalse(childChannel1.isActive) + XCTAssertFalse(childChannel2.isActive) + + XCTAssertNoThrow(try allShutdownPromise.futureResult.wait() as Void) + } + + func testChannelAdded_whenShutdown() throws { + let el = EmbeddedEventLoop() + let allShutdownPromise: EventLoopPromise = el.makePromise() + let serverChannel = EmbeddedChannel(handler: nil, loop: el) + // let's activate the server channel, nothing actually happens as this is an EmbeddedChannel + XCTAssertNoThrow(try serverChannel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 1)).wait()) + let quiesce = ServerQuiescingHelper(group: el) + let collectionHandler = quiesce.makeServerChannelHandler(channel: serverChannel) + XCTAssertNoThrow(try serverChannel.pipeline.addHandler(collectionHandler).wait()) + + // check that the server is running + XCTAssertTrue(serverChannel.isActive) + + // let's shut down + quiesce.initiateShutdown(promise: allShutdownPromise) + + // check that the shutdown has completed + XCTAssertNoThrow(try allShutdownPromise.futureResult.wait()) + + // let's add one channel + let childChannel1 = EmbeddedChannel(loop: el) + // activate the child channel + XCTAssertNoThrow(try childChannel1.connect(to: .init(ipAddress: "1.2.3.4", port: 1)).wait()) + serverChannel.pipeline.fireChannelRead(NIOAny(childChannel1)) + + el.run() + } }