diff --git a/.api-breakage/allowlist-branch-sendable-take-2.txt b/.api-breakage/allowlist-branch-sendable-take-2.txt new file mode 100644 index 00000000..a2985625 --- /dev/null +++ b/.api-breakage/allowlist-branch-sendable-take-2.txt @@ -0,0 +1,14 @@ +API breakage: func WebSocket.onText(_:) is now with @preconcurrency +API breakage: func WebSocket.onBinary(_:) is now with @preconcurrency +API breakage: func WebSocket.onPong(_:) is now with @preconcurrency +API breakage: func WebSocket.onPing(_:) is now with @preconcurrency +API breakage: func WebSocket.connect(to:headers:configuration:on:onUpgrade:) is now with @preconcurrency +API breakage: func WebSocket.connect(scheme:host:port:path:query:headers:configuration:on:onUpgrade:) is now with @preconcurrency +API breakage: func WebSocket.connect(scheme:host:port:path:query:headers:proxy:proxyPort:proxyHeaders:proxyConnectDeadline:configuration:on:onUpgrade:) is now with @preconcurrency +API breakage: func WebSocket.connect(to:headers:proxy:proxyPort:proxyHeaders:proxyConnectDeadline:configuration:on:onUpgrade:) is now with @preconcurrency +API breakage: func WebSocket.client(on:onUpgrade:) is now with @preconcurrency +API breakage: func WebSocket.client(on:config:onUpgrade:) is now with @preconcurrency +API breakage: func WebSocket.server(on:onUpgrade:) is now with @preconcurrency +API breakage: func WebSocket.server(on:config:onUpgrade:) is now with @preconcurrency +API breakage: func WebSocketClient.connect(scheme:host:port:path:query:headers:onUpgrade:) is now with @preconcurrency +API breakage: func WebSocketClient.connect(scheme:host:port:path:query:headers:proxy:proxyPort:proxyHeaders:proxyConnectDeadline:onUpgrade:) is now with @preconcurrency diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b2288dc2..7fe9308e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -5,9 +5,7 @@ concurrency: on: pull_request: { types: [opened, reopened, synchronize, ready_for_review] } push: { branches: [ main ] } - jobs: - vapor-integration: if: ${{ !(github.event.pull_request.draft || false) }} runs-on: ubuntu-latest diff --git a/.gitignore b/.gitignore index 68b8b308..630ed81f 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ DerivedData .swiftpm Package.resolved +.devcontainer/ diff --git a/Package.swift b/Package.swift index 4841d15c..86fea2af 100644 --- a/Package.swift +++ b/Package.swift @@ -16,8 +16,8 @@ let package = Package( .package(url: "https://github.com/apple/swift-nio.git", from: "2.53.0"), .package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.16.0"), .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.24.0"), - .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.11.4"), - .package(url: "https://github.com/apple/swift-atomics.git", from: "1.0.2"), + .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.16.0"), + .package(url: "https://github.com/apple/swift-atomics.git", from: "1.1.0"), ], targets: [ .target(name: "WebSocketKit", dependencies: [ diff --git a/Sources/WebSocketKit/Concurrency/WebSocket+Concurrency.swift b/Sources/WebSocketKit/Concurrency/WebSocket+Concurrency.swift index d0e9c0d3..56085bef 100644 --- a/Sources/WebSocketKit/Concurrency/WebSocket+Concurrency.swift +++ b/Sources/WebSocketKit/Concurrency/WebSocket+Concurrency.swift @@ -40,44 +40,52 @@ extension WebSocket { try await close(code: code).get() } - public func onText(_ callback: @escaping (WebSocket, String) async -> ()) { - onText { socket, text in - Task { - await callback(socket, text) + @preconcurrency public func onText(_ callback: @Sendable @escaping (WebSocket, String) async -> ()) { + self.eventLoop.execute { + self.onText { socket, text in + Task { + await callback(socket, text) + } } } } - public func onBinary(_ callback: @escaping (WebSocket, ByteBuffer) async -> ()) { - onBinary { socket, binary in - Task { - await callback(socket, binary) + @preconcurrency public func onBinary(_ callback: @Sendable @escaping (WebSocket, ByteBuffer) async -> ()) { + self.eventLoop.execute { + self.onBinary { socket, binary in + Task { + await callback(socket, binary) + } } } } - public func onPong(_ callback: @escaping (WebSocket) async -> ()) { - onPong { socket in - Task { - await callback(socket) + @preconcurrency public func onPong(_ callback: @Sendable @escaping (WebSocket) async -> ()) { + self.eventLoop.execute { + self.onPong { socket in + Task { + await callback(socket) + } } } } - public func onPing(_ callback: @escaping (WebSocket) async -> ()) { - onPing { socket in - Task { - await callback(socket) + @preconcurrency public func onPing(_ callback: @Sendable @escaping (WebSocket) async -> ()) { + self.eventLoop.execute { + self.onPing { socket in + Task { + await callback(socket) + } } } } - public static func connect( + @preconcurrency public static func connect( to url: String, headers: HTTPHeaders = [:], configuration: WebSocketClient.Configuration = .init(), on eventLoopGroup: EventLoopGroup, - onUpgrade: @escaping (WebSocket) async -> () + onUpgrade: @Sendable @escaping (WebSocket) async -> () ) async throws { return try await self.connect( to: url, @@ -92,12 +100,12 @@ extension WebSocket { ).get() } - public static func connect( + @preconcurrency public static func connect( to url: URL, headers: HTTPHeaders = [:], configuration: WebSocketClient.Configuration = .init(), on eventLoopGroup: EventLoopGroup, - onUpgrade: @escaping (WebSocket) async -> () + onUpgrade: @Sendable @escaping (WebSocket) async -> () ) async throws { return try await self.connect( to: url, @@ -112,7 +120,7 @@ extension WebSocket { ).get() } - public static func connect( + @preconcurrency public static func connect( scheme: String = "ws", host: String, port: Int = 80, @@ -121,7 +129,7 @@ extension WebSocket { headers: HTTPHeaders = [:], configuration: WebSocketClient.Configuration = .init(), on eventLoopGroup: EventLoopGroup, - onUpgrade: @escaping (WebSocket) async -> () + onUpgrade: @Sendable @escaping (WebSocket) async -> () ) async throws { return try await self.connect( scheme: scheme, diff --git a/Sources/WebSocketKit/Exports.swift b/Sources/WebSocketKit/Exports.swift index 11ce0e49..b853992f 100644 --- a/Sources/WebSocketKit/Exports.swift +++ b/Sources/WebSocketKit/Exports.swift @@ -6,9 +6,7 @@ @_documentation(visibility: internal) @_exported import protocol NIOCore.EventLoopGroup @_documentation(visibility: internal) @_exported import struct NIOCore.EventLoopPromise @_documentation(visibility: internal) @_exported import class NIOCore.EventLoopFuture - @_documentation(visibility: internal) @_exported import struct NIOHTTP1.HTTPHeaders - @_documentation(visibility: internal) @_exported import struct Foundation.URL #else @@ -19,9 +17,7 @@ @_exported import protocol NIOCore.EventLoopGroup @_exported import struct NIOCore.EventLoopPromise @_exported import class NIOCore.EventLoopFuture - @_exported import struct NIOHTTP1.HTTPHeaders - @_exported import struct Foundation.URL #endif diff --git a/Sources/WebSocketKit/WebSocket+Connect.swift b/Sources/WebSocketKit/WebSocket+Connect.swift index ca94540b..546401a4 100644 --- a/Sources/WebSocketKit/WebSocket+Connect.swift +++ b/Sources/WebSocketKit/WebSocket+Connect.swift @@ -12,12 +12,13 @@ extension WebSocket { /// - eventLoopGroup: Event loop group to be used by the WebSocket client. /// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`. /// - Returns: An future which completes when the connection to the WebSocket server is established. + @preconcurrency public static func connect( to url: String, headers: HTTPHeaders = [:], configuration: WebSocketClient.Configuration = .init(), on eventLoopGroup: EventLoopGroup, - onUpgrade: @escaping (WebSocket) -> () + onUpgrade: @Sendable @escaping (WebSocket) -> () ) -> EventLoopFuture { guard let url = URL(string: url) else { return eventLoopGroup.any().makeFailedFuture(WebSocketClient.Error.invalidURL) @@ -40,12 +41,13 @@ extension WebSocket { /// - eventLoopGroup: Event loop group to be used by the WebSocket client. /// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`. /// - Returns: An future which completes when the connection to the WebSocket server is established. + @preconcurrency public static func connect( to url: URL, headers: HTTPHeaders = [:], configuration: WebSocketClient.Configuration = .init(), on eventLoopGroup: EventLoopGroup, - onUpgrade: @escaping (WebSocket) -> () + onUpgrade: @Sendable @escaping (WebSocket) -> () ) -> EventLoopFuture { let scheme = url.scheme ?? "ws" return self.connect( @@ -74,6 +76,7 @@ extension WebSocket { /// - eventLoopGroup: Event loop group to be used by the WebSocket client. /// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`. /// - Returns: An future which completes when the connection to the WebSocket server is established. + @preconcurrency public static func connect( scheme: String = "ws", host: String, @@ -83,7 +86,7 @@ extension WebSocket { headers: HTTPHeaders = [:], configuration: WebSocketClient.Configuration = .init(), on eventLoopGroup: EventLoopGroup, - onUpgrade: @escaping (WebSocket) -> () + onUpgrade: @Sendable @escaping (WebSocket) -> () ) -> EventLoopFuture { return WebSocketClient( eventLoopGroupProvider: .shared(eventLoopGroup), @@ -116,6 +119,7 @@ extension WebSocket { /// - eventLoopGroup: Event loop group to be used by the WebSocket client. /// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`. /// - Returns: An future which completes when the connection to the origin server is established. + @preconcurrency public static func connect( scheme: String = "ws", host: String, @@ -129,7 +133,7 @@ extension WebSocket { proxyConnectDeadline: NIODeadline = NIODeadline.distantFuture, configuration: WebSocketClient.Configuration = .init(), on eventLoopGroup: EventLoopGroup, - onUpgrade: @escaping (WebSocket) -> () + onUpgrade: @Sendable @escaping (WebSocket) -> () ) -> EventLoopFuture { return WebSocketClient( eventLoopGroupProvider: .shared(eventLoopGroup), @@ -162,6 +166,7 @@ extension WebSocket { /// - eventLoopGroup: Event loop group to be used by the WebSocket client. /// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`. /// - Returns: An future which completes when the connection to the origin server is established. + @preconcurrency public static func connect( to url: String, headers: HTTPHeaders = [:], @@ -171,7 +176,7 @@ extension WebSocket { proxyConnectDeadline: NIODeadline = NIODeadline.distantFuture, configuration: WebSocketClient.Configuration = .init(), on eventLoopGroup: EventLoopGroup, - onUpgrade: @escaping (WebSocket) -> () + onUpgrade: @Sendable @escaping (WebSocket) -> () ) -> EventLoopFuture { guard let url = URL(string: url) else { return eventLoopGroup.any().makeFailedFuture(WebSocketClient.Error.invalidURL) diff --git a/Sources/WebSocketKit/WebSocket.swift b/Sources/WebSocketKit/WebSocket.swift index 6ab6b6fd..245d97a2 100644 --- a/Sources/WebSocketKit/WebSocket.swift +++ b/Sources/WebSocketKit/WebSocket.swift @@ -4,9 +4,10 @@ import NIOHTTP1 import NIOSSL import Foundation import NIOFoundationCompat +import NIOConcurrencyHelpers -public final class WebSocket { - enum PeerType { +public final class WebSocket: Sendable { + enum PeerType: Sendable { case server case client } @@ -18,7 +19,11 @@ public final class WebSocket { public var isClosed: Bool { !self.channel.isActive } - public private(set) var closeCode: WebSocketErrorCode? + public var closeCode: WebSocketErrorCode? { + _closeCode.withLockedValue { $0 } + } + + private let _closeCode: NIOLockedValueBox public var onClose: EventLoopFuture { self.channel.closeFuture @@ -27,42 +32,46 @@ public final class WebSocket { @usableFromInline /* private but @usableFromInline */ internal let channel: Channel - private var onTextCallback: (WebSocket, String) -> () - private var onBinaryCallback: (WebSocket, ByteBuffer) -> () - private var onPongCallback: (WebSocket) -> () - private var onPingCallback: (WebSocket) -> () - private var frameSequence: WebSocketFrameSequence? + private let onTextCallback: NIOLoopBoundBox<@Sendable (WebSocket, String) -> ()> + private let onBinaryCallback: NIOLoopBoundBox<@Sendable (WebSocket, ByteBuffer) -> ()> + private let onPongCallback: NIOLoopBoundBox<@Sendable (WebSocket) -> ()> + private let onPingCallback: NIOLoopBoundBox<@Sendable (WebSocket) -> ()> private let type: PeerType - private var waitingForPong: Bool - private var waitingForClose: Bool - private var scheduledTimeoutTask: Scheduled? + private let waitingForPong: NIOLockedValueBox + private let waitingForClose: NIOLockedValueBox + private let scheduledTimeoutTask: NIOLockedValueBox?> + private let frameSequence: NIOLockedValueBox + private let _pingInterval: NIOLockedValueBox init(channel: Channel, type: PeerType) { self.channel = channel self.type = type - self.onTextCallback = { _, _ in } - self.onBinaryCallback = { _, _ in } - self.onPongCallback = { _ in } - self.onPingCallback = { _ in } - self.waitingForPong = false - self.waitingForClose = false - self.scheduledTimeoutTask = nil + self.onTextCallback = .init({ _, _ in }, eventLoop: channel.eventLoop) + self.onBinaryCallback = .init({ _, _ in }, eventLoop: channel.eventLoop) + self.onPongCallback = .init({ _ in }, eventLoop: channel.eventLoop) + self.onPingCallback = .init({ _ in }, eventLoop: channel.eventLoop) + self.waitingForPong = .init(false) + self.waitingForClose = .init(false) + self.scheduledTimeoutTask = .init(nil) + self._closeCode = .init(nil) + self.frameSequence = .init(nil) + self._pingInterval = .init(nil) } - public func onText(_ callback: @escaping (WebSocket, String) -> ()) { - self.onTextCallback = callback + @preconcurrency public func onText(_ callback: @Sendable @escaping (WebSocket, String) -> ()) { + self.onTextCallback.value = callback } - public func onBinary(_ callback: @escaping (WebSocket, ByteBuffer) -> ()) { - self.onBinaryCallback = callback + @preconcurrency public func onBinary(_ callback: @Sendable @escaping (WebSocket, ByteBuffer) -> ()) { + self.onBinaryCallback.value = callback } - public func onPong(_ callback: @escaping (WebSocket) -> ()) { - self.onPongCallback = callback + @preconcurrency public func onPong(_ callback: @Sendable @escaping (WebSocket) -> ()) { + self.onPongCallback.value = callback } - public func onPing(_ callback: @escaping (WebSocket) -> ()) { - self.onPingCallback = callback + @preconcurrency public func onPing(_ callback: @Sendable @escaping (WebSocket) -> ()) { + self.onPingCallback.value = callback } /// If set, this will trigger automatic pings on the connection. If ping is not answered before @@ -72,14 +81,18 @@ public final class WebSocket { /// mechanism shutting down inactive connections, such as a Load Balancer deployed in /// front of the server. public var pingInterval: TimeAmount? { - didSet { - if pingInterval != nil { - if scheduledTimeoutTask == nil { - waitingForPong = false + get { + return _pingInterval.withLockedValue { $0 } + } + set { + _pingInterval.withLockedValue { $0 = newValue } + if newValue != nil { + if scheduledTimeoutTask.withLockedValue({ $0 == nil }) { + waitingForPong.withLockedValue { $0 = false } self.pingAndScheduleNextTimeoutTask() } } else { - scheduledTimeoutTask?.cancel() + scheduledTimeoutTask.withLockedValue { $0?.cancel() } } } } @@ -160,12 +173,12 @@ public final class WebSocket { promise?.succeed(()) return } - guard !self.waitingForClose else { + guard !self.waitingForClose.withLockedValue({ $0 }) else { promise?.succeed(()) return } - self.waitingForClose = true - self.closeCode = code + self.waitingForClose.withLockedValue { $0 = true } + self._closeCode.withLockedValue { $0 = code } let codeAsInt = UInt16(webSocketErrorCode: code) let codeToSend: WebSocketErrorCode @@ -197,7 +210,7 @@ public final class WebSocket { func handle(incoming frame: WebSocketFrame) { switch frame.opcode { case .connectionClose: - if self.waitingForClose { + if self.waitingForClose.withLockedValue({ $0 }) { // peer confirmed close, time to close channel self.channel.close(mode: .all, promise: nil) } else { @@ -223,7 +236,7 @@ public final class WebSocket { if let maskingKey = maskingKey { frameData.webSocketUnmask(maskingKey) } - self.onPingCallback(self) + self.onPingCallback.value(self) self.send( raw: frameData.readableBytesView, opcode: .pong, @@ -240,22 +253,19 @@ public final class WebSocket { if let maskingKey = maskingKey { frameData.webSocketUnmask(maskingKey) } - self.waitingForPong = false - self.onPongCallback(self) + self.waitingForPong.withLockedValue { $0 = false } + self.onPongCallback.value(self) } else { self.close(code: .protocolError, promise: nil) } case .text, .binary: // create a new frame sequence or use existing - var frameSequence: WebSocketFrameSequence - if let existing = self.frameSequence { - frameSequence = existing - } else { - frameSequence = WebSocketFrameSequence(type: frame.opcode) + self.frameSequence.withLockedValue { currentFrameSequence in + var frameSequence = currentFrameSequence ?? .init(type: frame.opcode) + // append this frame and update the sequence + frameSequence.append(frame) + currentFrameSequence = frameSequence } - // append this frame and update the sequence - frameSequence.append(frame) - self.frameSequence = frameSequence case .continuation: /// continuations are filtered by ``NIOWebSocketFrameAggregator`` preconditionFailure("We will never receive a continuation frame") @@ -266,26 +276,29 @@ public final class WebSocket { // if this frame was final and we have a non-nil frame sequence, // output it to the websocket and clear storage - if let frameSequence = self.frameSequence, frame.fin { - switch frameSequence.type { - case .binary: - self.onBinaryCallback(self, frameSequence.binaryBuffer) - case .text: - self.onTextCallback(self, frameSequence.textBuffer) - case .ping, .pong: - assertionFailure("Control frames never have a frameSequence") - default: break + self.frameSequence.withLockedValue { currentFrameSequence in + if let frameSequence = currentFrameSequence, frame.fin { + switch frameSequence.type { + case .binary: + self.onBinaryCallback.value(self, frameSequence.binaryBuffer) + case .text: + self.onTextCallback.value(self, frameSequence.textBuffer) + case .ping, .pong: + assertionFailure("Control frames never have a frameSequence") + default: break + } + currentFrameSequence = nil } - self.frameSequence = nil } } + @Sendable private func pingAndScheduleNextTimeoutTask() { guard channel.isActive, let pingInterval = pingInterval else { return } - if waitingForPong { + if waitingForPong.withLockedValue({ $0 }) { // We never received a pong from our last ping, so the connection has timed out let promise = self.eventLoop.makePromise(of: Void.self) self.close(code: .unknown(1006), promise: promise) @@ -298,11 +311,13 @@ public final class WebSocket { } } else { self.sendPing() - self.waitingForPong = true - self.scheduledTimeoutTask = self.eventLoop.scheduleTask( - deadline: .now() + pingInterval, - self.pingAndScheduleNextTimeoutTask - ) + self.waitingForPong.withLockedValue { $0 = true } + self.scheduledTimeoutTask.withLockedValue { + $0 = self.eventLoop.scheduleTask( + deadline: .now() + pingInterval, + self.pingAndScheduleNextTimeoutTask + ) + } } } @@ -311,27 +326,31 @@ public final class WebSocket { } } -private struct WebSocketFrameSequence { +private struct WebSocketFrameSequence: Sendable { var binaryBuffer: ByteBuffer var textBuffer: String - var type: WebSocketOpcode + let type: WebSocketOpcode + let lock: NIOLock init(type: WebSocketOpcode) { self.binaryBuffer = ByteBufferAllocator().buffer(capacity: 0) self.textBuffer = .init() self.type = type + self.lock = .init() } mutating func append(_ frame: WebSocketFrame) { - var data = frame.unmaskedData - switch type { - case .binary: - self.binaryBuffer.writeBuffer(&data) - case .text: - if let string = data.readString(length: data.readableBytes) { - self.textBuffer += string + self.lock.withLockVoid { + var data = frame.unmaskedData + switch type { + case .binary: + self.binaryBuffer.writeBuffer(&data) + case .text: + if let string = data.readString(length: data.readableBytes) { + self.textBuffer += string + } + default: break } - default: break } } } diff --git a/Sources/WebSocketKit/WebSocketClient.swift b/Sources/WebSocketKit/WebSocketClient.swift index 08bf48e5..df96d6bb 100644 --- a/Sources/WebSocketKit/WebSocketClient.swift +++ b/Sources/WebSocketKit/WebSocketClient.swift @@ -9,7 +9,7 @@ import NIOSSL import NIOTransportServices import Atomics -public final class WebSocketClient { +public final class WebSocketClient: Sendable { public enum Error: Swift.Error, LocalizedError { case invalidURL case invalidResponseStatus(HTTPResponseHead) @@ -21,7 +21,7 @@ public final class WebSocketClient { public typealias EventLoopGroupProvider = NIOEventLoopGroupProvider - public struct Configuration { + public struct Configuration: Sendable { public var tlsConfiguration: TLSConfiguration? public var maxFrameSize: Int @@ -63,6 +63,7 @@ public final class WebSocketClient { self.configuration = configuration } + @preconcurrency public func connect( scheme: String, host: String, @@ -70,7 +71,7 @@ public final class WebSocketClient { path: String = "/", query: String? = nil, headers: HTTPHeaders = [:], - onUpgrade: @escaping (WebSocket) -> () + onUpgrade: @Sendable @escaping (WebSocket) -> () ) -> EventLoopFuture { self.connect(scheme: scheme, host: host, port: port, path: path, query: query, headers: headers, proxy: nil, onUpgrade: onUpgrade) } @@ -90,6 +91,7 @@ public final class WebSocketClient { /// - proxyConnectDeadline: Deadline for establishing the proxy connection. /// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`. /// - Returns: An future which completes when the connection to the origin server is established. + @preconcurrency public func connect( scheme: String, host: String, @@ -101,7 +103,7 @@ public final class WebSocketClient { proxyPort: Int? = nil, proxyHeaders: HTTPHeaders = [:], proxyConnectDeadline: NIODeadline = NIODeadline.distantFuture, - onUpgrade: @escaping (WebSocket) -> () + onUpgrade: @Sendable @escaping (WebSocket) -> () ) -> EventLoopFuture { assert(["ws", "wss"].contains(scheme)) let upgradePromise = self.group.any().makePromise(of: Void.self) @@ -130,6 +132,7 @@ public final class WebSocketClient { headers: upgradeRequestHeaders, upgradePromise: upgradePromise ) + let httpUpgradeRequestHandlerBox = NIOLoopBound(httpUpgradeRequestHandler, eventLoop: channel.eventLoop) let websocketUpgrader = NIOWebSocketClientUpgrader( maxFrameSize: self.configuration.maxFrameSize, @@ -143,9 +146,10 @@ public final class WebSocketClient { upgraders: [websocketUpgrader], completionHandler: { context in upgradePromise.succeed(()) - channel.pipeline.removeHandler(httpUpgradeRequestHandler, promise: nil) + channel.pipeline.removeHandler(httpUpgradeRequestHandlerBox.value, promise: nil) } ) + let configBox = NIOLoopBound(config, eventLoop: channel.eventLoop) if proxy == nil || scheme == "ws" { if scheme == "wss" { @@ -163,15 +167,15 @@ public final class WebSocketClient { leftOverBytesStrategy: .forwardBytes, withClientUpgrade: config ).flatMap { - channel.pipeline.addHandler(httpUpgradeRequestHandler) + channel.pipeline.addHandler(httpUpgradeRequestHandlerBox.value) } } // TLS + proxy // we need to handle connecting with an additional CONNECT request let proxyEstablishedPromise = channel.eventLoop.makePromise(of: Void.self) - let encoder = HTTPRequestEncoder() - let decoder = ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .dropBytes)) + let encoder = NIOLoopBound(HTTPRequestEncoder(), eventLoop: channel.eventLoop) + let decoder = NIOLoopBound(ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .dropBytes)), eventLoop: channel.eventLoop) var connectHeaders = proxyHeaders connectHeaders.add(name: "Host", value: host) @@ -188,17 +192,17 @@ public final class WebSocketClient { // They are then removed upon completion only to be re-added in `addHTTPClientHandlers`. // This is done because the HTTP decoder is not valid after an upgrade, the CONNECT request being counted as one. do { - try channel.pipeline.syncOperations.addHandler(encoder) - try channel.pipeline.syncOperations.addHandler(decoder) + try channel.pipeline.syncOperations.addHandler(encoder.value) + try channel.pipeline.syncOperations.addHandler(decoder.value) try channel.pipeline.syncOperations.addHandler(proxyRequestHandler) } catch { return channel.eventLoop.makeFailedFuture(error) } proxyEstablishedPromise.futureResult.flatMap { - channel.pipeline.removeHandler(decoder) + channel.pipeline.removeHandler(decoder.value) }.flatMap { - channel.pipeline.removeHandler(encoder) + channel.pipeline.removeHandler(encoder.value) }.whenComplete { result in switch result { case .success: @@ -209,9 +213,9 @@ public final class WebSocketClient { try channel.pipeline.syncOperations.addHandler(tlsHandler) try channel.pipeline.syncOperations.addHTTPClientHandlers( leftOverBytesStrategy: .forwardBytes, - withClientUpgrade: config + withClientUpgrade: configBox.value ) - try channel.pipeline.syncOperations.addHandler(httpUpgradeRequestHandler) + try channel.pipeline.syncOperations.addHandler(httpUpgradeRequestHandlerBox.value) } catch { channel.pipeline.close(mode: .all, promise: nil) } @@ -230,6 +234,7 @@ public final class WebSocketClient { } } + @Sendable private func makeTLSHandler(tlsConfiguration: TLSConfiguration?, host: String) throws -> NIOSSLClientHandler { let context = try NIOSSLContext( configuration: self.configuration.tlsConfiguration ?? .makeClientConfiguration() diff --git a/Sources/WebSocketKit/WebSocketHandler.swift b/Sources/WebSocketKit/WebSocketHandler.swift index 45f266ce..6e333dc3 100644 --- a/Sources/WebSocketKit/WebSocketHandler.swift +++ b/Sources/WebSocketKit/WebSocketHandler.swift @@ -4,7 +4,7 @@ import NIOWebSocket extension WebSocket { /// Stores configuration for a WebSocket client/server instance - public struct Configuration { + public struct Configuration: Sendable { /// Defends against small payloads in frame aggregation. /// See `NIOWebSocketFrameAggregator` for details. public var minNonFinalFragmentSize: Int @@ -33,9 +33,10 @@ extension WebSocket { /// - channel: NIO channel which the client will use to communicate. /// - onUpgrade: An escaping closure to be executed the channel is configured with the WebSocket handlers. /// - Returns: An future which completes when the WebSocket connection to the server is established. + @preconcurrency public static func client( on channel: Channel, - onUpgrade: @escaping (WebSocket) -> () + onUpgrade: @Sendable @escaping (WebSocket) -> () ) -> EventLoopFuture { return self.configure(on: channel, as: .client, with: Configuration(), onUpgrade: onUpgrade) } @@ -46,10 +47,11 @@ extension WebSocket { /// - config: Configuration for the client channel handlers. /// - onUpgrade: An escaping closure to be executed the channel is configured with the WebSocket handlers. /// - Returns: An future which completes when the WebSocket connection to the server is established. + @preconcurrency public static func client( on channel: Channel, config: Configuration, - onUpgrade: @escaping (WebSocket) -> () + onUpgrade: @Sendable @escaping (WebSocket) -> () ) -> EventLoopFuture { return self.configure(on: channel, as: .client, with: config, onUpgrade: onUpgrade) } @@ -59,9 +61,10 @@ extension WebSocket { /// - channel: NIO channel which the server will use to communicate. /// - onUpgrade: An escaping closure to be executed the channel is configured with the WebSocket handlers. /// - Returns: An future which completes when the WebSocket connection to the server is established. + @preconcurrency public static func server( on channel: Channel, - onUpgrade: @escaping (WebSocket) -> () + onUpgrade: @Sendable @escaping (WebSocket) -> () ) -> EventLoopFuture { return self.configure(on: channel, as: .server, with: Configuration(), onUpgrade: onUpgrade) } @@ -72,10 +75,11 @@ extension WebSocket { /// - config: Configuration for the server channel handlers. /// - onUpgrade: An escaping closure to be executed the channel is configured with the WebSocket handlers. /// - Returns: An future which completes when the WebSocket connection to the server is established. + @preconcurrency public static func server( on channel: Channel, config: Configuration, - onUpgrade: @escaping (WebSocket) -> () + onUpgrade: @Sendable @escaping (WebSocket) -> () ) -> EventLoopFuture { return self.configure(on: channel, as: .server, with: config, onUpgrade: onUpgrade) } @@ -84,7 +88,7 @@ extension WebSocket { on channel: Channel, as type: PeerType, with config: Configuration, - onUpgrade: @escaping (WebSocket) -> () + onUpgrade: @Sendable @escaping (WebSocket) -> () ) -> EventLoopFuture { let webSocket = WebSocket(channel: channel, type: type) diff --git a/Tests/WebSocketKitTests/AsyncWebSocketKitTests.swift b/Tests/WebSocketKitTests/AsyncWebSocketKitTests.swift index e20ffa93..e662b2c6 100644 --- a/Tests/WebSocketKitTests/AsyncWebSocketKitTests.swift +++ b/Tests/WebSocketKitTests/AsyncWebSocketKitTests.swift @@ -5,6 +5,12 @@ import NIOWebSocket @testable import WebSocketKit final class AsyncWebSocketKitTests: XCTestCase { + + override func setUp() async throws { + // Handy for catching hangs in the tests. See https://github.com/apple/swift-corelibs-xctest/issues/422#issuecomment-1310952437 + fflush(stdout) + } + func testWebSocketEcho() async throws { let server = try await ServerBootstrap.webSocket(on: self.elg) { req, ws in ws.onText { ws, text in @@ -21,15 +27,15 @@ final class AsyncWebSocketKitTests: XCTestCase { try await WebSocket.connect(to: "ws://localhost:\(port)", on: elg) { ws in do { - try await ws.send("hello") ws.onText { ws, string in - promise.succeed(string) do { try await ws.close() } catch { XCTFail("Failed to close websocket, error: \(error)") } + promise.succeed(string) } + try await ws.send("hello") } catch { promise.fail(error) } @@ -39,23 +45,6 @@ final class AsyncWebSocketKitTests: XCTestCase { XCTAssertEqual(result, "hello") try await server.close(mode: .all) } - - func testAlternateWebsocketConnectMethods() async throws { - let server = try await ServerBootstrap.webSocket(on: self.elg) { $1.onText { $0.send($1) } }.bind(host: "localhost", port: 0).get() - let promise = self.elg.any().makePromise(of: Void.self) - guard let port = server.localAddress?.port else { - return XCTFail("couldn't get port from \(String(reflecting: server.localAddress))") - } - try await WebSocket.connect(scheme: "ws", host: "localhost", port: port, on: self.elg) { (ws) async in - do { try await ws.send("hello") } catch { promise.fail(error); try? await ws.close() } - ws.onText { ws, _ in - promise.succeed(()) - do { try await ws.close() } catch { XCTFail("Failed to close websocket: \(String(reflecting: error))") } - } - } - try await promise.futureResult.get() - try await server.close(mode: .all) - } func testBadURLInWebsocketConnect() async throws { do { @@ -77,11 +66,17 @@ final class AsyncWebSocketKitTests: XCTestCase { return XCTFail("couldn't get port from \(String(reflecting: server.localAddress))") } try await WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { ws in - do { try await ws.send([0x01]) } catch { promise.fail(error); try? await ws.close() } ws.onBinary { ws, buf in - promise.succeed(.init(buf.readableBytesView)) do { try await ws.close() } catch { XCTFail("Failed to close websocket: \(String(reflecting: error))") } + promise.succeed(.init(buf.readableBytesView)) + } + + do { + try await ws.send([0x01]) + } catch { + try? await ws.close() + promise.fail(error); } } let result = try await promise.futureResult.get() @@ -96,10 +91,19 @@ final class AsyncWebSocketKitTests: XCTestCase { return XCTFail("couldn't get port from \(String(reflecting: server.localAddress))") } try await WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { (ws) async in - do { try await ws.sendPing() } catch { promise.fail(error); try? await ws.close() } ws.onPong { + do { + try await $0.close() + } catch { + XCTFail("Failed to close websocket: \(String(reflecting: error))") + } promise.succeed(()) - do { try await $0.close() } catch { XCTFail("Failed to close websocket: \(String(reflecting: error))") } + } + do { + try await ws.sendPing() + } catch { + try? await ws.close() + promise.fail(error) } } try await promise.futureResult.get() @@ -115,8 +119,8 @@ final class AsyncWebSocketKitTests: XCTestCase { try await WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { (ws) async in ws.pingInterval = .milliseconds(100) ws.onPong { - promise.succeed(()) do { try await $0.close() } catch { XCTFail("Failed to close websocket: \(String(reflecting: error))") } + promise.succeed(()) } } try await promise.futureResult.get() diff --git a/Tests/WebSocketKitTests/SSLTestHelpers.swift b/Tests/WebSocketKitTests/SSLTestHelpers.swift index d515fe3e..a6776aab 100644 --- a/Tests/WebSocketKitTests/SSLTestHelpers.swift +++ b/Tests/WebSocketKitTests/SSLTestHelpers.swift @@ -18,6 +18,7 @@ import Foundation import NIOCore @testable import NIOSSL + // This function generates a random number suitable for use in an X509 // serial field. This needs to be a positive number less than 2^159 // (such that it will fit into 20 ASN.1 bytes). diff --git a/Tests/WebSocketKitTests/WebSocketKitTests.swift b/Tests/WebSocketKitTests/WebSocketKitTests.swift index 985cb00b..af3d3cf7 100644 --- a/Tests/WebSocketKitTests/WebSocketKitTests.swift +++ b/Tests/WebSocketKitTests/WebSocketKitTests.swift @@ -8,6 +8,10 @@ import NIOWebSocket @testable import WebSocketKit final class WebSocketKitTests: XCTestCase { + override func setUp() async throws { + fflush(stdout) + } + func testWebSocketEcho() throws { let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in ws.onText { ws, text in @@ -23,11 +27,11 @@ final class WebSocketKitTests: XCTestCase { let promise = elg.any().makePromise(of: String.self) let closePromise = elg.any().makePromise(of: Void.self) WebSocket.connect(to: "ws://localhost:\(port)", on: elg) { ws in - ws.send("hello") ws.onText { ws, string in - promise.succeed(string) ws.close(promise: closePromise) + promise.succeed(string) } + ws.send("hello") }.cascadeFailure(to: promise) try XCTAssertEqual(promise.futureResult.wait(), "hello") XCTAssertNoThrow(try closePromise.futureResult.wait()) @@ -56,8 +60,8 @@ final class WebSocketKitTests: XCTestCase { } WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { ws in - ws.send("close", promise: sendPromise) ws.onClose.cascade(to: clientClose) + ws.send("close", promise: sendPromise) }.cascadeFailure(to: sendPromise) XCTAssertNoThrow(try sendPromise.futureResult.wait()) @@ -83,12 +87,12 @@ final class WebSocketKitTests: XCTestCase { } WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { ws in - ws.send("close", promise: sendPromise) ws.onText { ws, text in if text == "close" { ws.close(promise: clientClose) } } + ws.send("close", promise: sendPromise) }.cascadeFailure(to: sendPromise) XCTAssertNoThrow(try sendPromise.futureResult.wait()) @@ -100,11 +104,11 @@ final class WebSocketKitTests: XCTestCase { func testImmediateSend() throws { let promise = self.elg.any().makePromise(of: String.self) let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in - ws.send("hello") ws.onText { ws, string in promise.succeed(string) ws.close(promise: nil) } + ws.send("hello") }.bind(host: "localhost", port: 0).wait() guard let port = server.localAddress?.port else { @@ -140,11 +144,11 @@ final class WebSocketKitTests: XCTestCase { } WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { ws in - ws.send(raw: pingPongData.readableBytesView, opcode: .ping) ws.onPong { ws in pongPromise.succeed("pong") ws.close(promise: nil) } + ws.send(raw: pingPongData.readableBytesView, opcode: .ping) }.cascadeFailure(to: pongPromise) try XCTAssertEqual(pingPromise.futureResult.wait(), "ping") @@ -174,13 +178,13 @@ final class WebSocketKitTests: XCTestCase { let promise = elg.any().makePromise(of: String.self) let closePromise = elg.any().makePromise(of: Void.self) WebSocket.connect(to: "ws://localhost:\(port)", on: elg) { ws in - ws.send(.init(string: "Hel"), opcode: .text, fin: false) - ws.send(.init(string: "lo! Vapor r"), opcode: .continuation, fin: false) - ws.send(.init(string: "ules"), opcode: .continuation, fin: true) ws.onText { ws, string in - promise.succeed(string) ws.close(promise: closePromise) + promise.succeed(string) } + ws.send(.init(string: "Hel"), opcode: .text, fin: false) + ws.send(.init(string: "lo! Vapor r"), opcode: .continuation, fin: false) + ws.send(.init(string: "ules"), opcode: .continuation, fin: true) }.cascadeFailure(to: promise) try XCTAssertEqual(promise.futureResult.wait(), "Hello! Vapor rules the most") XCTAssertNoThrow(try closePromise.futureResult.wait()) @@ -204,8 +208,8 @@ final class WebSocketKitTests: XCTestCase { ws.send("goodbye") } ws.onClose.whenSuccess { - promise.succeed(ws.closeCode!) XCTAssertEqual(ws.closeCode, WebSocketErrorCode.normalClosure) + promise.succeed(ws.closeCode!) } }.cascadeFailure(to: promise) @@ -228,9 +232,8 @@ final class WebSocketKitTests: XCTestCase { headers.contains(name: "Content-Length") || headers.contains(name: "Content-Type") ) - promiseHasUnwantedHeaders.succeed(hasUnwantedHeaders) - ws.close(promise: nil) + promiseHasUnwantedHeaders.succeed(hasUnwantedHeaders) }.bind(host: "localhost", port: 0).wait() guard let port = server.localAddress?.port else { @@ -254,8 +257,8 @@ final class WebSocketKitTests: XCTestCase { let promise = self.elg.any().makePromise(of: String.self) let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in - promise.succeed(req.uri) ws.close(promise: nil) + promise.succeed(req.uri) }.bind(host: "localhost", port: 0).wait() guard let port = server.localAddress?.port else { @@ -281,8 +284,6 @@ final class WebSocketKitTests: XCTestCase { let shutdownPromise = self.elg.any().makePromise(of: Void.self) let server = try! ServerBootstrap.webSocket(on: self.elg) { req, ws in - ws.send("welcome!") - ws.onClose.whenComplete { print("ws.onClose done: \($0)") } @@ -299,6 +300,8 @@ final class WebSocketKitTests: XCTestCase { ws.send(text.reversed()) } } + + ws.send("welcome!") }.bind(host: "localhost", port: port).wait() print("Serving at ws://localhost:\(port)") @@ -379,8 +382,8 @@ final class WebSocketKitTests: XCTestCase { ) { ws in ws.send("hello") ws.onText { ws, string in - promise.succeed(string) ws.close(promise: closePromise) + promise.succeed(string) } }.cascadeFailure(to: promise) @@ -442,11 +445,11 @@ final class WebSocketKitTests: XCTestCase { proxyPort: localWebsocketBin.port, proxyHeaders: HTTPHeaders([("proxy-authorization", "token amFwcGxlc2VlZDpwYXNzMTIz")]) ) { ws in - ws.send("hello") ws.onText { ws, string in - promise.succeed(string) ws.close(promise: closePromise) + promise.succeed(string) } + ws.send("hello") }.cascadeFailure(to: promise) XCTAssertEqual(try promise.futureResult.wait(), "hello") @@ -488,11 +491,11 @@ final class WebSocketKitTests: XCTestCase { return XCTFail("couldn't get port from \(String(reflecting: server.localAddress))") } WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { ws in - ws.send([0x01]) ws.onBinary { ws, buf in - promise.succeed(.init(buf.readableBytesView)) ws.close(promise: closePromise) + promise.succeed(.init(buf.readableBytesView)) } + ws.send([0x01]) }.whenFailure { promise.fail($0) closePromise.fail($0) @@ -510,11 +513,11 @@ final class WebSocketKitTests: XCTestCase { return XCTFail("couldn't get port from \(String(reflecting: server.localAddress))") } WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { ws in - ws.sendPing() ws.onPong { - promise.succeed() $0.close(promise: closePromise) + promise.succeed() } + ws.sendPing() }.cascadeFailure(to: closePromise) XCTAssertNoThrow(try promise.futureResult.wait()) XCTAssertNoThrow(try closePromise.futureResult.wait()) @@ -531,8 +534,8 @@ final class WebSocketKitTests: XCTestCase { WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { ws in ws.pingInterval = .milliseconds(100) ws.onPong { - promise.succeed() $0.close(promise: closePromise) + promise.succeed() } }.cascadeFailure(to: closePromise) XCTAssertNoThrow(try promise.futureResult.wait())