Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix WebSocketProtocolErrorHandler sending the close frame with appropriate masking key #3040

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
4 changes: 3 additions & 1 deletion Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,9 @@ private func _upgrade<UpgradeResult>(
ByteToMessageHandler(WebSocketFrameDecoder(maxFrameSize: maxFrameSize))
)
if enableAutomaticErrorHandling {
try channel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler())
let errorHandler = WebSocketProtocolErrorHandler()
errorHandler.setIsServer(false)
try channel.pipeline.syncOperations.addHandler(errorHandler)
}
}
.flatMap {
Expand Down
5 changes: 1 addition & 4 deletions Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,8 @@ extension HTTPHeaders {
///
/// This upgrader assumes that the `HTTPServerUpgradeHandler` will appropriately mutate the pipeline to
/// remove the HTTP `ChannelHandler`s.
public final class NIOWebSocketServerUpgrader: HTTPServerProtocolUpgrader, @unchecked Sendable {
// This type *is* Sendable but we can't express that properly until Swift 5.7. In the meantime
// the conformance is `@unchecked`.
public final class NIOWebSocketServerUpgrader: HTTPServerProtocolUpgrader, Sendable {

// FIXME: remove @unchecked when 5.7 is the minimum supported version.
private typealias ShouldUpgrade = @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture<HTTPHeaders?>
private typealias UpgradePipelineHandler = @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture<Void>
/// RFC 6455 specs this as the required entry in the Upgrade header.
Expand Down
22 changes: 21 additions & 1 deletion Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@ public final class WebSocketProtocolErrorHandler: ChannelInboundHandler {
public typealias InboundIn = Never
public typealias OutboundOut = WebSocketFrame

public init() {}
/// Indicate that this `ChannelHandeler` is used by a WebSocket server or client. Default is true.
public private(set) var isServer: Bool
johnnzhou marked this conversation as resolved.
Show resolved Hide resolved

public init() {
self.isServer = true
}

public func errorCaught(context: ChannelHandlerContext, error: Error) {
let loopBoundContext = context.loopBound
Expand All @@ -32,6 +37,7 @@ public final class WebSocketProtocolErrorHandler: ChannelInboundHandler {
let frame = WebSocketFrame(
fin: true,
opcode: .connectionClose,
maskKey: self.makeMaskingKey(),
data: data
)
context.writeAndFlush(Self.wrapOutboundOut(frame)).whenComplete { (_: Result<Void, Error>) in
Expand All @@ -44,6 +50,20 @@ public final class WebSocketProtocolErrorHandler: ChannelInboundHandler {
// forward the error on to let others see it.
context.fireErrorCaught(error)
}

private func makeMaskingKey() -> WebSocketMaskingKey? {
// According to RFC 6455 Section 5, a client *must* mask all frames that it sends to the server.
// A server *must not* mask any frames that it sends to the client
self.isServer ? nil : .random()
}

/// Configure this `ChannelHandler` to be used by a WebSocket server or client.
///
/// - Parameters:
/// - isServer: indicate whether this `ChannelHandler` is used by a WebSocket server or client.
public func setIsServer(_ isServer: Bool) {
self.isServer = isServer
}
}

@available(*, unavailable)
Expand Down
28 changes: 28 additions & 0 deletions Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,34 @@ class WebSocketClientEndToEndTests: XCTestCase {
// Close the pipeline.
XCTAssertNoThrow(try clientChannel.close().wait())
}

func testErrorHandlerMaskFrameForClient() throws {

let (clientChannel, _) = try self.runSuccessfulUpgrade()
let maskBitMask: UInt8 = 0x80

var data = clientChannel.allocator.buffer(capacity: 4)
// A fake frame header that claims that the length of the frame is 16385 bytes,
// larger than the frame max.
data.writeBytes([0x81, 0xFE, 0x40, 0x01])

XCTAssertThrowsError(try clientChannel.writeInbound(data)) { error in
XCTAssertEqual(.invalidFrameLength, error as? NIOWebSocketError)
}

clientChannel.embeddedEventLoop.run()
var buffer = try clientChannel.readAllOutboundBuffers()

guard let (_, secondByte) = buffer.readMultipleIntegers(as: (UInt8, UInt8).self) else {
XCTFail("Insufficient bytes from WebSocket frame")
return
}

let maskedBit = (secondByte & maskBitMask)
XCTAssertEqual(0x80, maskedBit)

XCTAssertNoThrow(!clientChannel.isActive)
}
}

#if !canImport(Darwin) || swift(>=5.10)
Expand Down