diff --git a/Sources/NIOSOCKS/Channel Handlers/SOCKSServerHandshakeHandler.swift b/Sources/NIOSOCKS/Channel Handlers/SOCKSServerHandshakeHandler.swift index 03afe783..4d1b18cd 100644 --- a/Sources/NIOSOCKS/Channel Handlers/SOCKSServerHandshakeHandler.swift +++ b/Sources/NIOSOCKS/Channel Handlers/SOCKSServerHandshakeHandler.swift @@ -71,14 +71,17 @@ public final class SOCKSServerHandshakeHandler: ChannelDuplexHandler, RemovableC public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { do { let message = self.unwrapOutboundIn(data) + let outboundBuffer: ByteBuffer switch message { case .selectedAuthenticationMethod(let method): - try self.handleWriteSelectedAuthenticationMethod(method, context: context, promise: promise) + outboundBuffer = try self.handleWriteSelectedAuthenticationMethod(method, context: context) case .response(let response): - try self.handleWriteResponse(response, context: context, promise: promise) + outboundBuffer = try self.handleWriteResponse(response, context: context) case .authenticationData(let data, let complete): - try self.handleWriteAuthenticationData(data, complete: complete, context: context, promise: promise) + outboundBuffer = try self.handleWriteAuthenticationData(data, complete: complete, context: context) } + context.write(self.wrapOutboundOut(outboundBuffer), promise: promise) + } catch { context.fireErrorCaught(error) promise?.fail(error) @@ -86,31 +89,25 @@ public final class SOCKSServerHandshakeHandler: ChannelDuplexHandler, RemovableC } private func handleWriteSelectedAuthenticationMethod( - _ method: SelectedAuthenticationMethod, context: ChannelHandlerContext, promise: EventLoopPromise?) throws { + _ method: SelectedAuthenticationMethod, context: ChannelHandlerContext) throws -> ByteBuffer { try stateMachine.sendAuthenticationMethod(method) var buffer = context.channel.allocator.buffer(capacity: 16) buffer.writeMethodSelection(method) - context.write(self.wrapOutboundOut(buffer), promise: promise) + return buffer } private func handleWriteResponse( - _ response: SOCKSResponse, context: ChannelHandlerContext, promise: EventLoopPromise?) throws { + _ response: SOCKSResponse, context: ChannelHandlerContext) throws -> ByteBuffer { try stateMachine.sendServerResponse(response) var buffer = context.channel.allocator.buffer(capacity: 16) buffer.writeServerResponse(response) - context.write(self.wrapOutboundOut(buffer), promise: promise) + return buffer } - private func handleWriteAuthenticationData(_ data: ByteBuffer, complete: Bool, context: ChannelHandlerContext, promise: EventLoopPromise?) throws { - do { - try self.stateMachine.sendData() - if complete { - try self.stateMachine.authenticationComplete() - } - context.write(self.wrapOutboundOut(data), promise: promise) - } catch { - promise?.fail(error) - } + private func handleWriteAuthenticationData( + _ data: ByteBuffer, complete: Bool, context: ChannelHandlerContext) throws -> ByteBuffer { + try self.stateMachine.sendAuthenticationData(data, complete: complete) + return data } } diff --git a/Sources/NIOSOCKS/State/ServerStateMachine.swift b/Sources/NIOSOCKS/State/ServerStateMachine.swift index acb51f9d..44a76d47 100644 --- a/Sources/NIOSOCKS/State/ServerStateMachine.swift +++ b/Sources/NIOSOCKS/State/ServerStateMachine.swift @@ -28,6 +28,7 @@ enum ServerState: Hashable { struct ServerStateMachine: Hashable { private var state: ServerState + private var authenticationMethod: AuthenticationMethod? var proxyEstablished: Bool { switch self.state { @@ -118,7 +119,7 @@ extension ServerStateMachine { self.state = .waitingForClientGreeting } - mutating func sendAuthenticationMethod(_ method: SelectedAuthenticationMethod) throws { + mutating func sendAuthenticationMethod(_ selected: SelectedAuthenticationMethod) throws { switch self.state { case .waitingToSendAuthenticationMethod: () @@ -131,7 +132,13 @@ extension ServerStateMachine { .error: throw SOCKSError.InvalidServerState() } - self.state = .authenticating + + self.authenticationMethod = selected.method + if selected.method == .noneRequired { + self.state = .waitingForClientRequest + } else { + self.state = .authenticating + } } mutating func sendServerResponse(_ response: SOCKSResponse) throws { @@ -155,35 +162,25 @@ extension ServerStateMachine { } } - mutating func sendData() throws { + mutating func sendAuthenticationData(_ data: ByteBuffer, complete: Bool) throws { switch self.state { case .authenticating: - () - case .inactive, - .waitingForClientGreeting, - .waitingToSendAuthenticationMethod, - .waitingForClientRequest, - .waitingToSendResponse, - .active, - .error: - throw SOCKSError.InvalidServerState() - } - } - - mutating func authenticationComplete() throws { - switch self.state { - case .authenticating: - () + break + case .waitingForClientRequest: + guard self.authenticationMethod == .noneRequired, complete, data.readableBytes == 0 else { + throw SOCKSError.InvalidServerState() + } case .inactive, .waitingForClientGreeting, .waitingToSendAuthenticationMethod, - .waitingForClientRequest, .waitingToSendResponse, .active, .error: throw SOCKSError.InvalidServerState() } - self.state = .waitingForClientRequest + if complete { + self.state = .waitingForClientRequest + } } } diff --git a/Tests/NIOSOCKSTests/SOCKSServerHandshakeHandler+Tests+XCTest.swift b/Tests/NIOSOCKSTests/SOCKSServerHandshakeHandler+Tests+XCTest.swift index f6e10eab..3c6d3583 100644 --- a/Tests/NIOSOCKSTests/SOCKSServerHandshakeHandler+Tests+XCTest.swift +++ b/Tests/NIOSOCKSTests/SOCKSServerHandshakeHandler+Tests+XCTest.swift @@ -32,6 +32,11 @@ extension SOCKSServerHandlerTests { ("testOutboundErrorsAreHandled", testOutboundErrorsAreHandled), ("testFlushOnHandlerRemoved", testFlushOnHandlerRemoved), ("testForceHandlerRemovalAfterAuth", testForceHandlerRemovalAfterAuth), + ("testAutoAuthenticationComplete", testAutoAuthenticationComplete), + ("testAutoAuthenticationCompleteWithManualCompletion", testAutoAuthenticationCompleteWithManualCompletion), + ("testEagerClientRequestBeforeAuthenticationComplete", testEagerClientRequestBeforeAuthenticationComplete), + ("testManualAuthenticationFailureExtraBytes", testManualAuthenticationFailureExtraBytes), + ("testManualAuthenticationFailureInvalidCompletion", testManualAuthenticationFailureInvalidCompletion), ] } } diff --git a/Tests/NIOSOCKSTests/SOCKSServerHandshakeHandler+Tests.swift b/Tests/NIOSOCKSTests/SOCKSServerHandshakeHandler+Tests.swift index c7ae79ad..e993ded6 100644 --- a/Tests/NIOSOCKSTests/SOCKSServerHandshakeHandler+Tests.swift +++ b/Tests/NIOSOCKSTests/SOCKSServerHandshakeHandler+Tests.swift @@ -152,7 +152,7 @@ class SOCKSServerHandlerTests: XCTestCase { // tests dripfeeding to ensure we buffer data correctly func testTypicalWorkflowDripfeed() { - let expectedGreeting = ClientGreeting(methods: [.noneRequired]) + let expectedGreeting = ClientGreeting(methods: [.gssapi]) let expectedRequest = SOCKSRequest(command: .connect, addressType: .address(try! .init(ipAddress: "127.0.0.1", port: 80))) let expectedData = ByteBuffer(string: "1234") let testHandler = PromiseTestHandler( @@ -168,16 +168,15 @@ class SOCKSServerHandlerTests: XCTestCase { self.assertOutputBuffer([]) self.writeInbound([0x01]) self.assertOutputBuffer([]) - self.writeInbound([0x00]) + self.writeInbound([0x01]) self.assertOutputBuffer([]) XCTAssertTrue(testHandler.hadGreeting) // write the auth selection - XCTAssertNoThrow(try self.channel.writeOutbound(ServerMessage.selectedAuthenticationMethod(.init(method: .noneRequired)))) - self.assertOutputBuffer([0x05, 0x00]) + XCTAssertNoThrow(try self.channel.writeOutbound(ServerMessage.selectedAuthenticationMethod(.init(method: .gssapi)))) + self.assertOutputBuffer([0x05, 0x01]) - // finish authentication - nothing should be written - // as this is informing the state machine only + // finish authentication with some bytes XCTAssertNoThrow(try self.channel.writeOutbound(ServerMessage.authenticationData(ByteBuffer(bytes: [0xFF, 0xFF]), complete: true))) self.assertOutputBuffer([0xFF, 0xFF]) @@ -217,10 +216,11 @@ class SOCKSServerHandlerTests: XCTestCase { func testForceHandlerRemovalAfterAuth() { // go through auth - self.writeInbound([0x05, 0x01, 0x00]) - self.writeOutbound(.selectedAuthenticationMethod(.init(method: .noneRequired))) - self.assertOutputBuffer([0x05, 0x00]) - XCTAssertNoThrow(try self.handler.stateMachine.authenticationComplete()) + self.writeInbound([0x05, 0x01, 0x01]) + self.writeOutbound(.selectedAuthenticationMethod(.init(method: .gssapi))) + self.assertOutputBuffer([0x05, 0x01]) + self.writeOutbound(.authenticationData(ByteBuffer(), complete: true)) + self.assertOutputBuffer([]) self.writeInbound([0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80]) self.writeOutbound(.response(.init(reply: .succeeded, boundAddress: .address(try! .init(ipAddress: "127.0.0.1", port: 80))))) self.assertOutputBuffer([0x05, 0x00, 0x00, 0x01, 127, 0, 0, 1, 0, 80]) @@ -229,4 +229,88 @@ class SOCKSServerHandlerTests: XCTestCase { // removing the handler, it should fail XCTAssertThrowsError(try self.channel.writeOutbound(ServerMessage.authenticationData(ByteBuffer(string: "hello, world!"), complete: false))) } + + func testAutoAuthenticationComplete() { + + // server selects none-required, this should mean we can continue without + // having to manually inform the state machine + self.writeInbound([0x05, 0x01, 0x00]) + self.writeOutbound(.selectedAuthenticationMethod(.init(method: .noneRequired))) + self.assertOutputBuffer([0x05, 0x00]) + + // if we try and write the request then the data would be read + // as authentication data, and so the server wouldn't reply + // with a response + self.writeInbound([0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80]) + self.writeOutbound(.response(.init(reply: .succeeded, boundAddress: .address(try! .init(ipAddress: "127.0.0.1", port: 80))))) + self.assertOutputBuffer([0x05, 0x00, 0x00, 0x01, 127, 0, 0, 1, 0, 80]) + } + + func testAutoAuthenticationCompleteWithManualCompletion() { + + // server selects none-required, this should mean we can continue without + // having to manually inform the state machine. However, informing the state + // machine manually shouldn't break anything. + self.writeInbound([0x05, 0x01, 0x00]) + self.writeOutbound(.selectedAuthenticationMethod(.init(method: .noneRequired))) + self.assertOutputBuffer([0x05, 0x00]) + + // complete authentication, but nothing should be written + // to the network + self.writeOutbound(.authenticationData(ByteBuffer(), complete: true)) + self.assertOutputBuffer([]) + + // if we try and write the request then the data would be read + // as authentication data, and so the server wouldn't reply + // with a response + self.writeInbound([0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80]) + self.writeOutbound(.response(.init(reply: .succeeded, boundAddress: .address(try! .init(ipAddress: "127.0.0.1", port: 80))))) + self.assertOutputBuffer([0x05, 0x00, 0x00, 0x01, 127, 0, 0, 1, 0, 80]) + } + + func testEagerClientRequestBeforeAuthenticationComplete() { + + // server selects none-required, this should mean we can continue without + // having to manually inform the state machine. However, informing the state + // machine manually shouldn't break anything. + self.writeInbound([0x05, 0x01, 0x01]) + self.assertInbound(.greeting(.init(methods: [.gssapi]))) + self.writeOutbound(.selectedAuthenticationMethod(.init(method: .gssapi))) + self.assertOutputBuffer([0x05, 0x01]) + + // at this point authentication isn't complete + // so if the client sends a request then the + // server will read those as authentication bytes + self.writeInbound([0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80]) + self.assertInbound(.authenticationData(ByteBuffer(bytes: [0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80]))) + } + + func testManualAuthenticationFailureExtraBytes() { + // server selects none-required, this should mean we can continue without + // having to manually inform the state machine. However, informing the state + // machine manually shouldn't break anything. + self.writeInbound([0x05, 0x01, 0x00]) + self.writeOutbound(.selectedAuthenticationMethod(.init(method: .noneRequired))) + self.assertOutputBuffer([0x05, 0x00]) + + // invalid authentication completion + // we've selected `noneRequired`, so no + // bytes should be written + XCTAssertThrowsError(try self.channel.writeOutbound(ServerMessage.authenticationData(ByteBuffer(bytes: [0x00]), complete: true))) + } + + func testManualAuthenticationFailureInvalidCompletion() { + // server selects none-required, this should mean we can continue without + // having to manually inform the state machine. However, informing the state + // machine manually shouldn't break anything. + self.writeInbound([0x05, 0x01, 0x00]) + self.writeOutbound(.selectedAuthenticationMethod(.init(method: .noneRequired))) + self.assertOutputBuffer([0x05, 0x00]) + + // invalid authentication completion + // authentication should have already completed + // as we selected `noneRequired`, so sending + // `complete = false` should be an error + XCTAssertThrowsError(try self.channel.writeOutbound(ServerMessage.authenticationData(ByteBuffer(bytes: []), complete: false))) + } } diff --git a/Tests/NIOSOCKSTests/ServerStateMachine+Tests.swift b/Tests/NIOSOCKSTests/ServerStateMachine+Tests.swift index ff03d8f6..e28edaed 100644 --- a/Tests/NIOSOCKSTests/ServerStateMachine+Tests.swift +++ b/Tests/NIOSOCKSTests/ServerStateMachine+Tests.swift @@ -34,9 +34,6 @@ public class ServerStateMachineTests: XCTestCase { XCTAssertNoThrow(try stateMachine.sendAuthenticationMethod(.init(method: .noneRequired))) XCTAssertFalse(stateMachine.proxyEstablished) - // authentication is now finished, as we didn't send any - XCTAssertNoThrow(try stateMachine.authenticationComplete()) - // send the client request var request = ByteBuffer(bytes: [0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80]) XCTAssertNoThrow(try stateMachine.receiveBuffer(&request)) @@ -61,7 +58,6 @@ public class ServerStateMachineTests: XCTestCase { XCTAssertNoThrow(try stateMachine.connectionEstablished()) XCTAssertNoThrow(try stateMachine.receiveBuffer(&greeting)) XCTAssertNoThrow(try stateMachine.sendAuthenticationMethod(.init(method: .noneRequired))) - XCTAssertNoThrow(try stateMachine.authenticationComplete()) // write some invalid bytes from the client // the state machine should throw