From fc31edbdad98e99211cd78c361b8faec303dbc67 Mon Sep 17 00:00:00 2001 From: Rick Newton-Rogers Date: Fri, 25 Oct 2024 15:46:16 +0100 Subject: [PATCH 1/3] Migrate CI to use GitHub Actions. Motivation: To migrate to GitHub actions and centralised infrastructure. Modifications: Changes of note: * Adopt swift-format using rules from SwiftNIO * Remove scripts and docker files which are no longer needed Result: Feature parity with old CI. --- .github/workflows/pull_request.yml | 30 + .github/workflows/pull_request_label.yml | 18 + .github/workflows/scheduled.yml | 17 + .licenseignore | 40 + .swift-format | 62 ++ CONTRIBUTING.md | 5 +- Package.swift | 66 +- .../HTTPServerWithQuiescingDemo/main.swift | 38 +- .../NIOExtras/DebugInboundEventsHandler.swift | 10 +- .../DebugOutboundEventsHandler.swift | 14 +- .../NIOExtras/FixedLengthFrameDecoder.swift | 7 +- .../NIOExtras/HTTP1ProxyConnectHandler.swift | 25 +- .../JSONRPCFraming+ContentLengthHeader.swift | 25 +- .../LengthFieldBasedFrameDecoder.swift | 39 +- Sources/NIOExtras/LengthFieldPrepender.swift | 21 +- Sources/NIOExtras/LineBasedFrameDecoder.swift | 21 +- Sources/NIOExtras/NIOExtrasError.swift | 2 +- .../NIOExtras/NIOLengthFieldBitLength.swift | 4 +- Sources/NIOExtras/PCAPRingBuffer.swift | 6 +- Sources/NIOExtras/QuiescingHelper.swift | 5 +- .../NIOExtras/RequestResponseHandler.swift | 7 +- .../RequestResponseWithIDHandler.swift | 15 +- Sources/NIOExtras/WritePCAPHandler.swift | 416 ++++++---- .../Benchmark.swift | 2 +- .../HTTP1PCAPPerformanceTests.swift | 21 +- .../HTTP1PerformanceTestFramework.swift | 31 +- .../HTTP1RawPerformanceTests.swift | 10 +- .../HTTP1RollingPCAPPerformanceTests.swift | 22 +- .../Measurement.swift | 4 +- .../PCAPPerformanceTest.swift | 13 +- .../RollingPCAPPerformanceTest.swift | 15 +- Sources/NIOExtrasPerformanceTester/main.swift | 34 +- .../NIOHTTPCompression/HTTPCompression.swift | 41 +- .../HTTPDecompression.swift | 14 +- .../HTTPRequestCompressor.swift | 36 +- .../HTTPRequestDecompressor.swift | 11 +- .../HTTPResponseCompressor.swift | 50 +- .../HTTPResponseDecompressor.swift | 19 +- .../HTTPTypeConversion.swift | 21 +- .../HTTP2HeadersStateMachine.swift | 10 +- .../NIOHTTPTypesHTTP2/HTTP2ToHTTPCodec.swift | 46 +- .../HTTPTypeConversion.swift | 11 +- Sources/NIONFS3/MountTypes+Mount.swift | 24 +- Sources/NIONFS3/MountTypes+Null.swift | 4 +- Sources/NIONFS3/MountTypes+Unmount.swift | 4 +- Sources/NIONFS3/NFSCallDecoder.swift | 2 +- .../NIONFS3/NFSFileSystem+FuturesAPI.swift | 3 +- Sources/NIONFS3/NFSFileSystem.swift | 49 +- Sources/NIONFS3/NFSFileSystemHandler.swift | 56 +- Sources/NIONFS3/NFSFileSystemInvoker.swift | 2 +- .../NIONFS3/NFSFileSystemServerHandler.swift | 66 +- Sources/NIONFS3/NFSReplyDecoder.swift | 6 +- Sources/NIONFS3/NFSTypes+Access.swift | 43 +- Sources/NIONFS3/NFSTypes+Common.swift | 288 ++++--- Sources/NIONFS3/NFSTypes+Containers.swift | 152 ++-- Sources/NIONFS3/NFSTypes+FSInfo.swift | 77 +- Sources/NIONFS3/NFSTypes+FSStat.swift | 57 +- Sources/NIONFS3/NFSTypes+Getattr.swift | 9 +- Sources/NIONFS3/NFSTypes+Lookup.swift | 29 +- Sources/NIONFS3/NFSTypes+Null.swift | 4 +- Sources/NIONFS3/NFSTypes+PathConf.swift | 56 +- Sources/NIONFS3/NFSTypes+Read.swift | 29 +- Sources/NIONFS3/NFSTypes+ReadDir.swift | 88 +- Sources/NIONFS3/NFSTypes+ReadDirPlus.swift | 109 ++- Sources/NIONFS3/NFSTypes+Readlink.swift | 11 +- Sources/NIONFS3/NFSTypes+SetAttr.swift | 37 +- Sources/NIONFS3/RPCTypes.swift | 44 +- .../Channel Handlers/SOCKSClientHandler.swift | 64 +- .../SOCKSServerHandshakeHandler.swift | 39 +- .../Messages/AuthenticationMethod.swift | 14 +- .../NIOSOCKS/Messages/ClientGreeting.swift | 20 +- Sources/NIOSOCKS/Messages/Errors.swift | 18 +- Sources/NIOSOCKS/Messages/Helpers.swift | 12 +- Sources/NIOSOCKS/Messages/Messages.swift | 16 +- Sources/NIOSOCKS/Messages/SOCKSRequest.swift | 68 +- Sources/NIOSOCKS/Messages/SOCKSResponse.swift | 44 +- .../SelectedAuthenticationMethod.swift | 16 +- .../NIOSOCKS/State/ClientStateMachine.swift | 45 +- .../NIOSOCKS/State/ServerStateMachine.swift | 102 +-- Sources/NIOSOCKSClient/main.swift | 8 +- Sources/NIOWritePCAPDemo/main.swift | 37 +- Sources/NIOWritePartialPCAPDemo/main.swift | 59 +- .../DebugInboundEventsHandlerTest.swift | 35 +- .../DebugOutboundEventsHandlerTest.swift | 25 +- .../FixedLengthFrameDecoderTest.swift | 57 +- .../HTTP1ProxyConnectHandlerTests.swift | 21 +- ...amingContentLengthHeaderDecoderTests.swift | 37 +- ...amingContentLengthHeaderEncoderTests.swift | 51 +- .../LengthFieldBasedFrameDecoderTest.swift | 657 +++++++++------ .../LengthFieldPrependerTest.swift | 373 +++++---- .../LineBasedFrameDecoderTest.swift | 158 ++-- Tests/NIOExtrasTests/PCAPRingBufferTest.swift | 118 +-- .../NIOExtrasTests/QuiescingHelperTest.swift | 5 +- .../RequestResponseHandlerTest.swift | 3 +- .../RequestResponseWithIDHandlerTest.swift | 143 +++- .../SynchronizedFileSinkTests.swift | 44 +- .../NIOExtrasTests/WritePCAPHandlerTest.swift | 779 +++++++++++------- .../HTTPRequestCompressorTest.swift | 121 +-- .../HTTPRequestDecompressorTest.swift | 101 ++- .../HTTPResponseCompressorTest.swift | 281 ++++--- .../HTTPResponseDecompressorTest.swift | 173 ++-- .../NIOHTTPTypesHTTP1Tests.swift | 94 ++- .../NIOHTTPTypesHTTP2Tests.swift | 33 +- Tests/NIONFS3Tests/NFS3FileSystemTests.swift | 56 +- Tests/NIONFS3Tests/NFS3ReplyEncoderTest.swift | 106 ++- Tests/NIONFS3Tests/NFS3RoundtripTests.swift | 469 +++++++---- .../NIOSOCKSTests/ClientGreeting+Tests.swift | 16 +- Tests/NIOSOCKSTests/ClientRequest+Tests.swift | 58 +- .../ClientStateMachine+Tests.swift | 33 +- Tests/NIOSOCKSTests/Helpers+Tests.swift | 35 +- .../NIOSOCKSTests/MethodSelection+Tests.swift | 9 +- .../SOCKSServerHandshakeHandler+Tests.swift | 147 ++-- .../NIOSOCKSTests/ServerResponse+Tests.swift | 15 +- .../ServerStateMachine+Tests.swift | 27 +- .../SocksClientHandler+Tests.swift | 155 ++-- docker/Dockerfile | 25 - docker/docker-compose.2204.510.yaml | 21 - docker/docker-compose.2204.58.yaml | 21 - docker/docker-compose.2204.59.yaml | 21 - docker/docker-compose.2204.main.yaml | 20 - docker/docker-compose.yaml | 41 - scripts/check-docs.sh | 23 - scripts/check_no_api_breakages.sh | 130 --- scripts/soundness.sh | 142 ---- 124 files changed, 4670 insertions(+), 3324 deletions(-) create mode 100644 .github/workflows/pull_request.yml create mode 100644 .github/workflows/pull_request_label.yml create mode 100644 .github/workflows/scheduled.yml create mode 100644 .licenseignore create mode 100644 .swift-format delete mode 100644 docker/Dockerfile delete mode 100644 docker/docker-compose.2204.510.yaml delete mode 100644 docker/docker-compose.2204.58.yaml delete mode 100644 docker/docker-compose.2204.59.yaml delete mode 100644 docker/docker-compose.2204.main.yaml delete mode 100644 docker/docker-compose.yaml delete mode 100755 scripts/check-docs.sh delete mode 100755 scripts/check_no_api_breakages.sh delete mode 100755 scripts/soundness.sh diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml new file mode 100644 index 00000000..1335ed2a --- /dev/null +++ b/.github/workflows/pull_request.yml @@ -0,0 +1,30 @@ +name: PR + +on: + pull_request: + types: [opened, reopened, synchronize] + +jobs: + soundness: + name: Soundness + uses: swiftlang/github-workflows/.github/workflows/soundness.yml@main + with: + license_header_check_project_name: "SwiftNIO" + unit-tests: + name: Unit tests + uses: apple/swift-nio/.github/workflows/unit_tests.yml@main + with: + linux_5_9_arguments_override: "-Xswiftc -warnings-as-errors --explicit-target-dependency-import-check error" + linux_5_10_arguments_override: "-Xswiftc -warnings-as-errors --explicit-target-dependency-import-check error" + linux_6_0_arguments_override: "-Xswiftc -warnings-as-errors --explicit-target-dependency-import-check error -Xswiftc -require-explicit-sendable" + linux_nightly_6_0_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -require-explicit-sendable" + linux_nightly_main_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -require-explicit-sendable" + + cxx-interop: + name: Cxx interop + uses: apple/swift-nio/.github/workflows/cxx_interop.yml@main + + swift-6-language-mode: + name: Swift 6 Language Mode + uses: apple/swift-nio/.github/workflows/swift_6_language_mode.yml@main + if: false # Disabled for now. diff --git a/.github/workflows/pull_request_label.yml b/.github/workflows/pull_request_label.yml new file mode 100644 index 00000000..86f199f3 --- /dev/null +++ b/.github/workflows/pull_request_label.yml @@ -0,0 +1,18 @@ +name: PR label + +on: + pull_request: + types: [labeled, unlabeled, opened, reopened, synchronize] + +jobs: + semver-label-check: + name: Semantic Version label check + runs-on: ubuntu-latest + timeout-minutes: 1 + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + persist-credentials: false + - name: Check for Semantic Version label + uses: apple/swift-nio/.github/actions/pull_request_semver_label_checker@main diff --git a/.github/workflows/scheduled.yml b/.github/workflows/scheduled.yml new file mode 100644 index 00000000..b971be4d --- /dev/null +++ b/.github/workflows/scheduled.yml @@ -0,0 +1,17 @@ +name: Scheduled + +on: + schedule: + - cron: "0 8,20 * * *" + +jobs: + unit-tests: + name: Unit tests + uses: apple/swift-nio/.github/workflows/unit_tests.yml@main + with: + linux_5_8_enabled: false + linux_5_9_arguments_override: "-Xswiftc -warnings-as-errors --explicit-target-dependency-import-check error" + linux_5_10_arguments_override: "-Xswiftc -warnings-as-errors --explicit-target-dependency-import-check error" + linux_6_0_arguments_override: "-Xswiftc -warnings-as-errors --explicit-target-dependency-import-check error -Xswiftc -require-explicit-sendable" + linux_nightly_6_0_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -require-explicit-sendable" + linux_nightly_main_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -require-explicit-sendable" diff --git a/.licenseignore b/.licenseignore new file mode 100644 index 00000000..769f312a --- /dev/null +++ b/.licenseignore @@ -0,0 +1,40 @@ +.gitignore +**/.gitignore +.licenseignore +.gitattributes +.git-blame-ignore-revs +.mailfilter +.mailmap +.spi.yml +.swift-format +.editorconfig +.github/* +*.md +*.txt +*.yml +*.yaml +*.json +Package.swift +**/Package.swift +Package@-*.swift +**/Package@-*.swift +Package.resolved +**/Package.resolved +Makefile +*.modulemap +**/*.modulemap +**/*.docc/* +*.xcprivacy +**/*.xcprivacy +*.symlink +**/*.symlink +Dockerfile +**/Dockerfile +Snippets/* +dev/alloc-limits-from-test-output +dev/boxed-existentials.d +dev/git.commit.template +dev/lldb-smoker +dev/make-single-file-spm +dev/malloc-aggregation.d +dev/update-alloc-limits-to-last-completed-ci-build diff --git a/.swift-format b/.swift-format new file mode 100644 index 00000000..7fa06fb3 --- /dev/null +++ b/.swift-format @@ -0,0 +1,62 @@ +{ + "version" : 1, + "indentation" : { + "spaces" : 4 + }, + "tabWidth" : 4, + "fileScopedDeclarationPrivacy" : { + "accessLevel" : "private" + }, + "spacesAroundRangeFormationOperators" : false, + "indentConditionalCompilationBlocks" : false, + "indentSwitchCaseLabels" : false, + "lineBreakAroundMultilineExpressionChainComponents" : false, + "lineBreakBeforeControlFlowKeywords" : false, + "lineBreakBeforeEachArgument" : true, + "lineBreakBeforeEachGenericRequirement" : true, + "lineLength" : 120, + "maximumBlankLines" : 1, + "respectsExistingLineBreaks" : true, + "prioritizeKeepingFunctionOutputTogether" : true, + "rules" : { + "AllPublicDeclarationsHaveDocumentation" : false, + "AlwaysUseLiteralForEmptyCollectionInit" : false, + "AlwaysUseLowerCamelCase" : false, + "AmbiguousTrailingClosureOverload" : true, + "BeginDocumentationCommentWithOneLineSummary" : false, + "DoNotUseSemicolons" : true, + "DontRepeatTypeInStaticProperties" : true, + "FileScopedDeclarationPrivacy" : true, + "FullyIndirectEnum" : true, + "GroupNumericLiterals" : true, + "IdentifiersMustBeASCII" : true, + "NeverForceUnwrap" : false, + "NeverUseForceTry" : false, + "NeverUseImplicitlyUnwrappedOptionals" : false, + "NoAccessLevelOnExtensionDeclaration" : true, + "NoAssignmentInExpressions" : true, + "NoBlockComments" : true, + "NoCasesWithOnlyFallthrough" : true, + "NoEmptyTrailingClosureParentheses" : true, + "NoLabelsInCasePatterns" : true, + "NoLeadingUnderscores" : false, + "NoParensAroundConditions" : true, + "NoVoidReturnOnFunctionSignature" : true, + "OmitExplicitReturns" : true, + "OneCasePerLine" : true, + "OneVariableDeclarationPerLine" : true, + "OnlyOneTrailingClosureArgument" : true, + "OrderedImports" : true, + "ReplaceForEachWithForLoop" : true, + "ReturnVoidInsteadOfEmptyTuple" : true, + "UseEarlyExits" : false, + "UseExplicitNilCheckInConditions" : false, + "UseLetInEveryBoundCaseVariable" : false, + "UseShorthandTypeNames" : true, + "UseSingleLinePropertyGetter" : false, + "UseSynthesizedInitializer" : false, + "UseTripleSlashForDocumentationComments" : true, + "UseWhereClausesInForLoops" : false, + "ValidateDocumentationComments" : false + } +} diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ea200583..d32f8b7d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -63,10 +63,9 @@ We require that your commit messages match our template. The easiest way to do t git config commit.template dev/git.commit.template -### Make sure Tests work on Linux +### Run CI checks locally -SwiftNIO uses XCTest to run tests on both macOS and Linux. While the macOS version of XCTest is able to use the Objective-C runtime to discover tests at execution time, the Linux version is not. -For this reason, whenever you add new tests **you have to run a script** that generates the hooks needed to run those tests on Linux, or our CI will complain that the tests are not all present on Linux. To do this, merely execute `ruby ./scripts/generate_linux_tests.rb` at the root of the package and check the changes it made. +You can run the Github Actions workflows locally using [act](https://github.com/nektos/act). For detailed steps on how to do this please see [https://github.com/swiftlang/github-workflows?tab=readme-ov-file#running-workflows-locally](https://github.com/swiftlang/github-workflows?tab=readme-ov-file#running-workflows-locally). ## How to contribute your work diff --git a/Package.swift b/Package.swift index 1b6323cb..56e57858 100644 --- a/Package.swift +++ b/Package.swift @@ -22,7 +22,8 @@ var targets: [PackageDescription.Target] = [ .product(name: "NIO", package: "swift-nio"), .product(name: "NIOCore", package: "swift-nio"), .product(name: "NIOHTTP1", package: "swift-nio"), - ]), + ] + ), .target( name: "NIOHTTPCompression", dependencies: [ @@ -30,7 +31,8 @@ var targets: [PackageDescription.Target] = [ .product(name: "NIO", package: "swift-nio"), .product(name: "NIOCore", package: "swift-nio"), .product(name: "NIOHTTP1", package: "swift-nio"), - ]), + ] + ), .executableTarget( name: "HTTPServerWithQuiescingDemo", dependencies: [ @@ -38,7 +40,8 @@ var targets: [PackageDescription.Target] = [ .product(name: "NIOCore", package: "swift-nio"), .product(name: "NIOPosix", package: "swift-nio"), .product(name: "NIOHTTP1", package: "swift-nio"), - ]), + ] + ), .executableTarget( name: "NIOWritePCAPDemo", dependencies: [ @@ -46,7 +49,8 @@ var targets: [PackageDescription.Target] = [ .product(name: "NIOCore", package: "swift-nio"), .product(name: "NIOPosix", package: "swift-nio"), .product(name: "NIOHTTP1", package: "swift-nio"), - ]), + ] + ), .executableTarget( name: "NIOWritePartialPCAPDemo", dependencies: [ @@ -54,7 +58,8 @@ var targets: [PackageDescription.Target] = [ .product(name: "NIOCore", package: "swift-nio"), .product(name: "NIOPosix", package: "swift-nio"), .product(name: "NIOHTTP1", package: "swift-nio"), - ]), + ] + ), .executableTarget( name: "NIOExtrasPerformanceTester", dependencies: [ @@ -63,26 +68,30 @@ var targets: [PackageDescription.Target] = [ .product(name: "NIOPosix", package: "swift-nio"), .product(name: "NIOEmbedded", package: "swift-nio"), .product(name: "NIOHTTP1", package: "swift-nio"), - ]), + ] + ), .target( name: "NIOSOCKS", dependencies: [ .product(name: "NIO", package: "swift-nio"), .product(name: "NIOCore", package: "swift-nio"), - ]), + ] + ), .executableTarget( name: "NIOSOCKSClient", dependencies: [ .product(name: "NIOCore", package: "swift-nio"), .product(name: "NIOPosix", package: "swift-nio"), - "NIOSOCKS" - ]), + "NIOSOCKS", + ] + ), .target( name: "CNIOExtrasZlib", dependencies: [], linkerSettings: [ .linkedLibrary("z") - ]), + ] + ), .testTarget( name: "NIOExtrasTests", dependencies: [ @@ -92,7 +101,8 @@ var targets: [PackageDescription.Target] = [ .product(name: "NIOPosix", package: "swift-nio"), .product(name: "NIOTestUtils", package: "swift-nio"), .product(name: "NIOConcurrencyHelpers", package: "swift-nio"), - ]), + ] + ), .testTarget( name: "NIOHTTPCompressionTests", dependencies: [ @@ -102,19 +112,22 @@ var targets: [PackageDescription.Target] = [ .product(name: "NIOEmbedded", package: "swift-nio"), .product(name: "NIOHTTP1", package: "swift-nio"), .product(name: "NIOConcurrencyHelpers", package: "swift-nio"), - ]), + ] + ), .testTarget( name: "NIOSOCKSTests", dependencies: [ "NIOSOCKS", .product(name: "NIOCore", package: "swift-nio"), .product(name: "NIOEmbedded", package: "swift-nio"), - ]), + ] + ), .target( name: "NIONFS3", dependencies: [ - .product(name: "NIOCore", package: "swift-nio"), - ]), + .product(name: "NIOCore", package: "swift-nio") + ] + ), .testTarget( name: "NIONFS3Tests", dependencies: [ @@ -122,35 +135,41 @@ var targets: [PackageDescription.Target] = [ .product(name: "NIOCore", package: "swift-nio"), .product(name: "NIOEmbedded", package: "swift-nio"), .product(name: "NIOTestUtils", package: "swift-nio"), - ]), + ] + ), .target( name: "NIOHTTPTypes", dependencies: [ .product(name: "HTTPTypes", package: "swift-http-types"), .product(name: "NIOCore", package: "swift-nio"), - ]), + ] + ), .target( name: "NIOHTTPTypesHTTP1", dependencies: [ "NIOHTTPTypes", .product(name: "NIOHTTP1", package: "swift-nio"), - ]), + ] + ), .target( name: "NIOHTTPTypesHTTP2", dependencies: [ "NIOHTTPTypes", .product(name: "NIOHTTP2", package: "swift-nio-http2"), - ]), + ] + ), .testTarget( name: "NIOHTTPTypesHTTP1Tests", dependencies: [ - "NIOHTTPTypesHTTP1", - ]), + "NIOHTTPTypesHTTP1" + ] + ), .testTarget( name: "NIOHTTPTypesHTTP2Tests", dependencies: [ - "NIOHTTPTypesHTTP2", - ]), + "NIOHTTPTypesHTTP2" + ] + ), ] let package = Package( @@ -166,7 +185,6 @@ let package = Package( dependencies: [ .package(url: "https://github.com/apple/swift-nio.git", from: "2.67.0"), .package(url: "https://github.com/apple/swift-nio-http2.git", from: "1.27.0"), - .package(url: "https://github.com/apple/swift-docc-plugin.git", from: "1.0.0"), .package(url: "https://github.com/apple/swift-http-types.git", from: "1.0.0"), ], targets: targets diff --git a/Sources/HTTPServerWithQuiescingDemo/main.swift b/Sources/HTTPServerWithQuiescingDemo/main.swift index bb300d8a..491fc241 100644 --- a/Sources/HTTPServerWithQuiescingDemo/main.swift +++ b/Sources/HTTPServerWithQuiescingDemo/main.swift @@ -13,11 +13,10 @@ //===----------------------------------------------------------------------===// import Dispatch - import NIOCore -import NIOPosix -import NIOHTTP1 import NIOExtras +import NIOHTTP1 +import NIOPosix private final class HTTPHandler: ChannelInboundHandler { typealias InboundIn = HTTPServerRequestPart @@ -28,21 +27,35 @@ private final class HTTPHandler: ChannelInboundHandler { switch req { case .head(let head): guard head.version == HTTPVersion(major: 1, minor: 1) else { - context.write(self.wrapOutboundOut(.head(HTTPResponseHead(version: head.version, status: .badRequest))), promise: nil) + context.write( + self.wrapOutboundOut(.head(HTTPResponseHead(version: head.version, status: .badRequest))), + promise: nil + ) context.writeAndFlush(self.wrapOutboundOut(.end(nil))).whenComplete { (_: Result<(), Error>) in context.close(promise: nil) } return } case .body: - () // ignore + () // ignore case .end: var buffer = context.channel.allocator.buffer(capacity: 128) buffer.writeStaticString("received request; waiting 30s then finishing up request\n") - buffer.writeStaticString("press Ctrl+C in the server's terminal or run the following command to initiate server shutdown\n") - buffer.writeString(" kill -INT \(getpid())\n") - context.write(self.wrapOutboundOut(.head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), - status: .ok))), promise: nil) + buffer.writeStaticString( + "press Ctrl+C in the server's terminal or run the following command to initiate server shutdown\n" + ) + buffer.writeString(" kill -INT \(getpid())\n") // ignore-unacceptable-language + context.write( + self.wrapOutboundOut( + .head( + HTTPResponseHead( + version: HTTPVersion(major: 1, minor: 1), + status: .ok + ) + ) + ), + promise: nil + ) context.writeAndFlush(self.wrapOutboundOut(.body(.byteBuffer(buffer))), promise: nil) buffer.clear() buffer.writeStaticString("done with the request now\n") @@ -53,7 +66,7 @@ private final class HTTPHandler: ChannelInboundHandler { } } } - + func errorCaught(context: ChannelHandlerContext, error: Error) { print(error) } @@ -90,7 +103,10 @@ private func runServer() throws { .childChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) .childChannelOption(ChannelOptions.maxMessagesPerRead, value: 1) .childChannelInitializer { channel in - channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: true, withErrorHandling: true).flatMap { + channel.pipeline.configureHTTPServerPipeline( + withPipeliningAssistance: true, + withErrorHandling: true + ).flatMap { channel.pipeline.addHandler(HTTPHandler()) } } diff --git a/Sources/NIOExtras/DebugInboundEventsHandler.swift b/Sources/NIOExtras/DebugInboundEventsHandler.swift index b15590b4..14ad1e47 100644 --- a/Sources/NIOExtras/DebugInboundEventsHandler.swift +++ b/Sources/NIOExtras/DebugInboundEventsHandler.swift @@ -11,6 +11,8 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// +import NIOCore + #if canImport(Darwin) import Darwin #elseif canImport(Musl) @@ -19,8 +21,6 @@ import Musl import Glibc #endif -import NIOCore - /// `ChannelInboundHandler` that prints all inbound events that pass through the pipeline by default, /// overridable by providing your own closure for custom logging. See ``DebugOutboundEventsHandler`` for outbound events. public class DebugInboundEventsHandler: ChannelInboundHandler { @@ -50,12 +50,12 @@ public class DebugInboundEventsHandler: ChannelInboundHandler { /// An error was caught. case errorCaught(Error) } - - var logger: (Event, ChannelHandlerContext) -> () + + var logger: (Event, ChannelHandlerContext) -> Void /// Initialiser. /// - Parameter logger: Method for logging events which occur. - public init(logger: @escaping (Event, ChannelHandlerContext) -> () = DebugInboundEventsHandler.defaultPrint) { + public init(logger: @escaping (Event, ChannelHandlerContext) -> Void = DebugInboundEventsHandler.defaultPrint) { self.logger = logger } diff --git a/Sources/NIOExtras/DebugOutboundEventsHandler.swift b/Sources/NIOExtras/DebugOutboundEventsHandler.swift index b9e04a4c..13e9828f 100644 --- a/Sources/NIOExtras/DebugOutboundEventsHandler.swift +++ b/Sources/NIOExtras/DebugOutboundEventsHandler.swift @@ -12,6 +12,8 @@ // //===----------------------------------------------------------------------===// +import NIOCore + #if canImport(Darwin) import Darwin #elseif canImport(Musl) @@ -20,8 +22,6 @@ import Musl import Glibc #endif -import NIOCore - /// ChannelOutboundHandler that prints all outbound events that pass through the pipeline by default, /// overridable by providing your own closure for custom logging. See ``DebugInboundEventsHandler`` for inbound events. public class DebugOutboundEventsHandler: ChannelOutboundHandler { @@ -50,12 +50,12 @@ public class DebugOutboundEventsHandler: ChannelOutboundHandler { case triggerUserOutboundEvent(event: Any) } - var logger: (Event, ChannelHandlerContext) -> () + var logger: (Event, ChannelHandlerContext) -> Void /// Initialiser. /// - parameters: /// - logger: Method for logging events which happen. - public init(logger: @escaping (Event, ChannelHandlerContext) -> () = DebugOutboundEventsHandler.defaultPrint) { + public init(logger: @escaping (Event, ChannelHandlerContext) -> Void = DebugOutboundEventsHandler.defaultPrint) { self.logger = logger } @@ -73,7 +73,7 @@ public class DebugOutboundEventsHandler: ChannelOutboundHandler { /// Called to request that the `Channel` bind to a specific `SocketAddress`. /// - parameters: /// - context: The `ChannelHandlerContext` which this `ChannelHandler` belongs to. - /// - to: The `SocketAddress` to which this `Channel` should bind. + /// - address: The `SocketAddress` to which this `Channel` should bind. /// - promise: The `EventLoopPromise` which should be notified once the operation completes, or nil if no notification should take place. public func bind(context: ChannelHandlerContext, to address: SocketAddress, promise: EventLoopPromise?) { logger(.bind(address: address), context) @@ -84,7 +84,7 @@ public class DebugOutboundEventsHandler: ChannelOutboundHandler { /// Called to request that the `Channel` connect to a given `SocketAddress`. /// - parameters: /// - context: The `ChannelHandlerContext` which this `ChannelHandler` belongs to. - /// - to: The `SocketAddress` to which the the `Channel` should connect. + /// - address: The `SocketAddress` to which the the `Channel` should connect. /// - promise: The `EventLoopPromise` which should be notified once the operation completes, or nil if no notification should take place. public func connect(context: ChannelHandlerContext, to address: SocketAddress, promise: EventLoopPromise?) { logger(.connect(address: address), context) @@ -146,7 +146,7 @@ public class DebugOutboundEventsHandler: ChannelOutboundHandler { /// Print textual event description to stdout. /// - parameters: /// - event: The ``Event`` to print. - /// - in: The context the event occured in. + /// - context: The context the event occured in. public static func defaultPrint(event: Event, in context: ChannelHandlerContext) { let message: String switch event { diff --git a/Sources/NIOExtras/FixedLengthFrameDecoder.swift b/Sources/NIOExtras/FixedLengthFrameDecoder.swift index 0386ba9d..ba72537f 100644 --- a/Sources/NIOExtras/FixedLengthFrameDecoder.swift +++ b/Sources/NIOExtras/FixedLengthFrameDecoder.swift @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// import NIOCore + /// /// A decoder that splits the received `ByteBuffer` by a fixed number /// of bytes. For example, if you received the following four fragmented packets: @@ -68,7 +69,11 @@ public final class FixedLengthFrameDecoder: ByteToMessageDecoder { /// - buffer: Buffer containing data. /// - seenEOF: If end of file has been seen. /// - Returns: needMoreData always as all data is consumed. - public func decodeLast(context: ChannelHandlerContext, buffer: inout ByteBuffer, seenEOF: Bool) throws -> DecodingState { + public func decodeLast( + context: ChannelHandlerContext, + buffer: inout ByteBuffer, + seenEOF: Bool + ) throws -> DecodingState { while case .continue = try self.decode(context: context, buffer: &buffer) {} if buffer.readableBytes > 0 { context.fireErrorCaught(NIOExtrasErrors.LeftOverBytesError(leftOverBytes: buffer)) diff --git a/Sources/NIOExtras/HTTP1ProxyConnectHandler.swift b/Sources/NIOExtras/HTTP1ProxyConnectHandler.swift index 962573bf..8ee5fe9c 100644 --- a/Sources/NIOExtras/HTTP1ProxyConnectHandler.swift +++ b/Sources/NIOExtras/HTTP1ProxyConnectHandler.swift @@ -58,18 +58,20 @@ public final class NIOHTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableC /// - headers: Headers to supply to the proxy server as part of the CONNECT request /// - deadline: Deadline for the CONNECT request /// - promise: Promise with which the result of the connect operation is communicated - public init(targetHost: String, - targetPort: Int, - headers: HTTPHeaders, - deadline: NIODeadline, - promise: EventLoopPromise?) { + public init( + targetHost: String, + targetPort: Int, + headers: HTTPHeaders, + deadline: NIODeadline, + promise: EventLoopPromise? + ) { self.targetHost = targetHost self.targetPort = targetPort self.headers = headers self.deadline = deadline self.promise = promise - self.bufferedWrittenMessages = MarkedCircularBuffer(initialCapacity: 16) // matches CircularBuffer default + self.bufferedWrittenMessages = MarkedCircularBuffer(initialCapacity: 16) // matches CircularBuffer default } public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { @@ -98,7 +100,7 @@ public final class NIOHTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableC context.flush() } } - + } context.leavePipeline(removalToken: removalToken) @@ -297,7 +299,11 @@ public final class NIOHTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableC } /// Proxy response contains unexpected status - public static func invalidProxyResponseHead(_ head: HTTPResponseHead, file: String = #file, line: UInt = #line) -> Error { + public static func invalidProxyResponseHead( + _ head: HTTPResponseHead, + file: String = #file, + line: UInt = #line + ) -> Error { Error(error: .invalidProxyResponseHead(head: head), file: file, line: line) } @@ -357,7 +363,7 @@ public final class NIOHTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableC extension NIOHTTP1ProxyConnectHandler.Error: Hashable { // compare only the kind of error, not the associated response head public static func == (lhs: Self, rhs: Self) -> Bool { - return lhs.errorCode == rhs.errorCode + lhs.errorCode == rhs.errorCode } public func hash(into hasher: inout Hasher) { @@ -365,7 +371,6 @@ extension NIOHTTP1ProxyConnectHandler.Error: Hashable { } } - extension NIOHTTP1ProxyConnectHandler.Error: CustomStringConvertible { public var description: String { "\(self.store.details.description) (\(self.store.file): \(self.store.line))" diff --git a/Sources/NIOExtras/JSONRPCFraming+ContentLengthHeader.swift b/Sources/NIOExtras/JSONRPCFraming+ContentLengthHeader.swift index adc3fc44..8611b4e4 100644 --- a/Sources/NIOExtras/JSONRPCFraming+ContentLengthHeader.swift +++ b/Sources/NIOExtras/JSONRPCFraming+ContentLengthHeader.swift @@ -127,7 +127,7 @@ extension NIOJSONRPCFraming { // Given that we're waiting for the end of a header block or a new header field, it's sensible to // check if this might be the end of the block. if buffer.readableBytesView.starts(with: "\r\n".utf8) { - buffer.moveReaderIndex(forwardBy: 2) // skip \r\n\r\n + buffer.moveReaderIndex(forwardBy: 2) // skip \r\n\r\n return try self.processHeaderBlockEnd(context: context) } @@ -135,7 +135,7 @@ extension NIOJSONRPCFraming { // must always have a colon (or we don't have enough data). if let colonIndex = buffer.readableBytesView.firstIndex(of: UInt8(ascii: ":")) { let headerName = buffer.readString(length: colonIndex - buffer.readableBytesView.startIndex)! - buffer.moveReaderIndex(forwardBy: 1) // skip the colon + buffer.moveReaderIndex(forwardBy: 1) // skip the colon self.state = .waitingForHeaderValue(name: headerName.trimmed().lowercased()) return .continue } @@ -153,7 +153,8 @@ extension NIOJSONRPCFraming { if headerName == "content-length" { // Yes, let's parse the int. let headerValue = buffer.readString(length: newlineIndex - buffer.readableBytesView.startIndex + 1)! - if let length = UInt32(headerValue.trimmed()) { // anything more than 4GB or negative doesn't make sense + // anything more than 4GB or negative doesn't make sense + if let length = UInt32(headerValue.trimmed()) { self.payloadLength = .init(length) } else { throw DecodingError.illegalContentLengthHeaderValue(headerValue) @@ -166,7 +167,7 @@ extension NIOJSONRPCFraming { // but in any case, we're now waiting for a new header or the end of the header block again. self.state = .waitingForHeaderNameOrHeaderBlockEnd return .continue - case .waitingForPayload(length: let length): + case .waitingForPayload(let length): // That's the easiest case, let's just wait until we have enough data. if let payload = buffer.readSlice(length: length) { // Cool, we got enough data, let's go back waiting for a new header block. @@ -183,14 +184,16 @@ extension NIOJSONRPCFraming { /// Decode all remaining data. /// Invoked when the `Channel` is being brought down. /// Reports error through `ByteToMessageDecoderError.leftoverDataWhenDone` if not all data is consumed. - /// - parameters: + /// - Parameters: /// - context: Calling context. /// - buffer: Buffer of data to decode. /// - seenEOF: If the end of file has been seen. - /// - returns: .needMoreData always as all data should be consumed. - public mutating func decodeLast(context: ChannelHandlerContext, - buffer: inout ByteBuffer, - seenEOF: Bool) throws -> DecodingState { + /// - Returns: .needMoreData always as all data should be consumed. + public mutating func decodeLast( + context: ChannelHandlerContext, + buffer: inout ByteBuffer, + seenEOF: Bool + ) throws -> DecodingState { // Last chance to decode anything. while try self.decode(context: context, buffer: &buffer) == .continue {} @@ -211,12 +214,10 @@ extension String { } let lastElementIndex = self.reversed().firstIndex(where: { !$0.isWhitespace })! - return self[firstElementIndex ..< lastElementIndex.base] + return self[firstElementIndex.. DecodingState { - + if case .waitingForHeader = self.readState { try self.readNextLengthFieldToState(buffer: &buffer) } - + guard case .waitingForFrame(let frameLength) = self.readState else { return .needMoreData } - + guard let frameBuffer = try self.readNextFrame(buffer: &buffer, frameLength: frameLength) else { return .needMoreData } - + context.fireChannelRead(self.wrapInboundOut(frameBuffer)) return .continue @@ -170,7 +172,11 @@ public final class LengthFieldBasedFrameDecoder: ByteToMessageDecoder { /// - buffer: The data to decode /// - seenEOF: If End of File has been seen. /// - Returns: .needMoreData always as all data has been consumed. - public func decodeLast(context: ChannelHandlerContext, buffer: inout ByteBuffer, seenEOF: Bool) throws -> DecodingState { + public func decodeLast( + context: ChannelHandlerContext, + buffer: inout ByteBuffer, + seenEOF: Bool + ) throws -> DecodingState { // we'll just try to decode as much as we can as usually while case .continue = try self.decode(context: context, buffer: &buffer) {} if buffer.readableBytes > 0 { @@ -199,13 +205,13 @@ public final class LengthFieldBasedFrameDecoder: ByteToMessageDecoder { /// - buffer: The buffer containing the frame data. /// - frameLength: The length of the frame data to be read. private func readNextFrame(buffer: inout ByteBuffer, frameLength: Int) throws -> ByteBuffer? { - + guard let contentsFieldSlice = buffer.readSlice(length: frameLength) else { return nil } self.readState = .waitingForHeader - + return contentsFieldSlice } @@ -238,9 +244,10 @@ public final class LengthFieldBasedFrameDecoder: ByteToMessageDecoder { return size } } - + if let frameLength = frameLength, - frameLength > LengthFieldBasedFrameDecoder.maxSupportedLengthFieldSize { + frameLength > LengthFieldBasedFrameDecoder.maxSupportedLengthFieldSize + { throw NIOLengthFieldBasedFrameDecoderError.lengthFieldValueLargerThanMaxSupportedSize } return frameLength diff --git a/Sources/NIOExtras/LengthFieldPrepender.swift b/Sources/NIOExtras/LengthFieldPrepender.swift index af8d8612..591f7d05 100644 --- a/Sources/NIOExtras/LengthFieldPrepender.swift +++ b/Sources/NIOExtras/LengthFieldPrepender.swift @@ -24,16 +24,15 @@ extension ByteBuffer { precondition(integer & 0xFF_FF_FF == integer, "integer value does not fit into 24 bit integer") switch endianness { case .little: - return writeInteger(UInt8(integer & 0xFF), endianness: .little) + - writeInteger(UInt16((integer >> 8) & 0xFF_FF), endianness: .little) + return writeInteger(UInt8(integer & 0xFF), endianness: .little) + + writeInteger(UInt16((integer >> 8) & 0xFF_FF), endianness: .little) case .big: - return writeInteger(UInt16((integer >> 8) & 0xFF_FF), endianness: .big) + - writeInteger(UInt8(integer & 0xFF), endianness: .big) + return writeInteger(UInt16((integer >> 8) & 0xFF_FF), endianness: .big) + + writeInteger(UInt8(integer & 0xFF), endianness: .big) } } } - /// Error types from ``LengthFieldPrepender`` public enum LengthFieldPrependerError: Error { /// More data was given than the maximum encodable length value. @@ -64,7 +63,7 @@ public final class LengthFieldPrepender: ChannelOutboundHandler { case four /// Eight bytes case eight - + fileprivate var bitLength: NIOLengthFieldBitLength { switch self { case .one: return .oneByte @@ -82,7 +81,7 @@ public final class LengthFieldPrepender: ChannelOutboundHandler { private let lengthFieldLength: NIOLengthFieldBitLength private let lengthFieldEndianness: Endianness - + private var lengthBuffer: ByteBuffer? /// Create ``LengthFieldPrepender`` with a given length field length. @@ -101,8 +100,8 @@ public final class LengthFieldPrepender: ChannelOutboundHandler { public init(lengthFieldBitLength: NIOLengthFieldBitLength, lengthFieldEndianness: Endianness = .big) { // The value contained in the length field must be able to be represented by an integer type on the platform. // ie. .eight == 64bit which would not fit into the Int type on a 32bit platform. - precondition(lengthFieldBitLength.length <= Int.bitWidth/8) - + precondition(lengthFieldBitLength.length <= Int.bitWidth / 8) + self.lengthFieldLength = lengthFieldBitLength self.lengthFieldEndianness = lengthFieldEndianness } @@ -111,12 +110,12 @@ public final class LengthFieldPrepender: ChannelOutboundHandler { let dataBuffer = self.unwrapOutboundIn(data) let dataLength = dataBuffer.readableBytes - + guard dataLength <= self.lengthFieldLength.max else { promise?.fail(LengthFieldPrependerError.messageDataTooLongForLengthField) return } - + var dataLengthBuffer: ByteBuffer if let existingBuffer = self.lengthBuffer { diff --git a/Sources/NIOExtras/LineBasedFrameDecoder.swift b/Sources/NIOExtras/LineBasedFrameDecoder.swift index 53c5454b..ffaf379e 100644 --- a/Sources/NIOExtras/LineBasedFrameDecoder.swift +++ b/Sources/NIOExtras/LineBasedFrameDecoder.swift @@ -35,13 +35,13 @@ public class LineBasedFrameDecoder: ByteToMessageDecoder & NIOSingleStepByteToMe public typealias InboundIn = ByteBuffer /// `ByteBuffer`s will be passed to the next stage. public typealias InboundOut = ByteBuffer - + @available(*, deprecated, message: "No longer used") public var cumulationBuffer: ByteBuffer? // keep track of the last scan offset from the buffer's reader index (if we didn't find the delimiter) private var lastScanOffset = 0 - - public init() { } + + public init() {} /// Decode data in the supplied buffer. /// - Parameters: @@ -62,7 +62,7 @@ public class LineBasedFrameDecoder: ByteToMessageDecoder & NIOSingleStepByteToMe /// - buffer: Buffer containing data to decode. /// - Returns: The decoded object or `nil` if we require more bytes. public func decode(buffer: inout NIOCore.ByteBuffer) throws -> NIOCore.ByteBuffer? { - return try self.findNextFrame(buffer: &buffer) + try self.findNextFrame(buffer: &buffer) } /// Decode all remaining data. @@ -72,7 +72,11 @@ public class LineBasedFrameDecoder: ByteToMessageDecoder & NIOSingleStepByteToMe /// - buffer: Buffer containing the data to decode. /// - seenEOF: Has end of file been seen. /// - Returns: Always .needMoreData as all data will be consumed. - public func decodeLast(context: ChannelHandlerContext, buffer: inout ByteBuffer, seenEOF: Bool) throws -> DecodingState { + public func decodeLast( + context: ChannelHandlerContext, + buffer: inout ByteBuffer, + seenEOF: Bool + ) throws -> DecodingState { while try self.decode(context: context, buffer: &buffer) == .continue {} if buffer.readableBytes > 0 { context.fireErrorCaught(NIOExtrasErrors.LeftOverBytesError(leftOverBytes: buffer)) @@ -102,10 +106,11 @@ public class LineBasedFrameDecoder: ByteToMessageDecoder & NIOSingleStepByteToMe private func findNextFrame(buffer: inout ByteBuffer) throws -> ByteBuffer? { let view = buffer.readableBytesView.dropFirst(self.lastScanOffset) // look for the delimiter - if let delimiterIndex = view.firstIndex(of: 0x0A) { // '\n' + if let delimiterIndex = view.firstIndex(of: 0x0A) { // '\n' let length = delimiterIndex - buffer.readerIndex - let dropCarriageReturn = delimiterIndex > buffer.readableBytesView.startIndex && - buffer.readableBytesView[delimiterIndex - 1] == 0x0D // '\r' + let dropCarriageReturn = + delimiterIndex > buffer.readableBytesView.startIndex + && buffer.readableBytesView[delimiterIndex - 1] == 0x0D // '\r' let buff = buffer.readSlice(length: dropCarriageReturn ? length - 1 : length) // drop the delimiter (and trailing carriage return if appicable) buffer.moveReaderIndex(forwardBy: dropCarriageReturn ? 2 : 1) diff --git a/Sources/NIOExtras/NIOExtrasError.swift b/Sources/NIOExtras/NIOExtrasError.swift index 9750f563..122e5a09 100644 --- a/Sources/NIOExtras/NIOExtrasError.swift +++ b/Sources/NIOExtras/NIOExtrasError.swift @@ -14,7 +14,7 @@ import NIOCore /// Base type for errors from NIOExtras -public protocol NIOExtrasError: Equatable, Error { } +public protocol NIOExtrasError: Equatable, Error {} /// Errors that are raised in NIOExtras. public enum NIOExtrasErrors { diff --git a/Sources/NIOExtras/NIOLengthFieldBitLength.swift b/Sources/NIOExtras/NIOLengthFieldBitLength.swift index cb3aa955..2cb03df9 100644 --- a/Sources/NIOExtras/NIOLengthFieldBitLength.swift +++ b/Sources/NIOExtras/NIOLengthFieldBitLength.swift @@ -46,7 +46,7 @@ public struct NIOLengthFieldBitLength: Sendable { public static let thirtyTwoBits = NIOLengthFieldBitLength(bitLength: .bits32) /// Sixty-four bits - the same as ``eightBytes`` public static let sixtyFourBits = NIOLengthFieldBitLength(bitLength: .bits64) - + internal var length: Int { switch bitLength { case .bits8: @@ -61,7 +61,7 @@ public struct NIOLengthFieldBitLength: Sendable { return 8 } } - + internal var max: UInt { switch bitLength { case .bits8: diff --git a/Sources/NIOExtras/PCAPRingBuffer.swift b/Sources/NIOExtras/PCAPRingBuffer.swift index dabcd775..30379788 100644 --- a/Sources/NIOExtras/PCAPRingBuffer.swift +++ b/Sources/NIOExtras/PCAPRingBuffer.swift @@ -50,7 +50,7 @@ public class NIOPCAPRingBuffer { public convenience init(maximumFragments: Int) { self.init(maximumFragments: maximumFragments, maximumBytes: .max) } - + @discardableResult private func popFirst() -> ByteBuffer? { let popped = self.pcapFragments.popFirst() @@ -59,7 +59,7 @@ public class NIOPCAPRingBuffer { } return popped } - + private func append(_ buffer: ByteBuffer) { self.pcapFragments.append(buffer) self.pcapCurrentBytes += buffer.readableBytes @@ -91,7 +91,7 @@ public class NIOPCAPRingBuffer { /// Emit the captured data to a consuming function; then clear the captured data. /// - Returns: A circular buffer of captured fragments. public func emitPCAP() -> CircularBuffer { - let toReturn = self.pcapFragments // Copy before clearing. + let toReturn = self.pcapFragments // Copy before clearing. self.pcapFragments.removeAll(keepingCapacity: true) self.pcapCurrentBytes = 0 return toReturn diff --git a/Sources/NIOExtras/QuiescingHelper.swift b/Sources/NIOExtras/QuiescingHelper.swift index e66de895..cc8e2123 100644 --- a/Sources/NIOExtras/QuiescingHelper.swift +++ b/Sources/NIOExtras/QuiescingHelper.swift @@ -106,7 +106,10 @@ private final class ChannelCollector { if openChannels.isEmpty { self.shutdownCompleted() } else { - self.lifecycleState = .shuttingDown(openChannels: openChannels, fullyShutdownPromise: fullyShutdownPromise) + self.lifecycleState = .shuttingDown( + openChannels: openChannels, + fullyShutdownPromise: fullyShutdownPromise + ) } case .shutdownCompleted: diff --git a/Sources/NIOExtras/RequestResponseHandler.swift b/Sources/NIOExtras/RequestResponseHandler.swift index fadf88ad..d0d332a1 100644 --- a/Sources/NIOExtras/RequestResponseHandler.swift +++ b/Sources/NIOExtras/RequestResponseHandler.swift @@ -52,7 +52,6 @@ public final class RequestResponseHandler: ChannelDuplexHandl private var state: State = .operational private var promiseBuffer: CircularBuffer> - /// Create a new ``RequestResponseHandler``. /// /// - parameters: @@ -72,7 +71,7 @@ public final class RequestResponseHandler: ChannelDuplexHandl case .operational: let promiseBuffer = self.promiseBuffer self.promiseBuffer.removeAll() - promiseBuffer.forEach { promise in + for promise in promiseBuffer { promise.fail(NIOExtrasErrors.ClosedBeforeReceivingResponse()) } } @@ -101,8 +100,8 @@ public final class RequestResponseHandler: ChannelDuplexHandl let promiseBuffer = self.promiseBuffer self.promiseBuffer.removeAll() context.close(promise: nil) - promiseBuffer.forEach { - $0.fail(error) + for promise in promiseBuffer { + promise.fail(error) } } diff --git a/Sources/NIOExtras/RequestResponseWithIDHandler.swift b/Sources/NIOExtras/RequestResponseWithIDHandler.swift index ace1f591..7ff6e4a0 100644 --- a/Sources/NIOExtras/RequestResponseWithIDHandler.swift +++ b/Sources/NIOExtras/RequestResponseWithIDHandler.swift @@ -26,9 +26,11 @@ import NIOCore /// /// `NIORequestResponseWithIDHandler` does _not_ require that the `Response`s arrive on `Channel` in the same order as /// the `Request`s were submitted. They are matched by their `requestID` property (from `NIORequestIdentifiable`). -public final class NIORequestResponseWithIDHandler: ChannelDuplexHandler - where Request.RequestID == Response.RequestID { +public final class NIORequestResponseWithIDHandler< + Request: NIORequestIdentifiable, + Response: NIORequestIdentifiable +>: ChannelDuplexHandler +where Request.RequestID == Response.RequestID { public typealias InboundIn = Response public typealias InboundOut = Never public typealias OutboundIn = (Request, EventLoopPromise) @@ -78,7 +80,7 @@ public final class NIORequestResponseWithIDHandler Void private let mode: Mode - private let maxPayloadSize = Int(UInt16.max - 40 /* needs to fit into the IPv4 header which adds 40 */) + private let maxPayloadSize = Int(UInt16.max - 40) // needs to fit into the IPv4 header which adds 40 private let settings: Settings private var buffer: ByteBuffer! private var readInboundBytes: UInt64 = 0 @@ -182,7 +182,7 @@ public class NIOWritePCAPHandler: RemovableChannelHandler { private static let fakeLocalAddress = try! SocketAddress(ipAddress: "111.111.111.111", port: 1111) private static let fakeRemoteAddress = try! SocketAddress(ipAddress: "222.222.222.222", port: 2222) - + private var localAddress: SocketAddress? private var remoteAddress: SocketAddress? @@ -196,17 +196,20 @@ public class NIOWritePCAPHandler: RemovableChannelHandler { /// Initialize a ``NIOWritePCAPHandler``. /// /// - parameters: + /// - mode: Whether the handler should behave in the client or server mode. /// - fakeLocalAddress: Allows you to optionally override the local address to be different from the real one. /// - fakeRemoteAddress: Allows you to optionally override the remote address to be different from the real one. /// - settings: The settings for the ``NIOWritePCAPHandler``. /// - fileSink: The `fileSink` closure is called every time a new chunk of the `.pcap` file is ready to be /// written to disk or elsewhere. See ``SynchronizedFileSink`` for a convenient way to write to /// disk. - public init(mode: Mode, - fakeLocalAddress: SocketAddress? = nil, - fakeRemoteAddress: SocketAddress? = nil, - settings: Settings, - fileSink: @escaping (ByteBuffer) -> Void) { + public init( + mode: Mode, + fakeLocalAddress: SocketAddress? = nil, + fakeRemoteAddress: SocketAddress? = nil, + settings: Settings, + fileSink: @escaping (ByteBuffer) -> Void + ) { self.settings = settings self.fileSink = fileSink self.mode = mode @@ -221,26 +224,31 @@ public class NIOWritePCAPHandler: RemovableChannelHandler { /// Initialize a ``NIOWritePCAPHandler`` with default settings. /// /// - parameters: + /// - mode: Whether the handler should behave in the client or server mode. /// - fakeLocalAddress: Allows you to optionally override the local address to be different from the real one. /// - fakeRemoteAddress: Allows you to optionally override the remote address to be different from the real one. /// - fileSink: The `fileSink` closure is called every time a new chunk of the `.pcap` file is ready to be /// written to disk or elsewhere. See `NIOSynchronizedFileSink` for a convenient way to write to /// disk. - public convenience init(mode: Mode, - fakeLocalAddress: SocketAddress? = nil, - fakeRemoteAddress: SocketAddress? = nil, - fileSink: @escaping (ByteBuffer) -> Void) { - self.init(mode: mode, - fakeLocalAddress: fakeLocalAddress, - fakeRemoteAddress: fakeRemoteAddress, - settings: Settings(), - fileSink: fileSink) + public convenience init( + mode: Mode, + fakeLocalAddress: SocketAddress? = nil, + fakeRemoteAddress: SocketAddress? = nil, + fileSink: @escaping (ByteBuffer) -> Void + ) { + self.init( + mode: mode, + fakeLocalAddress: fakeLocalAddress, + fakeRemoteAddress: fakeRemoteAddress, + settings: Settings(), + fileSink: fileSink + ) } - + private func writeBuffer(_ buffer: ByteBuffer) { self.fileSink(buffer) } - + private func localAddress(context: ChannelHandlerContext) -> SocketAddress { if let localAddress = self.localAddress { return localAddress @@ -278,17 +286,17 @@ public class NIOWritePCAPHandler: RemovableChannelHandler { return self.localAddress(context: context) } } - + private func takeSensiblySizedPayload(buffer: inout ByteBuffer) -> ByteBuffer? { guard buffer.readableBytes > 0 else { return nil } - + return buffer.readSlice(length: min(buffer.readableBytes, self.maxPayloadSize)) } private func sequenceNumber(byteCount: UInt64) -> UInt32 { - return UInt32(byteCount % (UInt64(UInt32.max) + 1)) + UInt32(byteCount % (UInt64(UInt32.max) + 1)) } } @@ -304,7 +312,7 @@ extension NIOWritePCAPHandler: ChannelDuplexHandler { public func handlerAdded(context: ChannelHandlerContext) { self.buffer = context.channel.allocator.buffer(capacity: 256) } - + public func channelActive(context: ChannelHandlerContext) { self.buffer.clear() self.readInboundBytes = 1 @@ -312,37 +320,55 @@ extension NIOWritePCAPHandler: ChannelDuplexHandler { do { let clientAddress = self.clientAddress(context: context) let serverAddress = self.serverAddress(context: context) - try self.buffer.writePCAPRecord(.init(payloadLength: 0, - src: clientAddress, - dst: serverAddress, - tcp: TCPHeader(flags: [.syn], - ackNumber: nil, - sequenceNumber: 0, - srcPort: .init(clientAddress.port!), - dstPort: .init(serverAddress.port!)))) - try self.buffer.writePCAPRecord(.init(payloadLength: 0, - src: serverAddress, - dst: clientAddress, - tcp: TCPHeader(flags: [.syn, .ack], - ackNumber: 1, - sequenceNumber: 0, - srcPort: .init(serverAddress.port!), - dstPort: .init(clientAddress.port!)))) - try self.buffer.writePCAPRecord(.init(payloadLength: 0, - src: clientAddress, - dst: serverAddress, - tcp: TCPHeader(flags: [.ack], - ackNumber: 1, - sequenceNumber: 1, - srcPort: .init(clientAddress.port!), - dstPort: .init(serverAddress.port!)))) + try self.buffer.writePCAPRecord( + .init( + payloadLength: 0, + src: clientAddress, + dst: serverAddress, + tcp: TCPHeader( + flags: [.syn], + ackNumber: nil, + sequenceNumber: 0, + srcPort: .init(clientAddress.port!), + dstPort: .init(serverAddress.port!) + ) + ) + ) + try self.buffer.writePCAPRecord( + .init( + payloadLength: 0, + src: serverAddress, + dst: clientAddress, + tcp: TCPHeader( + flags: [.syn, .ack], + ackNumber: 1, + sequenceNumber: 0, + srcPort: .init(serverAddress.port!), + dstPort: .init(clientAddress.port!) + ) + ) + ) + try self.buffer.writePCAPRecord( + .init( + payloadLength: 0, + src: clientAddress, + dst: serverAddress, + tcp: TCPHeader( + flags: [.ack], + ackNumber: 1, + sequenceNumber: 1, + srcPort: .init(clientAddress.port!), + dstPort: .init(serverAddress.port!) + ) + ) + ) self.writeBuffer(self.buffer) } catch { context.fireErrorCaught(error) } context.fireChannelActive() } - + public func channelInactive(context: ChannelHandlerContext) { let didLocalInitiateTheClose: Bool switch self.closeState { @@ -354,48 +380,70 @@ extension NIOWritePCAPHandler: ChannelDuplexHandler { self.closeState = .closedInitiatorRemote didLocalInitiateTheClose = false } - + self.buffer.clear() do { - let closeInitiatorAddress = didLocalInitiateTheClose ? self.localAddress(context: context) : self.remoteAddress(context: context) - let closeRecipientAddress = didLocalInitiateTheClose ? self.remoteAddress(context: context) : self.localAddress(context: context) - let initiatorSeq = self.sequenceNumber(byteCount: didLocalInitiateTheClose ? - self.writtenOutboundBytes : self.readInboundBytes) - let recipientSeq = self.sequenceNumber(byteCount: didLocalInitiateTheClose ? - self.readInboundBytes : self.writtenOutboundBytes) - + let closeInitiatorAddress = + didLocalInitiateTheClose ? self.localAddress(context: context) : self.remoteAddress(context: context) + let closeRecipientAddress = + didLocalInitiateTheClose ? self.remoteAddress(context: context) : self.localAddress(context: context) + let initiatorSeq = self.sequenceNumber( + byteCount: didLocalInitiateTheClose ? self.writtenOutboundBytes : self.readInboundBytes + ) + let recipientSeq = self.sequenceNumber( + byteCount: didLocalInitiateTheClose ? self.readInboundBytes : self.writtenOutboundBytes + ) + // terminate the connection cleanly - try self.buffer.writePCAPRecord(.init(payloadLength: 0, - src: closeInitiatorAddress, - dst: closeRecipientAddress, - tcp: TCPHeader(flags: [.fin], - ackNumber: nil, - sequenceNumber: initiatorSeq, - srcPort: .init(closeInitiatorAddress.port!), - dstPort: .init(closeRecipientAddress.port!)))) - try self.buffer.writePCAPRecord(.init(payloadLength: 0, - src: closeRecipientAddress, - dst: closeInitiatorAddress, - tcp: TCPHeader(flags: [.ack, .fin], - ackNumber: initiatorSeq + 1, - sequenceNumber: recipientSeq, - srcPort: .init(closeRecipientAddress.port!), - dstPort: .init(closeInitiatorAddress.port!)))) - try self.buffer.writePCAPRecord(.init(payloadLength: 0, - src: closeInitiatorAddress, - dst: closeRecipientAddress, - tcp: TCPHeader(flags: [.ack], - ackNumber: recipientSeq + 1, - sequenceNumber: initiatorSeq + 1, - srcPort: .init(closeInitiatorAddress.port!), - dstPort: .init(closeRecipientAddress.port!)))) + try self.buffer.writePCAPRecord( + .init( + payloadLength: 0, + src: closeInitiatorAddress, + dst: closeRecipientAddress, + tcp: TCPHeader( + flags: [.fin], + ackNumber: nil, + sequenceNumber: initiatorSeq, + srcPort: .init(closeInitiatorAddress.port!), + dstPort: .init(closeRecipientAddress.port!) + ) + ) + ) + try self.buffer.writePCAPRecord( + .init( + payloadLength: 0, + src: closeRecipientAddress, + dst: closeInitiatorAddress, + tcp: TCPHeader( + flags: [.ack, .fin], + ackNumber: initiatorSeq + 1, + sequenceNumber: recipientSeq, + srcPort: .init(closeRecipientAddress.port!), + dstPort: .init(closeInitiatorAddress.port!) + ) + ) + ) + try self.buffer.writePCAPRecord( + .init( + payloadLength: 0, + src: closeInitiatorAddress, + dst: closeRecipientAddress, + tcp: TCPHeader( + flags: [.ack], + ackNumber: recipientSeq + 1, + sequenceNumber: initiatorSeq + 1, + srcPort: .init(closeInitiatorAddress.port!), + dstPort: .init(closeRecipientAddress.port!) + ) + ) + ) self.writeBuffer(self.buffer) } catch { context.fireErrorCaught(error) } context.fireChannelInactive() } - + public func channelRead(context: ChannelHandlerContext, data: NIOAny) { defer { context.fireChannelRead(data) @@ -403,24 +451,30 @@ extension NIOWritePCAPHandler: ChannelDuplexHandler { guard self.closeState == .notClosing else { return } - + let data = self.unwrapInboundIn(data) guard data.readableBytes > 0 else { return } - + self.buffer.clear() do { var data = data while var payloadToSend = self.takeSensiblySizedPayload(buffer: &data) { - try self.buffer.writePCAPRecord(.init(payloadLength: payloadToSend.readableBytes, - src: self.remoteAddress(context: context), - dst: self.localAddress(context: context), - tcp: TCPHeader(flags: [], - ackNumber: nil, - sequenceNumber: self.sequenceNumber(byteCount: self.readInboundBytes), - srcPort: .init(self.remoteAddress(context: context).port!), - dstPort: .init(self.localAddress(context: context).port!)))) + try self.buffer.writePCAPRecord( + .init( + payloadLength: payloadToSend.readableBytes, + src: self.remoteAddress(context: context), + dst: self.localAddress(context: context), + tcp: TCPHeader( + flags: [], + ackNumber: nil, + sequenceNumber: self.sequenceNumber(byteCount: self.readInboundBytes), + srcPort: .init(self.remoteAddress(context: context).port!), + dstPort: .init(self.localAddress(context: context).port!) + ) + ) + ) self.readInboundBytes += UInt64(payloadToSend.readableBytes) self.buffer.writeBuffer(&payloadToSend) } @@ -430,7 +484,7 @@ extension NIOWritePCAPHandler: ChannelDuplexHandler { context.fireErrorCaught(error) } } - + public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { var buffer = self.unwrapInboundIn(data) @@ -438,14 +492,20 @@ extension NIOWritePCAPHandler: ChannelDuplexHandler { do { self.buffer.clear() while var payloadToSend = self.takeSensiblySizedPayload(buffer: &buffer) { - try self.buffer.writePCAPRecord(.init(payloadLength: payloadToSend.readableBytes, - src: self.localAddress(context: context), - dst: self.remoteAddress(context: context), - tcp: TCPHeader(flags: [], - ackNumber: nil, - sequenceNumber: self.sequenceNumber(byteCount: self.writtenOutboundBytes), - srcPort: .init(self.localAddress(context: context).port!), - dstPort: .init(self.remoteAddress(context: context).port!)))) + try self.buffer.writePCAPRecord( + .init( + payloadLength: payloadToSend.readableBytes, + src: self.localAddress(context: context), + dst: self.remoteAddress(context: context), + tcp: TCPHeader( + flags: [], + ackNumber: nil, + sequenceNumber: self.sequenceNumber(byteCount: self.writtenOutboundBytes), + srcPort: .init(self.localAddress(context: context).port!), + dstPort: .init(self.remoteAddress(context: context).port!) + ) + ) + ) self.writtenOutboundBytes += UInt64(payloadToSend.readableBytes) self.buffer.writeBuffer(&payloadToSend) } @@ -467,15 +527,15 @@ extension NIOWritePCAPHandler: ChannelDuplexHandler { context.write(data, promise: promise) } } - + public func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { if let event = event as? ChannelEvent { if event == .inputClosed { switch self.closeState { case .closedInitiatorLocal: - () // fair enough, we already closed locally + () // fair enough, we already closed locally case .closedInitiatorRemote: - () // that's odd but okay + () // that's odd but okay case .notClosing: self.closeState = .closedInitiatorRemote } @@ -483,13 +543,13 @@ extension NIOWritePCAPHandler: ChannelDuplexHandler { } context.fireUserInboundEventTriggered(event) } - + public func close(context: ChannelHandlerContext, mode: CloseMode, promise: EventLoopPromise?) { switch self.closeState { case .closedInitiatorLocal: - () // weird, this looks like a double-close + () // weird, this looks like a double-close case .closedInitiatorRemote: - () // fair enough, already closed I guess + () // fair enough, already closed I guess case .notClosing: self.closeState = .closedInitiatorLocal } @@ -499,75 +559,75 @@ extension NIOWritePCAPHandler: ChannelDuplexHandler { extension ByteBuffer { mutating func writePCAPHeader() { - // guint32 magic_number; /* magic number */ - self.writeInteger(0xa1b2c3d4, endianness: .host, as: UInt32.self) - // guint16 version_major; /* major version number */ + // guint32 magic_number; // magic number + self.writeInteger(0xa1b2_c3d4, endianness: .host, as: UInt32.self) + // guint16 version_major; // major version number self.writeInteger(2, endianness: .host, as: UInt16.self) - // guint16 version_minor; /* minor version number * + // guint16 version_minor; // minor version number self.writeInteger(4, endianness: .host, as: UInt16.self) - // gint32 thiszone; /* GMT to local correction */ + // gint32 thiszone; // GMT to local correction self.writeInteger(0, endianness: .host, as: UInt32.self) - // guint32 sigfigs; /* accuracy of timestamps */ + // guint32 sigfigs; // accuracy of timestamps self.writeInteger(0, endianness: .host, as: UInt32.self) - // guint32 snaplen; /* max length of captured packets, in octets */ + // guint32 snaplen; // max length of captured packets, in octets self.writeInteger(.max, endianness: .host, as: UInt32.self) - // guint32 network; /* data link type */ + // guint32 network; // data link type self.writeInteger(0, endianness: .host, as: UInt32.self) } - + mutating func writePCAPRecord(_ record: PCAPRecordHeader) throws { let rawDataLength = record.payloadLength - let tcpLength = rawDataLength + 20 /* TCP header length */ + let tcpLength = rawDataLength + 20 // TCP header length // record - // guint32 ts_sec; /* timestamp seconds */ + // guint32 ts_sec; // timestamp seconds self.writeInteger(.init(record.time.tv_sec), endianness: .host, as: UInt32.self) - // guint32 ts_usec; /* timestamp microseconds */ + // guint32 ts_usec; // timestamp microseconds self.writeInteger(.init(record.time.tv_usec), endianness: .host, as: UInt32.self) // continued below ... switch record.addresses { case .v4(let la, let ra): - let ipv4WholeLength = tcpLength + 20 /* IPv4 header length, included in IPv4 */ - let recordLength = ipv4WholeLength + 4 /* 32 bits for protocol id */ - + let ipv4WholeLength = tcpLength + 20 // IPv4 header length, included in IPv4 + let recordLength = ipv4WholeLength + 4 // 32 bits for protocol id + // record, continued - // guint32 incl_len; /* number of octets of packet saved in file */ + // guint32 incl_len; // number of octets of packet saved in file self.writeInteger(.init(recordLength), endianness: .host, as: UInt32.self) - // guint32 orig_len; /* actual length of packet */ + // guint32 orig_len; // actual length of packet self.writeInteger(.init(recordLength), endianness: .host, as: UInt32.self) - - self.writeInteger(2, endianness: .host, as: UInt32.self) // IPv4 + + self.writeInteger(2, endianness: .host, as: UInt32.self) // IPv4 // IPv4 packet - self.writeInteger(0x45, as: UInt8.self) // IP version (4) & IHL (5) - self.writeInteger(0, as: UInt8.self) // DSCP + self.writeInteger(0x45, as: UInt8.self) // IP version (4) & IHL (5) + self.writeInteger(0, as: UInt8.self) // DSCP self.writeInteger(.init(ipv4WholeLength), as: UInt16.self) - - self.writeInteger(0, as: UInt16.self) // identification - self.writeInteger(0x4000 /* this set's "don't fragment" */, as: UInt16.self) // flags & fragment offset - self.writeInteger(.max /* we don't care about TTL */, as: UInt8.self) // TTL - self.writeInteger(6, as: UInt8.self) // TCP - self.writeInteger(0, as: UInt16.self) // checksum + + self.writeInteger(0, as: UInt16.self) // identification + self.writeInteger(0x4000, as: UInt16.self) // flags & fragment offset, 0x4000 sets "don't fragment" + self.writeInteger(.max, as: UInt8.self) // TTL, `.max` as we don't care about the TTL + self.writeInteger(6, as: UInt8.self) // TCP + self.writeInteger(0, as: UInt16.self) // checksum self.writeInteger(la.address.sin_addr.s_addr, endianness: .host, as: UInt32.self) self.writeInteger(ra.address.sin_addr.s_addr, endianness: .host, as: UInt32.self) case .v6(let la, let ra): let ipv6PayloadLength = tcpLength - let recordLength = ipv6PayloadLength + 4 /* 32 bits for protocol id */ + 40 /* IPv6 header length */ - + let recordLength = ipv6PayloadLength + 4 + 40 // IPv6 header length (+4 gives 32 bits for protocol id) + // record, continued - // guint32 incl_len; /* number of octets of packet saved in file */ + // guint32 incl_len; // number of octets of packet saved in file self.writeInteger(.init(recordLength), endianness: .host, as: UInt32.self) - // guint32 orig_len; /* actual length of packet */ + // guint32 orig_len; // actual length of packet self.writeInteger(.init(recordLength), endianness: .host, as: UInt32.self) - - self.writeInteger(24, endianness: .host, as: UInt32.self) // IPv6 - + + self.writeInteger(24, endianness: .host, as: UInt32.self) // IPv6 + // IPv6 packet - self.writeInteger(/* version */ (6 << 28), as: UInt32.self) // IP version (6) & fancy stuff + self.writeInteger((6 << 28), as: UInt32.self) // IP version (6) & fancy stuff self.writeInteger(.init(ipv6PayloadLength), as: UInt16.self) - self.writeInteger(6, as: UInt8.self) // TCP - self.writeInteger(.max /* we don't care about TTL */, as: UInt8.self) // hop limit (like TTL) + self.writeInteger(6, as: UInt8.self) // TCP + self.writeInteger(.max, as: UInt8.self) // hop limit (like TTL & we don't care about TTL) var laAddress = la.address withUnsafeBytes(of: &laAddress.sin6_addr) { ptr in @@ -585,13 +645,15 @@ extension ByteBuffer { self.writeInteger(record.tcp.srcPort, as: UInt16.self) self.writeInteger(record.tcp.dstPort, as: UInt16.self) - self.writeInteger(record.tcp.sequenceNumber, as: UInt32.self) // seq no - self.writeInteger(record.tcp.ackNumber ?? 0, as: UInt32.self) // ack no + self.writeInteger(record.tcp.sequenceNumber, as: UInt32.self) // seq no + self.writeInteger(record.tcp.ackNumber ?? 0, as: UInt32.self) // ack no - self.writeInteger(5 << 12 | UInt16(record.tcp.flags.rawValue), as: UInt16.self) // data offset + reserved bits + fancy stuff - self.writeInteger(.max /* we don't do actual window sizes */, as: UInt16.self) // window size - self.writeInteger(0xbad /* fake */, as: UInt16.self) // checksum - self.writeInteger(0, as: UInt16.self) // urgent pointer + // data offset + reserved bits + fancy stuff + self.writeInteger(5 << 12 | UInt16(record.tcp.flags.rawValue), as: UInt16.self) + + self.writeInteger(.max, as: UInt16.self) // window size, `.max` as we don't do actual window sizes + self.writeInteger(0xbad, as: UInt16.self) // checksum (a fake one) + self.writeInteger(0, as: UInt16.self) // urgent pointer } } @@ -606,27 +668,27 @@ extension NIOWritePCAPHandler { private let workQueue: DispatchQueue private let writesGroup = DispatchGroup() private let errorHandler: (Swift.Error) -> Void - private var state: State = .running /* protected by `workQueue` */ - + private var state: State = .running // protected by `workQueue` + public enum FileWritingMode { case appendToExistingPCAPFile case createNewPCAPFile } - + public struct Error: Swift.Error { public var errorCode: Int - + internal enum ErrorCode: Int { case cannotOpenFileError = 1 case cannotWriteToFileError } } - + private enum State { case running case error(Swift.Error) } - + /// Creates a `SynchronizedFileSink` for writing to a `.pcap` file at `path`. /// /// Typically, after you created a `SynchronizedFileSink`, you will hand `myFileSink.write` to @@ -652,9 +714,11 @@ extension NIOWritePCAPHandler { /// you must then `syncClose` the `SynchronizedFileSink`. When `errorHandler` has been /// called, no further writes will be attempted and `errorHandler` will also not be called /// again. - public static func fileSinkWritingToFile(path: String, - fileWritingMode: FileWritingMode = .createNewPCAPFile, - errorHandler: @escaping (Swift.Error) -> Void) throws -> SynchronizedFileSink { + public static func fileSinkWritingToFile( + path: String, + fileWritingMode: FileWritingMode = .createNewPCAPFile, + errorHandler: @escaping (Swift.Error) -> Void + ) throws -> SynchronizedFileSink { let oflag: CInt = fileWritingMode == FileWritingMode.createNewPCAPFile ? (O_TRUNC | O_CREAT) : O_APPEND let fd = try path.withCString { pathPtr -> CInt in let fd = open(pathPtr, O_WRONLY | oflag, 0o600) @@ -663,21 +727,25 @@ extension NIOWritePCAPHandler { } return fd } - + if fileWritingMode == .createNewPCAPFile { let writeOk = NIOWritePCAPHandler.pcapFileHeader.withUnsafeReadableBytes { ptr in - return sysWrite(fd, ptr.baseAddress, ptr.count) == ptr.count + sysWrite(fd, ptr.baseAddress, ptr.count) == ptr.count } guard writeOk else { throw SynchronizedFileSink.Error(errorCode: Error.ErrorCode.cannotWriteToFileError.rawValue) } } - return SynchronizedFileSink(fileHandle: NIOFileHandle(descriptor: fd), - errorHandler: errorHandler) + return SynchronizedFileSink( + fileHandle: NIOFileHandle(descriptor: fd), + errorHandler: errorHandler + ) } - - private init(fileHandle: NIOFileHandle, - errorHandler: @escaping (Swift.Error) -> Void) { + + private init( + fileHandle: NIOFileHandle, + errorHandler: @escaping (Swift.Error) -> Void + ) { self.fileHandle = fileHandle self.workQueue = DispatchQueue(label: "io.swiftnio.extras.WritePCAPHandler.SynchronizedFileSink.workQueue") self.errorHandler = errorHandler @@ -710,7 +778,7 @@ extension NIOWritePCAPHandler { } } } - + public func write(buffer: ByteBuffer) { self.workQueue.async(group: self.writesGroup) { guard case .running = self.state else { diff --git a/Sources/NIOExtrasPerformanceTester/Benchmark.swift b/Sources/NIOExtrasPerformanceTester/Benchmark.swift index 20c12c39..3aedc63a 100644 --- a/Sources/NIOExtrasPerformanceTester/Benchmark.swift +++ b/Sources/NIOExtrasPerformanceTester/Benchmark.swift @@ -23,6 +23,6 @@ func measureAndPrint(desc: String, benchmark bench: B) throws { bench.tearDown() } try measureAndPrint(desc: desc) { - return try bench.run() + try bench.run() } } diff --git a/Sources/NIOExtrasPerformanceTester/HTTP1PCAPPerformanceTests.swift b/Sources/NIOExtrasPerformanceTester/HTTP1PCAPPerformanceTests.swift index c4d24f6d..c8f3ffbe 100644 --- a/Sources/NIOExtrasPerformanceTester/HTTP1PCAPPerformanceTests.swift +++ b/Sources/NIOExtrasPerformanceTester/HTTP1PCAPPerformanceTests.swift @@ -12,9 +12,9 @@ // //===----------------------------------------------------------------------===// +import Foundation import NIOCore import NIOExtras -import Foundation class HTTP1ThreadedPCapPerformanceTest: HTTP1ThreadedPerformanceTest { private class SinkHolder { @@ -22,7 +22,8 @@ class HTTP1ThreadedPCapPerformanceTest: HTTP1ThreadedPerformanceTest { func setUp() throws { let outputFile = NSTemporaryDirectory() + "/" + UUID().uuidString - self.fileSink = try NIOWritePCAPHandler.SynchronizedFileSink.fileSinkWritingToFile(path: outputFile) { error in + self.fileSink = try NIOWritePCAPHandler.SynchronizedFileSink.fileSinkWritingToFile(path: outputFile) { + error in print("ERROR: \(error)") exit(1) } @@ -36,16 +37,20 @@ class HTTP1ThreadedPCapPerformanceTest: HTTP1ThreadedPerformanceTest { init() { let sinkHolder = SinkHolder() func addPCap(channel: Channel) -> EventLoopFuture { - let pcapHandler = NIOWritePCAPHandler(mode: .client, - fileSink: sinkHolder.fileSink.write) + let pcapHandler = NIOWritePCAPHandler( + mode: .client, + fileSink: sinkHolder.fileSink.write + ) return channel.pipeline.addHandler(pcapHandler, position: .first) } self.sinkHolder = sinkHolder - super.init(numberOfRepeats: 50, - numberOfClients: System.coreCount, - requestsPerClient: 500, - extraInitialiser: { channel in return addPCap(channel: channel) }) + super.init( + numberOfRepeats: 50, + numberOfClients: System.coreCount, + requestsPerClient: 500, + extraInitialiser: { channel in addPCap(channel: channel) } + ) } private let sinkHolder: SinkHolder diff --git a/Sources/NIOExtrasPerformanceTester/HTTP1PerformanceTestFramework.swift b/Sources/NIOExtrasPerformanceTester/HTTP1PerformanceTestFramework.swift index 46d1e8c0..01bb90d4 100644 --- a/Sources/NIOExtrasPerformanceTester/HTTP1PerformanceTestFramework.swift +++ b/Sources/NIOExtrasPerformanceTester/HTTP1PerformanceTestFramework.swift @@ -13,8 +13,8 @@ //===----------------------------------------------------------------------===// import NIOCore -import NIOPosix import NIOHTTP1 +import NIOPosix // MARK: Handlers final class SimpleHTTPServer: ChannelInboundHandler { @@ -96,7 +96,7 @@ final class RepeatedRequests: ChannelInboundHandler { return reqs } - var completedFuture: EventLoopFuture { return self.isDonePromise.futureResult } + var completedFuture: EventLoopFuture { self.isDonePromise.futureResult } func errorCaught(context: ChannelHandlerContext, error: Error) { context.channel.close(promise: nil) @@ -131,10 +131,12 @@ class HTTP1ThreadedPerformanceTest: Benchmark { var group: MultiThreadedEventLoopGroup! var serverChannel: Channel! - init(numberOfRepeats: Int, - numberOfClients: Int, - requestsPerClient: Int, - extraInitialiser: @escaping (Channel) -> EventLoopFuture) { + init( + numberOfRepeats: Int, + numberOfClients: Int, + requestsPerClient: Int, + extraInitialiser: @escaping (Channel) -> EventLoopFuture + ) { self.numberOfRepeats = numberOfRepeats self.numberOfClients = numberOfClients self.requestsPerClient = requestsPerClient @@ -169,13 +171,15 @@ class HTTP1ThreadedPerformanceTest: Benchmark { requestHandlers.reserveCapacity(self.numberOfClients) var clientChannels: [Channel] = [] clientChannels.reserveCapacity(self.numberOfClients) - for _ in 0 ..< self.numberOfClients { + for _ in 0...reduce(0, streamCompletedFutures, on: streamCompletedFutures.first!.eventLoop, +) + let requestsServed = EventLoopFuture.reduce( + 0, + streamCompletedFutures, + on: streamCompletedFutures.first!.eventLoop, + + + ) reqs.append(try! requestsServed.wait()) } return reqs.reduce(0, +) / self.numberOfRepeats diff --git a/Sources/NIOExtrasPerformanceTester/HTTP1RawPerformanceTests.swift b/Sources/NIOExtrasPerformanceTester/HTTP1RawPerformanceTests.swift index af5179c5..d3541461 100644 --- a/Sources/NIOExtrasPerformanceTester/HTTP1RawPerformanceTests.swift +++ b/Sources/NIOExtrasPerformanceTester/HTTP1RawPerformanceTests.swift @@ -18,9 +18,11 @@ import NIOHTTP1 class HTTP1ThreadedRawPerformanceTest: HTTP1ThreadedPerformanceTest { init() { - super.init(numberOfRepeats: 50, - numberOfClients: System.coreCount, - requestsPerClient: 500, - extraInitialiser: { channel in return channel.eventLoop.makeSucceededFuture(()) }) + super.init( + numberOfRepeats: 50, + numberOfClients: System.coreCount, + requestsPerClient: 500, + extraInitialiser: { channel in channel.eventLoop.makeSucceededFuture(()) } + ) } } diff --git a/Sources/NIOExtrasPerformanceTester/HTTP1RollingPCAPPerformanceTests.swift b/Sources/NIOExtrasPerformanceTester/HTTP1RollingPCAPPerformanceTests.swift index fb5fcecc..cc1314a5 100644 --- a/Sources/NIOExtrasPerformanceTester/HTTP1RollingPCAPPerformanceTests.swift +++ b/Sources/NIOExtrasPerformanceTester/HTTP1RollingPCAPPerformanceTests.swift @@ -18,16 +18,22 @@ import NIOExtras class HTTP1ThreadedRollingPCapPerformanceTest: HTTP1ThreadedPerformanceTest { init() { func addRollingPCap(channel: Channel) -> EventLoopFuture { - let pcapRingBuffer = NIOPCAPRingBuffer(maximumFragments: 25, - maximumBytes: 1_000_000) - let pcapHandler = NIOWritePCAPHandler(mode: .client, - fileSink: pcapRingBuffer.addFragment) + let pcapRingBuffer = NIOPCAPRingBuffer( + maximumFragments: 25, + maximumBytes: 1_000_000 + ) + let pcapHandler = NIOWritePCAPHandler( + mode: .client, + fileSink: pcapRingBuffer.addFragment + ) return channel.pipeline.addHandler(pcapHandler, position: .first) } - super.init(numberOfRepeats: 50, - numberOfClients: System.coreCount, - requestsPerClient: 500, - extraInitialiser: { channel in return addRollingPCap(channel: channel) }) + super.init( + numberOfRepeats: 50, + numberOfClients: System.coreCount, + requestsPerClient: 500, + extraInitialiser: { channel in addRollingPCap(channel: channel) } + ) } } diff --git a/Sources/NIOExtrasPerformanceTester/Measurement.swift b/Sources/NIOExtrasPerformanceTester/Measurement.swift index 549c4813..8a249d1a 100644 --- a/Sources/NIOExtrasPerformanceTester/Measurement.swift +++ b/Sources/NIOExtrasPerformanceTester/Measurement.swift @@ -23,7 +23,7 @@ public func measure(_ fn: () throws -> Int) rethrows -> [Double] { return Double(end - start) / Double(TimeAmount.seconds(1).nanoseconds) } - _ = try measureOne(fn) /* pre-heat and throw away */ + _ = try measureOne(fn) // pre-heat and throw away var measurements = Array(repeating: 0.0, count: 10) for i in 0..<10 { measurements[i] = try measureOne(fn) @@ -32,7 +32,7 @@ public func measure(_ fn: () throws -> Int) rethrows -> [Double] { return measurements } -public func measureAndPrint(desc: String, fn: () throws -> Int) rethrows -> Void { +public func measureAndPrint(desc: String, fn: () throws -> Int) rethrows { print("measuring\(warning): \(desc): ", terminator: "") let measurements = try measure(fn) print(measurements.reduce(into: "") { $0.append("\($1), ") }) diff --git a/Sources/NIOExtrasPerformanceTester/PCAPPerformanceTest.swift b/Sources/NIOExtrasPerformanceTester/PCAPPerformanceTest.swift index 19fc5ead..c4a93b14 100644 --- a/Sources/NIOExtrasPerformanceTester/PCAPPerformanceTest.swift +++ b/Sources/NIOExtrasPerformanceTester/PCAPPerformanceTest.swift @@ -12,10 +12,10 @@ // //===----------------------------------------------------------------------===// +import Foundation import NIOCore import NIOEmbedded import NIOExtras -import Foundation final class PCAPPerformanceTest: Benchmark { let numberOfRepeats: Int @@ -33,7 +33,7 @@ final class PCAPPerformanceTest: Benchmark { func tearDown() { try! FileManager.default.removeItem(atPath: self.outputFile) } - + func run() throws -> Int { let fileSink = try NIOWritePCAPHandler.SynchronizedFileSink.fileSinkWritingToFile(path: self.outputFile) { error in @@ -49,12 +49,13 @@ final class PCAPPerformanceTest: Benchmark { _ = try! channel.finish() } - let pcapHandler = NIOWritePCAPHandler(mode: .client, - fileSink: fileSink.write) + let pcapHandler = NIOWritePCAPHandler( + mode: .client, + fileSink: fileSink.write + ) try channel.pipeline.addHandler(pcapHandler, position: .first).wait() - - for _ in 0 ..< self.numberOfRepeats { + for _ in 0.. ByteBuffer { + mutating func compress( + inputBuffer: inout ByteBuffer, + allocator: ByteBufferAllocator, + finalise: Bool + ) -> ByteBuffer { assert(isActive) let flags = finalise ? Z_FINISH : Z_SYNC_FLUSH // don't compress an empty buffer if we aren't finishing the compress @@ -101,13 +112,13 @@ public enum NIOCompression { stream.oneShotDeflate(from: &inputBuffer, to: &outputBuffer, flag: flags) return outputBuffer } - + mutating func shutdown() { assert(isActive) isActive = false deflateEnd(&stream) } - + mutating func shutdownIfActive() { if isActive { isActive = false @@ -131,8 +142,10 @@ extension z_stream { from.readWithUnsafeMutableReadableBytes { dataPtr in let typedPtr = dataPtr.baseAddress!.assumingMemoryBound(to: UInt8.self) - let typedDataPtr = UnsafeMutableBufferPointer(start: typedPtr, - count: dataPtr.count) + let typedDataPtr = UnsafeMutableBufferPointer( + start: typedPtr, + count: dataPtr.count + ) self.avail_in = UInt32(typedDataPtr.count) self.next_in = typedDataPtr.baseAddress! @@ -151,8 +164,10 @@ extension z_stream { var rc = Z_OK buffer.writeWithUnsafeMutableBytes(minimumWritableBytes: buffer.capacity) { outputPtr in - let typedOutputPtr = UnsafeMutableBufferPointer(start: outputPtr.baseAddress!.assumingMemoryBound(to: UInt8.self), - count: outputPtr.count) + let typedOutputPtr = UnsafeMutableBufferPointer( + start: outputPtr.baseAddress!.assumingMemoryBound(to: UInt8.self), + count: outputPtr.count + ) self.avail_out = UInt32(typedOutputPtr.count) self.next_out = typedOutputPtr.baseAddress! rc = deflate(&self, flag) diff --git a/Sources/NIOHTTPCompression/HTTPDecompression.swift b/Sources/NIOHTTPCompression/HTTPDecompression.swift index e173a1c7..4ff0d44b 100644 --- a/Sources/NIOHTTPCompression/HTTPDecompression.swift +++ b/Sources/NIOHTTPCompression/HTTPDecompression.swift @@ -31,9 +31,9 @@ public enum NIOHTTPDecompression { /// - warning: Setting `limit` to `.none` leaves you vulnerable to denial of service attacks. public static let none = DecompressionLimit(limit: .none) /// Limit will be set on the request body size. - public static func size(_ value: Int) -> DecompressionLimit { return DecompressionLimit(limit: .size(value)) } + public static func size(_ value: Int) -> DecompressionLimit { DecompressionLimit(limit: .size(value)) } /// Limit will be set on a ratio between compressed body size and decompressed result. - public static func ratio(_ value: Int) -> DecompressionLimit { return DecompressionLimit(limit: .ratio(value)) } + public static func ratio(_ value: Int) -> DecompressionLimit { DecompressionLimit(limit: .ratio(value)) } func exceeded(compressed: Int, decompressed: Int) -> Bool { switch self.limit { @@ -76,7 +76,7 @@ public enum NIOHTTPDecompression { public static let truncatedData = Self(.truncatedData) public var description: String { - return String(describing: self.backing) + String(describing: self.backing) } } @@ -128,7 +128,7 @@ public enum NIOHTTPDecompression { /// The state would need to be reset to continue decoding a subsequent gzip member. /// This must be done if there is more data after a gzip member, in order for the decompression to be compliant with the gzip standard (RFC 1952). static let windowBitsWithAutomaticCompressionFormatDetection: Int32 = 15 + 32 - + private let limit: NIOHTTPDecompression.DecompressionLimit private var stream = z_stream() private var inflated = 0 @@ -137,7 +137,11 @@ public enum NIOHTTPDecompression { self.limit = limit } - mutating func decompress(part: inout ByteBuffer, buffer: inout ByteBuffer, compressedLength: Int) throws -> InflateResult { + mutating func decompress( + part: inout ByteBuffer, + buffer: inout ByteBuffer, + compressedLength: Int + ) throws -> InflateResult { let result = try self.stream.inflatePart(input: &part, output: &buffer) self.inflated += result.written diff --git a/Sources/NIOHTTPCompression/HTTPRequestCompressor.swift b/Sources/NIOHTTPCompression/HTTPRequestCompressor.swift index d8b21bb9..c32715c0 100644 --- a/Sources/NIOHTTPCompression/HTTPRequestCompressor.swift +++ b/Sources/NIOHTTPCompression/HTTPRequestCompressor.swift @@ -54,7 +54,7 @@ public final class NIOHTTPRequestCompressor: ChannelOutboundHandler, RemovableCh var compressor: NIOCompression.Compressor /// pending write promise var pendingWritePromise: EventLoopPromise! - + /// Initialize a ``NIOHTTPRequestCompressor`` /// - Parameter encoding: Compression algorithm to use public init(encoding: NIOCompression.Algorithm) { @@ -83,7 +83,7 @@ public final class NIOHTTPRequestCompressor: ChannelOutboundHandler, RemovableCh /// - promise: The eventloop promise that should be notified when the operation completes public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { pendingWritePromise.futureResult.cascade(to: promise) - + let httpData = unwrapOutboundIn(data) switch httpData { case .head(let head): @@ -116,7 +116,7 @@ public final class NIOHTTPRequestCompressor: ChannelOutboundHandler, RemovableCh default: preconditionFailure("Unexpected Body") } - + case .end: switch state { case .head(let head): @@ -125,14 +125,22 @@ public final class NIOHTTPRequestCompressor: ChannelOutboundHandler, RemovableCh context.write(data, promise: pendingWritePromise) case .body(var head, var body): // have head and the whole of the body. Compress body, set content length header and write it all out, including the end - let outputBuffer = compressor.compress(inputBuffer: &body, allocator: context.channel.allocator, finalise: true) + let outputBuffer = compressor.compress( + inputBuffer: &body, + allocator: context.channel.allocator, + finalise: true + ) head.headers.replaceOrAdd(name: "Content-Length", value: outputBuffer.readableBytes.description) context.write(wrapOutboundOut(.head(head)), promise: nil) context.write(wrapOutboundOut(.body(.byteBuffer(outputBuffer))), promise: nil) context.write(data, promise: pendingWritePromise) case .partialBody(var body): // have a section of the body. Compress that section of the body and write it out along with the end - let outputBuffer = compressor.compress(inputBuffer: &body, allocator: context.channel.allocator, finalise: true) + let outputBuffer = compressor.compress( + inputBuffer: &body, + allocator: context.channel.allocator, + finalise: true + ) context.write(wrapOutboundOut(.body(.byteBuffer(outputBuffer))), promise: nil) context.write(data, promise: pendingWritePromise) default: @@ -142,7 +150,7 @@ public final class NIOHTTPRequestCompressor: ChannelOutboundHandler, RemovableCh compressor.shutdown() } } - + public func flush(context: ChannelHandlerContext) { switch state { case .head(var head): @@ -156,19 +164,27 @@ public final class NIOHTTPRequestCompressor: ChannelOutboundHandler, RemovableCh case .body(var head, var body): // Write out head with transfer-encoding set to "chunked" as we cannot set the content length // Compress and write out what we have of the the body - let outputBuffer = compressor.compress(inputBuffer: &body, allocator: context.channel.allocator, finalise: false) + let outputBuffer = compressor.compress( + inputBuffer: &body, + allocator: context.channel.allocator, + finalise: false + ) head.headers.remove(name: "Content-Length") head.headers.replaceOrAdd(name: "Transfer-Encoding", value: "chunked") context.write(wrapOutboundOut(.head(head)), promise: nil) context.write(wrapOutboundOut(.body(.byteBuffer(outputBuffer))), promise: pendingWritePromise) state = .partialBody(context.channel.allocator.buffer(capacity: 0)) - + case .partialBody(var body): // Compress and write out what we have of the body - let outputBuffer = compressor.compress(inputBuffer: &body, allocator: context.channel.allocator, finalise: false) + let outputBuffer = compressor.compress( + inputBuffer: &body, + allocator: context.channel.allocator, + finalise: false + ) context.write(wrapOutboundOut(.body(.byteBuffer(outputBuffer))), promise: pendingWritePromise) state = .partialBody(context.channel.allocator.buffer(capacity: 0)) - + default: context.flush() return diff --git a/Sources/NIOHTTPCompression/HTTPRequestDecompressor.swift b/Sources/NIOHTTPCompression/HTTPRequestDecompressor.swift index 0e6d6307..15c05089 100644 --- a/Sources/NIOHTTPCompression/HTTPRequestDecompressor.swift +++ b/Sources/NIOHTTPCompression/HTTPRequestDecompressor.swift @@ -13,8 +13,8 @@ //===----------------------------------------------------------------------===// import CNIOExtrasZlib -import NIOHTTP1 import NIOCore +import NIOHTTP1 /// Channel hander to decompress incoming HTTP data. public final class NIOHTTPRequestDecompressor: ChannelDuplexHandler, RemovableChannelHandler { @@ -49,8 +49,7 @@ public final class NIOHTTPRequestDecompressor: ChannelDuplexHandler, RemovableCh switch request { case .head(let head): - if - let encoding = head.headers[canonicalForm: "Content-Encoding"].first?.lowercased(), + if let encoding = head.headers[canonicalForm: "Content-Encoding"].first?.lowercased(), let algorithm = NIOHTTPDecompression.CompressionAlgorithm(header: encoding), let length = head.headers[canonicalForm: "Content-Length"].first.flatMap({ Int($0) }) { @@ -73,7 +72,11 @@ public final class NIOHTTPRequestDecompressor: ChannelDuplexHandler, RemovableCh while part.readableBytes > 0 && !self.decompressionComplete { do { var buffer = context.channel.allocator.buffer(capacity: 16384) - let result = try self.decompressor.decompress(part: &part, buffer: &buffer, compressedLength: compression.contentLength) + let result = try self.decompressor.decompress( + part: &part, + buffer: &buffer, + compressedLength: compression.contentLength + ) if result.complete { self.decompressionComplete = true } diff --git a/Sources/NIOHTTPCompression/HTTPResponseCompressor.swift b/Sources/NIOHTTPCompression/HTTPResponseCompressor.swift index ebdd9e1a..9bfa05cf 100644 --- a/Sources/NIOHTTPCompression/HTTPResponseCompressor.swift +++ b/Sources/NIOHTTPCompression/HTTPResponseCompressor.swift @@ -25,11 +25,10 @@ extension StringProtocol { /// - prefix: The string to match at the beginning of `self` /// - returns: Whether or not `self` starts with the same unicode scalars as `prefix`. func startsWithExactly(_ prefix: S) -> Bool { - return self.utf8.starts(with: prefix.utf8) + self.utf8.starts(with: prefix.utf8) } } - /// Given a header value, extracts the q value if there is one present. If one is not present, /// returns the default q value, 1.0. private func qValueFromHeader(_ text: S) -> Float { @@ -73,7 +72,7 @@ public final class HTTPResponseCompressor: ChannelDuplexHandler, RemovableChanne public typealias OutboundIn = HTTPServerResponsePart /// This class emits `HTTPServerResponsePart` outbound. public typealias OutboundOut = HTTPServerResponsePart - + /// A closure that accepts a response header, optionally modifies it, and returns `true` if the response it belongs to should be compressed. /// /// - Parameter responseHeaders: The headers that will be used for the response. These can be modified as needed at this stage, to clean up any marker headers used to statelessly determine if compression should occur, and the new headers will be used when writing the response. Compression headers are not yet provided and should not be set; ``HTTPResponseCompressor`` will set them accordingly based on the result of this predicate. @@ -85,7 +84,7 @@ public final class HTTPResponseCompressor: ChannelDuplexHandler, RemovableChanne _ responseHeaders: inout HTTPResponseHead, _ isCompressionSupported: Bool ) -> CompressionIntent - + /// A signal a ``ResponseCompressionPredicate`` returns to indicate if it intends for compression to be used or not when supported by HTTP. public struct CompressionIntent: Sendable, Hashable { /// The internal type ``CompressionIntent`` uses. @@ -95,15 +94,15 @@ public final class HTTPResponseCompressor: ChannelDuplexHandler, RemovableChanne /// The response should not be compressed even if supported by the HTTP protocol. case doNotCompress } - + /// The raw value of the intent. let rawValue: RawValue - + /// Initialize the raw value with an internal intent. init(_ rawValue: RawValue) { self.rawValue = rawValue } - + /// The response should be compressed if supported by the HTTP protocol. public static let compressIfPossible = CompressionIntent(.compressIfPossible) /// The response should not be compressed even if supported by the HTTP protocol. @@ -117,7 +116,7 @@ public final class HTTPResponseCompressor: ChannelDuplexHandler, RemovableChanne /// Data was somehow lost without being written. case noDataToWrite } - + private var compressor: NIOCompression.Compressor // A queue of accept headers. @@ -135,11 +134,14 @@ public final class HTTPResponseCompressor: ChannelDuplexHandler, RemovableChanne // TODO: This version is kept around for backwards compatibility and should be merged with the signature below in the next major version: https://github.com/apple/swift-nio-extras/issues/226 self.init(initialByteBufferCapacity: initialByteBufferCapacity, responseCompressionPredicate: nil) } - + /// Initialize a ``HTTPResponseCompressor``. /// - Parameter initialByteBufferCapacity: Initial size of buffer to allocate when hander is first added. /// - Parameter responseCompressionPredicate: The predicate used to determine if the response should be compressed or not based on its headers. Defaults to `nil`, which will compress every response this handler sees. This predicate is always called whether the client supports compression for this response or not, so it can be used to clean up any marker headers you may use to determine if compression should be performed or not. Please see ``ResponseCompressionPredicate`` for more details. - public init(initialByteBufferCapacity: Int = 1024, responseCompressionPredicate: ResponseCompressionPredicate? = nil) { + public init( + initialByteBufferCapacity: Int = 1024, + responseCompressionPredicate: ResponseCompressionPredicate? = nil + ) { self.initialByteBufferCapacity = initialByteBufferCapacity self.responseCompressionPredicate = responseCompressionPredicate self.compressor = NIOCompression.Compressor() @@ -148,7 +150,9 @@ public final class HTTPResponseCompressor: ChannelDuplexHandler, RemovableChanne /// Setup and add to the pipeline. /// - Parameter context: Calling context. public func handlerAdded(context: ChannelHandlerContext) { - pendingResponse = PartialHTTPResponse(bodyBuffer: context.channel.allocator.buffer(capacity: initialByteBufferCapacity)) + pendingResponse = PartialHTTPResponse( + bodyBuffer: context.channel.allocator.buffer(capacity: initialByteBufferCapacity) + ) pendingWritePromise = context.eventLoop.makePromise() } @@ -174,19 +178,20 @@ public final class HTTPResponseCompressor: ChannelDuplexHandler, RemovableChanne /// Grab the algorithm to use from the bottom of the accept queue, which will help determine if we support compression for this response or not. let algorithm = compressionAlgorithm() let requestSupportsCompression = algorithm != nil && responseHead.status.mayHaveResponseBody - + /// If a predicate was set, ask it if we should compress when compression is supported, and give the predicate a chance to clean up any marker headers that may have been set even if compression were not supported. - let predicateCompressionIntent = responseCompressionPredicate?(&responseHead, requestSupportsCompression) ?? .compressIfPossible - + let predicateCompressionIntent = + responseCompressionPredicate?(&responseHead, requestSupportsCompression) ?? .compressIfPossible + /// Make sure that compression should proceed, otherwise stop here and supply the response headers before configuring the compressor. guard let algorithm, requestSupportsCompression, predicateCompressionIntent == .compressIfPossible else { context.write(wrapOutboundOut(.head(responseHead)), promise: promise) return } - + /// Previous handlers in the pipeline might have already set this header even though they should not have as it is compressor responsibility to decide what encoding to use. responseHead.headers.replaceOrAdd(name: "Content-Encoding", value: algorithm.description) - + /// Initialize the compressor and write the header data, which marks the compressor as "active" allowing the `.body` and `.end` cases to properly compress the response rather than passing it as is. compressor.initialize(encoding: algorithm) pendingResponse.bufferResponseHead(responseHead) @@ -303,11 +308,11 @@ private struct PartialHTTPResponse { private let initialBufferSize: Int var isCompleteResponse: Bool { - return head != nil && end != nil + head != nil && end != nil } var mustFlush: Bool { - return end != nil + end != nil } init(bodyBuffer: ByteBuffer) { @@ -357,15 +362,17 @@ private struct PartialHTTPResponse { /// /// Calling this function resets the buffer, freeing any excess memory allocated in the internal /// buffer and losing all copies of the other HTTP data. At this point it may freely be reused. - mutating func flush(compressor: inout NIOCompression.Compressor, allocator: ByteBufferAllocator) -> (HTTPResponseHead?, ByteBuffer?, HTTPServerResponsePart?) { + mutating func flush( + compressor: inout NIOCompression.Compressor, + allocator: ByteBufferAllocator + ) -> (HTTPResponseHead?, ByteBuffer?, HTTPServerResponsePart?) { var outputBody: ByteBuffer? = nil if self.body.readableBytes > 0 || mustFlush { let compressedBody = compressor.compress(inputBuffer: &self.body, allocator: allocator, finalise: mustFlush) if isCompleteResponse { head!.headers.remove(name: "transfer-encoding") head!.headers.replaceOrAdd(name: "content-length", value: "\(compressedBody.readableBytes)") - } - else if head != nil && head!.status.mayHaveResponseBody { + } else if head != nil && head!.status.mayHaveResponseBody { head!.headers.remove(name: "content-length") head!.headers.replaceOrAdd(name: "transfer-encoding", value: "chunked") } @@ -377,4 +384,3 @@ private struct PartialHTTPResponse { return response } } - diff --git a/Sources/NIOHTTPCompression/HTTPResponseDecompressor.swift b/Sources/NIOHTTPCompression/HTTPResponseDecompressor.swift index e4f40178..cfc9815b 100644 --- a/Sources/NIOHTTPCompression/HTTPResponseDecompressor.swift +++ b/Sources/NIOHTTPCompression/HTTPResponseDecompressor.swift @@ -28,10 +28,10 @@ public final class NIOHTTPResponseDecompressor: ChannelDuplexHandler, RemovableC /// this struct encapsulates the state of a single http response decompression private struct Compression { - + /// the used algorithm var algorithm: NIOHTTPDecompression.CompressionAlgorithm - + /// the number of already consumed compressed bytes var compressedLength: Int } @@ -73,7 +73,7 @@ public final class NIOHTTPResponseDecompressor: ChannelDuplexHandler, RemovableC self.compression = Compression(algorithm: algorithm, compressedLength: 0) try self.decompressor.initializeDecoder() } - + context.fireChannelRead(data) } catch { context.fireErrorCaught(error) @@ -83,26 +83,29 @@ public final class NIOHTTPResponseDecompressor: ChannelDuplexHandler, RemovableC context.fireChannelRead(data) return } - + do { compression.compressedLength += part.readableBytes while part.readableBytes > 0 && !self.decompressionComplete { var buffer = context.channel.allocator.buffer(capacity: 16384) - let result = try self.decompressor.decompress(part: &part, buffer: &buffer, compressedLength: compression.compressedLength) + let result = try self.decompressor.decompress( + part: &part, + buffer: &buffer, + compressedLength: compression.compressedLength + ) if result.complete { self.decompressionComplete = true } context.fireChannelRead(self.wrapInboundOut(.body(buffer))) } - + // assign the changed local property back to the class state self.compression = compression if part.readableBytes > 0 { context.fireErrorCaught(NIOHTTPDecompression.ExtraDecompressionError.invalidTrailingData) } - } - catch { + } catch { context.fireErrorCaught(error) } case .end: diff --git a/Sources/NIOHTTPTypesHTTP1/HTTPTypeConversion.swift b/Sources/NIOHTTPTypesHTTP1/HTTPTypeConversion.swift index f1047746..8bad494d 100644 --- a/Sources/NIOHTTPTypesHTTP1/HTTPTypeConversion.swift +++ b/Sources/NIOHTTPTypesHTTP1/HTTPTypeConversion.swift @@ -27,11 +27,11 @@ public struct HTTP1TypeConversionError: Error, Equatable { } /// Failed to create HTTPRequest.Method from HTTPMethod - public static var invalidMethod: Self { .init(.invalidMethod)} + public static var invalidMethod: Self { .init(.invalidMethod) } /// Failed to extract a path from HTTPRequest - public static var missingPath: Self { .init(.missingPath)} + public static var missingPath: Self { .init(.missingPath) } /// HTTPResponseHead had an invalid status code - public static var invalidStatusCode: Self { .init(.invalidStatusCode)} + public static var invalidStatusCode: Self { .init(.invalidStatusCode) } } extension HTTPMethod { @@ -117,7 +117,7 @@ extension HTTPRequest.Method { case .MKACTIVITY: self = .init("MKACTIVITY")! case .UNSUBSCRIBE: self = .init("UNSUBSCRIBE")! case .SOURCE: self = .init("SOURCE")! - case .RAW(value: let value): + case .RAW(let value): guard let method = HTTPRequest.Method(value) else { throw HTTP1TypeConversionError.invalidMethod } @@ -145,9 +145,11 @@ extension HTTPFields { } if let name = HTTPField.Name(field.name) { if splitCookie && name == .cookie, #available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *) { - self.append(contentsOf: field.value.split(separator: "; ", omittingEmptySubsequences: false).map { - HTTPField(name: name, value: String($0)) - }) + self.append( + contentsOf: field.value.split(separator: "; ", omittingEmptySubsequences: false).map { + HTTPField(name: name, value: String($0)) + } + ) } else { self.append(HTTPField(name: name, value: field.value)) } @@ -219,7 +221,10 @@ extension HTTPResponse { guard oldResponse.status.code <= 999 else { throw HTTP1TypeConversionError.invalidStatusCode } - let status = HTTPResponse.Status(code: Int(oldResponse.status.code), reasonPhrase: oldResponse.status.reasonPhrase) + let status = HTTPResponse.Status( + code: Int(oldResponse.status.code), + reasonPhrase: oldResponse.status.reasonPhrase + ) self.init(status: status, headerFields: HTTPFields(oldResponse.headers, splitCookie: false)) } } diff --git a/Sources/NIOHTTPTypesHTTP2/HTTP2HeadersStateMachine.swift b/Sources/NIOHTTPTypesHTTP2/HTTP2HeadersStateMachine.swift index 64a88a98..0246b93c 100644 --- a/Sources/NIOHTTPTypesHTTP2/HTTP2HeadersStateMachine.swift +++ b/Sources/NIOHTTPTypesHTTP2/HTTP2HeadersStateMachine.swift @@ -91,23 +91,23 @@ struct HTTP2HeadersStateMachine { // The first header block received on a server mode stream must be a request block. newType = .requestHead case (.client, .none), - (.client, .some(.informationalResponseHead)): + (.client, .some(.informationalResponseHead)): // The first header block received on a client mode stream may be either informational or final, // depending on the value of the :status pseudo-header. Alternatively, if the previous // header block was informational, the same possibilities apply. newType = try block.isInformationalResponse() ? .informationalResponseHead : .finalResponseHead case (.server, .some(.requestHead)), - (.client, .some(.finalResponseHead)): + (.client, .some(.finalResponseHead)): // If the server has already received a request head, or the client has already received a final response, // this is a trailer block. newType = .trailer case (.server, .some(.informationalResponseHead)), - (.server, .some(.finalResponseHead)), - (.client, .some(.requestHead)): + (.server, .some(.finalResponseHead)), + (.client, .some(.requestHead)): // These states should not be reachable! preconditionFailure("Invalid internal state!") case (.server, .some(.trailer)), - (.client, .some(.trailer)): + (.client, .some(.trailer)): // TODO(cory): This should probably throw, as this can happen in malformed programs without the world ending. preconditionFailure("Sending too many header blocks.") } diff --git a/Sources/NIOHTTPTypesHTTP2/HTTP2ToHTTPCodec.swift b/Sources/NIOHTTPTypesHTTP2/HTTP2ToHTTPCodec.swift index 58e44968..10c81adb 100644 --- a/Sources/NIOHTTPTypesHTTP2/HTTP2ToHTTPCodec.swift +++ b/Sources/NIOHTTPTypesHTTP2/HTTP2ToHTTPCodec.swift @@ -25,7 +25,9 @@ private struct BaseClientCodec { private var outgoingHTTP1RequestHead: HTTPRequest? - mutating func processInboundData(_ data: HTTP2Frame.FramePayload) throws -> (first: HTTPResponsePart?, second: HTTPResponsePart?) { + mutating func processInboundData( + _ data: HTTP2Frame.FramePayload + ) throws -> (first: HTTPResponsePart?, second: HTTPResponsePart?) { switch data { case .headers(let headerContent): switch try self.headerStateMachine.newHeaders(block: headerContent.headers) { @@ -68,13 +70,17 @@ private struct BaseClientCodec { } } return (first: first, second: second) - case .alternativeService, .rstStream, .priority, .windowUpdate, .settings, .pushPromise, .ping, .goAway, .origin: + case .alternativeService, .rstStream, .priority, .windowUpdate, .settings, .pushPromise, .ping, .goAway, + .origin: // These are not meaningful in HTTP messaging, so drop them. return (first: nil, second: nil) } } - mutating func processOutboundData(_ data: HTTPRequestPart, allocator: ByteBufferAllocator) throws -> HTTP2Frame.FramePayload { + mutating func processOutboundData( + _ data: HTTPRequestPart, + allocator: ByteBufferAllocator + ) throws -> HTTP2Frame.FramePayload { switch data { case .head(let head): precondition(self.outgoingHTTP1RequestHead == nil, "Only a single HTTP request allowed per HTTP2 stream") @@ -85,10 +91,12 @@ private struct BaseClientCodec { return .data(HTTP2Frame.FramePayload.Data(data: .byteBuffer(body))) case .end(let trailers): if let trailers { - return .headers(.init( - headers: HPACKHeaders(trailers), - endStream: true - )) + return .headers( + .init( + headers: HPACKHeaders(trailers), + endStream: true + ) + ) } else { return .data(.init(data: .byteBuffer(allocator.buffer(capacity: 0)), endStream: true)) } @@ -134,7 +142,10 @@ public final class HTTP2FramePayloadToHTTPClientCodec: ChannelDuplexHandler, Rem let requestPart = self.unwrapOutboundIn(data) do { - let transformedPayload = try self.baseCodec.processOutboundData(requestPart, allocator: context.channel.allocator) + let transformedPayload = try self.baseCodec.processOutboundData( + requestPart, + allocator: context.channel.allocator + ) context.write(self.wrapOutboundOut(transformedPayload), promise: promise) } catch { promise?.fail(error) @@ -148,7 +159,9 @@ public final class HTTP2FramePayloadToHTTPClientCodec: ChannelDuplexHandler, Rem private struct BaseServerCodec { private var headerStateMachine: HTTP2HeadersStateMachine = .init(mode: .server) - mutating func processInboundData(_ data: HTTP2Frame.FramePayload) throws -> (first: HTTPRequestPart?, second: HTTPRequestPart?) { + mutating func processInboundData( + _ data: HTTP2Frame.FramePayload + ) throws -> (first: HTTPRequestPart?, second: HTTPRequestPart?) { switch data { case .headers(let headerContent): if case .trailer = try self.headerStateMachine.newHeaders(block: headerContent.headers) { @@ -183,7 +196,10 @@ private struct BaseServerCodec { } } - mutating func processOutboundData(_ data: HTTPResponsePart, allocator: ByteBufferAllocator) -> HTTP2Frame.FramePayload { + mutating func processOutboundData( + _ data: HTTPResponsePart, + allocator: ByteBufferAllocator + ) -> HTTP2Frame.FramePayload { switch data { case .head(let head): let payload = HTTP2Frame.FramePayload.Headers(headers: HPACKHeaders(head)) @@ -193,10 +209,12 @@ private struct BaseServerCodec { return .data(payload) case .end(let trailers): if let trailers { - return .headers(.init( - headers: HPACKHeaders(trailers), - endStream: true - )) + return .headers( + .init( + headers: HPACKHeaders(trailers), + endStream: true + ) + ) } else { return .data(.init(data: .byteBuffer(allocator.buffer(capacity: 0)), endStream: true)) } diff --git a/Sources/NIOHTTPTypesHTTP2/HTTPTypeConversion.swift b/Sources/NIOHTTPTypesHTTP2/HTTPTypeConversion.swift index 202facb6..e0f3777f 100644 --- a/Sources/NIOHTTPTypesHTTP2/HTTPTypeConversion.swift +++ b/Sources/NIOHTTPTypesHTTP2/HTTPTypeConversion.swift @@ -32,8 +32,8 @@ private enum HTTP2TypeConversionError: Error { case pseudoFieldInTrailers } -private extension HPACKIndexing { - init(_ newIndexingStrategy: HTTPField.DynamicTableIndexingStrategy) { +extension HPACKIndexing { + fileprivate init(_ newIndexingStrategy: HTTPField.DynamicTableIndexingStrategy) { switch newIndexingStrategy { case .avoid: self = .nonIndexable case .disallow: self = .neverIndexed @@ -42,8 +42,8 @@ private extension HPACKIndexing { } } -private extension HTTPField.DynamicTableIndexingStrategy { - init(_ oldIndexing: HPACKIndexing) { +extension HTTPField.DynamicTableIndexingStrategy { + fileprivate init(_ oldIndexing: HPACKIndexing) { switch oldIndexing { case .indexable: self = .automatic case .nonIndexable: self = .avoid @@ -221,7 +221,8 @@ extension HTTPResponse { throw HTTP2TypeConversionError.missingStatus } guard let status = Int(statusString), - (0 ... 999).contains(status) else { + (0...999).contains(status) + else { throw HTTP2TypeConversionError.invalidStatus } diff --git a/Sources/NIONFS3/MountTypes+Mount.swift b/Sources/NIONFS3/MountTypes+Mount.swift index 2ada4c70..08a3d093 100644 --- a/Sources/NIONFS3/MountTypes+Mount.swift +++ b/Sources/NIONFS3/MountTypes+Mount.swift @@ -57,8 +57,10 @@ extension ByteBuffer { switch reply.result { case .okay(let reply): bytesWritten += self.writeNFS3FileHandle(reply.fileHandle) - precondition(reply.authFlavors == [.unix] || reply.authFlavors == [.noAuth], - "Sorry, anything but [.unix] / [.system] / [.noAuth] unimplemented.") + precondition( + reply.authFlavors == [.unix] || reply.authFlavors == [.noAuth], + "Sorry, anything but [.unix] / [.system] / [.noAuth] unimplemented." + ) bytesWritten += self.writeInteger(UInt32(reply.authFlavors.count), as: UInt32.self) for flavor in reply.authFlavors { bytesWritten += self.writeInteger(flavor.rawValue, as: UInt32.self) @@ -71,15 +73,17 @@ extension ByteBuffer { } public mutating func readNFS3ReplyMount() throws -> MountReplyMount { - let result = try self.readNFS3Result(readOkay: { buffer -> MountReplyMount.Okay in - let fileHandle = try buffer.readNFS3FileHandle() - let authFlavors = try buffer.readNFS3List(readEntry: { buffer in - try buffer.readRPCAuthFlavor() - }) - return MountReplyMount.Okay(fileHandle: fileHandle, authFlavors: authFlavors) + let result = try self.readNFS3Result( + readOkay: { buffer -> MountReplyMount.Okay in + let fileHandle = try buffer.readNFS3FileHandle() + let authFlavors = try buffer.readNFS3List(readEntry: { buffer in + try buffer.readRPCAuthFlavor() + }) + return MountReplyMount.Okay(fileHandle: fileHandle, authFlavors: authFlavors) - }, - readFail: { _ in NFS3Nothing() }) + }, + readFail: { _ in NFS3Nothing() } + ) return MountReplyMount(result: result) } } diff --git a/Sources/NIONFS3/MountTypes+Null.swift b/Sources/NIONFS3/MountTypes+Null.swift index 8dc5f98d..cefdb546 100644 --- a/Sources/NIONFS3/MountTypes+Null.swift +++ b/Sources/NIONFS3/MountTypes+Null.swift @@ -21,10 +21,10 @@ public struct MountCallNull: Hashable & Sendable { extension ByteBuffer { public mutating func readMountCallNull() throws -> MountCallNull { - return MountCallNull() + MountCallNull() } @discardableResult public mutating func writeMountCallNull(_ call: MountCallNull) -> Int { - return 0 + 0 } } diff --git a/Sources/NIONFS3/MountTypes+Unmount.swift b/Sources/NIONFS3/MountTypes+Unmount.swift index d1122dd4..169d954c 100644 --- a/Sources/NIONFS3/MountTypes+Unmount.swift +++ b/Sources/NIONFS3/MountTypes+Unmount.swift @@ -38,10 +38,10 @@ extension ByteBuffer { } @discardableResult public mutating func writeNFS3ReplyUnmount(_ reply: MountReplyUnmount) -> Int { - return 0 + 0 } public mutating func readNFS3ReplyUnmount() throws -> MountReplyUnmount { - return MountReplyUnmount() + MountReplyUnmount() } } diff --git a/Sources/NIONFS3/NFSCallDecoder.swift b/Sources/NIONFS3/NFSCallDecoder.swift index 25616655..ba9dbab4 100644 --- a/Sources/NIONFS3/NFSCallDecoder.swift +++ b/Sources/NIONFS3/NFSCallDecoder.swift @@ -32,6 +32,6 @@ public struct NFS3CallDecoder: NIOSingleStepByteToMessageDecoder { } public mutating func decodeLast(buffer: inout ByteBuffer, seenEOF: Bool) throws -> RPCNFS3Call? { - return try self.decode(buffer: &buffer) + try self.decode(buffer: &buffer) } } diff --git a/Sources/NIONFS3/NFSFileSystem+FuturesAPI.swift b/Sources/NIONFS3/NFSFileSystem+FuturesAPI.swift index 7929e6e4..949255b2 100644 --- a/Sources/NIONFS3/NFSFileSystem+FuturesAPI.swift +++ b/Sources/NIONFS3/NFSFileSystem+FuturesAPI.swift @@ -111,7 +111,8 @@ extension NFS3FileSystemNoAuth { return promise.futureResult } - public func readdirplus(_ call: NFS3CallReadDirPlus, eventLoop: EventLoop) -> EventLoopFuture { + public func readdirplus(_ call: NFS3CallReadDirPlus, eventLoop: EventLoop) -> EventLoopFuture + { let promise = eventLoop.makePromise(of: NFS3ReplyReadDirPlus.self) if eventLoop.inEventLoop { self.readdirplus(call, promise: promise) diff --git a/Sources/NIONFS3/NFSFileSystem.swift b/Sources/NIONFS3/NFSFileSystem.swift index 3b07648b..67a10c95 100644 --- a/Sources/NIONFS3/NFSFileSystem.swift +++ b/Sources/NIONFS3/NFSFileSystem.swift @@ -35,28 +35,49 @@ public protocol NFS3FileSystemNoAuth { extension NFS3FileSystemNoAuth { public func readdir(_ call: NFS3CallReadDir, promise originalPromise: EventLoopPromise) { let promise = originalPromise.futureResult.eventLoop.makePromise(of: NFS3ReplyReadDirPlus.self) - self.readdirplus(NFS3CallReadDirPlus(fileHandle: call.fileHandle, - cookie: call.cookie, - cookieVerifier: call.cookieVerifier, - dirCount: NFS3Count(rawValue: .max), - maxCount: call.maxResultByteCount), + self.readdirplus( + NFS3CallReadDirPlus( + fileHandle: call.fileHandle, + cookie: call.cookie, + cookieVerifier: call.cookieVerifier, + dirCount: NFS3Count(rawValue: .max), + maxCount: call.maxResultByteCount + ), - promise: promise) + promise: promise + ) promise.futureResult.whenComplete { readDirPlusResult in switch readDirPlusResult { case .success(let readDirPlusSuccessResult): switch readDirPlusSuccessResult.result { case .okay(let readDirPlusOkay): - originalPromise.succeed(NFS3ReplyReadDir(result: .okay(.init(cookieVerifier: readDirPlusOkay.cookieVerifier, - entries: readDirPlusOkay.entries.map { readDirPlusEntry in - NFS3ReplyReadDir.Entry(fileID: readDirPlusEntry.fileID, - fileName: readDirPlusEntry.fileName, - cookie: readDirPlusEntry.cookie) - }, eof: readDirPlusOkay.eof)))) + originalPromise.succeed( + NFS3ReplyReadDir( + result: .okay( + .init( + cookieVerifier: readDirPlusOkay.cookieVerifier, + entries: readDirPlusOkay.entries.map { readDirPlusEntry in + NFS3ReplyReadDir.Entry( + fileID: readDirPlusEntry.fileID, + fileName: readDirPlusEntry.fileName, + cookie: readDirPlusEntry.cookie + ) + }, + eof: readDirPlusOkay.eof + ) + ) + ) + ) case .fail(let nfsStatus, let readDirPlusFailure): - originalPromise.succeed(NFS3ReplyReadDir(result: .fail(nfsStatus, - .init(dirAttributes: readDirPlusFailure.dirAttributes)))) + originalPromise.succeed( + NFS3ReplyReadDir( + result: .fail( + nfsStatus, + .init(dirAttributes: readDirPlusFailure.dirAttributes) + ) + ) + ) } case .failure(let error): diff --git a/Sources/NIONFS3/NFSFileSystemHandler.swift b/Sources/NIONFS3/NFSFileSystemHandler.swift index cc2739ee..c49ad0d0 100644 --- a/Sources/NIONFS3/NFSFileSystemHandler.swift +++ b/Sources/NIONFS3/NFSFileSystemHandler.swift @@ -25,15 +25,22 @@ import NIOCore /// because NFS3 tranditionally just trusts the UNIX uid/gid that the client provided. So there's no security value /// added by verifying them. However, the client may rely on the server to check the UNIX permissions (whilst trusting /// the uid/gid) which cannot be done with this handler. -public final class NFS3FileSystemNoAuthHandler: ChannelDuplexHandler, NFS3FileSystemResponder { +public final class NFS3FileSystemNoAuthHandler: ChannelDuplexHandler, NFS3FileSystemResponder +{ public typealias OutboundIn = Never public typealias InboundIn = RPCNFS3Call public typealias OutboundOut = RPCNFS3Reply private let filesystem: FS - private let rpcReplySuccess: RPCReplyStatus = .messageAccepted(.init(verifier: .init(flavor: .noAuth, - opaque: nil), - status: .success)) + private let rpcReplySuccess: RPCReplyStatus = .messageAccepted( + .init( + verifier: .init( + flavor: .noAuth, + opaque: nil + ), + status: .success + ) + ) private var invoker: NFS3FileSystemInvoker>? private var context: ChannelHandlerContext? = nil @@ -53,25 +60,38 @@ public final class NFS3FileSystemNoAuthHandler: Channe func sendSuccessfulReply(_ reply: NFS3Reply, call: RPCNFS3Call) { if let context = self.context { - let reply = RPCNFS3Reply(rpcReply: .init(xid: call.rpcCall.xid, - status: self.rpcReplySuccess), - nfsReply: reply) + let reply = RPCNFS3Reply( + rpcReply: .init( + xid: call.rpcCall.xid, + status: self.rpcReplySuccess + ), + nfsReply: reply + ) context.writeAndFlush(self.wrapOutboundOut(reply), promise: nil) } } func sendError(_ error: Error, call: RPCNFS3Call) { if let context = self.context { - let nfsErrorReply = RPCNFS3Reply(rpcReply: .init(xid: call.rpcCall.xid, - status: self.rpcReplySuccess), - nfsReply: .mount(.init(result: .fail(.errorSERVERFAULT, - NFS3Nothing())))) + let nfsErrorReply = RPCNFS3Reply( + rpcReply: .init( + xid: call.rpcCall.xid, + status: self.rpcReplySuccess + ), + nfsReply: .mount( + .init( + result: .fail( + .errorSERVERFAULT, + NFS3Nothing() + ) + ) + ) + ) context.writeAndFlush(self.wrapOutboundOut(nfsErrorReply), promise: nil) context.fireErrorCaught(error) } } - public func channelRead(context: ChannelHandlerContext, data: NIOAny) { let call = self.unwrapInboundIn(data) // ! is safe here because it's set on `handlerAdded` (and unset in `handlerRemoved`). Calling this outside that @@ -82,10 +102,14 @@ public final class NFS3FileSystemNoAuthHandler: Channe public func errorCaught(context: ChannelHandlerContext, error: Error) { switch error as? NFS3Error { case .unknownProgramOrProcedure(.call(let call)): - let acceptedReply = RPCAcceptedReply(verifier: RPCOpaqueAuth(flavor: .noAuth, opaque: nil), - status: .procedureUnavailable) - let reply = RPCNFS3Reply(rpcReply: RPCReply(xid: call.xid, status: .messageAccepted(acceptedReply)), - nfsReply: .null) + let acceptedReply = RPCAcceptedReply( + verifier: RPCOpaqueAuth(flavor: .noAuth, opaque: nil), + status: .procedureUnavailable + ) + let reply = RPCNFS3Reply( + rpcReply: RPCReply(xid: call.xid, status: .messageAccepted(acceptedReply)), + nfsReply: .null + ) context.writeAndFlush(self.wrapOutboundOut(reply), promise: nil) return default: diff --git a/Sources/NIONFS3/NFSFileSystemInvoker.swift b/Sources/NIONFS3/NFSFileSystemInvoker.swift index f3cba7df..c117efa5 100644 --- a/Sources/NIONFS3/NFSFileSystemInvoker.swift +++ b/Sources/NIONFS3/NFSFileSystemInvoker.swift @@ -31,7 +31,7 @@ internal struct NFS3FileSystemInvoker EventLoopFuture { - return self.fs.shutdown(eventLoop: self.eventLoop) + self.fs.shutdown(eventLoop: self.eventLoop) } func handleNFS3Call(_ callMessage: RPCNFS3Call) { diff --git a/Sources/NIONFS3/NFSFileSystemServerHandler.swift b/Sources/NIONFS3/NFSFileSystemServerHandler.swift index 97b3f923..88c74552 100644 --- a/Sources/NIONFS3/NFSFileSystemServerHandler.swift +++ b/Sources/NIONFS3/NFSFileSystemServerHandler.swift @@ -19,12 +19,20 @@ public final class NFS3FileSystemServerHandler { public typealias OutboundOut = ByteBuffer private var error: Error? = nil - private var b2md = NIOSingleStepByteToMessageProcessor(NFS3CallDecoder(), - maximumBufferSize: 4 * 1024 * 1024) + private var b2md = NIOSingleStepByteToMessageProcessor( + NFS3CallDecoder(), + maximumBufferSize: 4 * 1024 * 1024 + ) private let filesystem: FS - private let rpcReplySuccess: RPCReplyStatus = .messageAccepted(.init(verifier: .init(flavor: .noAuth, - opaque: nil), - status: .success)) + private let rpcReplySuccess: RPCReplyStatus = .messageAccepted( + .init( + verifier: .init( + flavor: .noAuth, + opaque: nil + ), + status: .success + ) + ) private var invoker: NFS3FileSystemInvoker>? private var context: ChannelHandlerContext? = nil private var writeBuffer = ByteBuffer() @@ -49,8 +57,12 @@ extension NFS3FileSystemServerHandler: ChannelInboundHandler { public func channelRead(context: ChannelHandlerContext, data: NIOAny) { let data = self.unwrapInboundIn(data) guard self.error == nil else { - context.fireErrorCaught(ByteToMessageDecoderError.dataReceivedInErrorState(self.error!, - data)) + context.fireErrorCaught( + ByteToMessageDecoderError.dataReceivedInErrorState( + self.error!, + data + ) + ) return } @@ -68,10 +80,14 @@ extension NFS3FileSystemServerHandler: ChannelInboundHandler { public func errorCaught(context: ChannelHandlerContext, error: Error) { switch error as? NFS3Error { case .unknownProgramOrProcedure(.call(let call)): - let acceptedReply = RPCAcceptedReply(verifier: .init(flavor: .noAuth, opaque: nil), - status: .procedureUnavailable) - let reply = RPCNFS3Reply(rpcReply: RPCReply(xid: call.xid, status: .messageAccepted(acceptedReply)), - nfsReply: .null) + let acceptedReply = RPCAcceptedReply( + verifier: .init(flavor: .noAuth, opaque: nil), + status: .procedureUnavailable + ) + let reply = RPCNFS3Reply( + rpcReply: RPCReply(xid: call.xid, status: .messageAccepted(acceptedReply)), + nfsReply: .null + ) self.writeBuffer.clear() self.writeBuffer.writeRPCNFS3Reply(reply) return @@ -85,9 +101,13 @@ extension NFS3FileSystemServerHandler: ChannelInboundHandler { extension NFS3FileSystemServerHandler: NFS3FileSystemResponder { func sendSuccessfulReply(_ reply: NFS3Reply, call: RPCNFS3Call) { if let context = self.context { - let reply = RPCNFS3Reply(rpcReply: .init(xid: call.rpcCall.xid, - status: self.rpcReplySuccess), - nfsReply: reply) + let reply = RPCNFS3Reply( + rpcReply: .init( + xid: call.rpcCall.xid, + status: self.rpcReplySuccess + ), + nfsReply: reply + ) self.writeBuffer.clear() switch self.writeBuffer.writeRPCNFS3ReplyPartially(reply).1 { @@ -107,10 +127,20 @@ extension NFS3FileSystemServerHandler: NFS3FileSystemResponder { func sendError(_ error: Error, call: RPCNFS3Call) { if let context = self.context { - let reply = RPCNFS3Reply(rpcReply: .init(xid: call.rpcCall.xid, - status: self.rpcReplySuccess), - nfsReply: .mount(.init(result: .fail(.errorSERVERFAULT, - NFS3Nothing())))) + let reply = RPCNFS3Reply( + rpcReply: .init( + xid: call.rpcCall.xid, + status: self.rpcReplySuccess + ), + nfsReply: .mount( + .init( + result: .fail( + .errorSERVERFAULT, + NFS3Nothing() + ) + ) + ) + ) self.writeBuffer.clear() self.writeBuffer.writeRPCNFS3Reply(reply) diff --git a/Sources/NIONFS3/NFSReplyDecoder.swift b/Sources/NIONFS3/NFSReplyDecoder.swift index 5c346424..cb98020e 100644 --- a/Sources/NIONFS3/NFSReplyDecoder.swift +++ b/Sources/NIONFS3/NFSReplyDecoder.swift @@ -27,8 +27,10 @@ public struct NFS3ReplyDecoder: WriteObservingByteToMessageDecoder { /// - prepopulatedProcecedures: For testing and other more obscure purposes it might be useful to pre-seed the /// decoder with some RPC numbers and their respective type. /// - allowDuplicateReplies: Whether to fail when receiving more than one response for a given call. - public init(prepopulatedProcecedures: [UInt32: RPCNFS3ProcedureID]? = nil, - allowDuplicateReplies: Bool = false) { + public init( + prepopulatedProcecedures: [UInt32: RPCNFS3ProcedureID]? = nil, + allowDuplicateReplies: Bool = false + ) { self.procedures = prepopulatedProcecedures ?? [:] self.allowDuplicateReplies = allowDuplicateReplies } diff --git a/Sources/NIONFS3/NFSTypes+Access.swift b/Sources/NIONFS3/NFSTypes+Access.swift index ddcff0e1..3b7c32a5 100644 --- a/Sources/NIONFS3/NFSTypes+Access.swift +++ b/Sources/NIONFS3/NFSTypes+Access.swift @@ -59,24 +59,29 @@ extension ByteBuffer { } @discardableResult public mutating func writeNFS3CallAccess(_ call: NFS3CallAccess) -> Int { - return self.writeNFS3FileHandle(call.object) - + self.writeInteger(call.access.rawValue) + self.writeNFS3FileHandle(call.object) + + self.writeInteger(call.access.rawValue) } public mutating func readNFS3ReplyAccess() throws -> NFS3ReplyAccess { - return NFS3ReplyAccess(result: try self.readNFS3Result( - readOkay: { buffer in - let attrs = try buffer.readNFS3Optional { buffer in - try buffer.readNFS3FileAttr() + NFS3ReplyAccess( + result: try self.readNFS3Result( + readOkay: { buffer in + let attrs = try buffer.readNFS3Optional { buffer in + try buffer.readNFS3FileAttr() + } + let access = try buffer.readNFS3Access() + return NFS3ReplyAccess.Okay(dirAttributes: attrs, access: access) + }, + readFail: { buffer in + NFS3ReplyAccess.Fail( + dirAttributes: try buffer.readNFS3Optional { buffer in + try buffer.readNFS3FileAttr() + } + ) } - let access = try buffer.readNFS3Access() - return NFS3ReplyAccess.Okay(dirAttributes: attrs, access: access) - }, - readFail: { buffer in - return NFS3ReplyAccess.Fail(dirAttributes: try buffer.readNFS3Optional { buffer in - try buffer.readNFS3FileAttr() - }) - })) + ) + ) } @discardableResult public mutating func writeNFS3ReplyAccess(_ accessResult: NFS3ReplyAccess) -> Int { @@ -86,8 +91,9 @@ extension ByteBuffer { case .okay(let result): bytesWritten += self.writeInteger(NFS3Status.ok.rawValue) if let attrs = result.dirAttributes { - bytesWritten += self.writeInteger(1, as: UInt32.self) - + self.writeNFS3FileAttr(attrs) + bytesWritten += + self.writeInteger(1, as: UInt32.self) + + self.writeNFS3FileAttr(attrs) } else { bytesWritten += self.writeInteger(0, as: UInt32.self) } @@ -96,8 +102,9 @@ extension ByteBuffer { precondition(status != .ok) bytesWritten += self.writeInteger(status.rawValue) if let attrs = fail.dirAttributes { - bytesWritten += self.writeInteger(1, as: UInt32.self) - + self.writeNFS3FileAttr(attrs) + bytesWritten += + self.writeInteger(1, as: UInt32.self) + + self.writeNFS3FileAttr(attrs) } else { bytesWritten += self.writeInteger(0, as: UInt32.self) } diff --git a/Sources/NIONFS3/NFSTypes+Common.swift b/Sources/NIONFS3/NFSTypes+Common.swift index d927ea2e..453d7a7d 100644 --- a/Sources/NIONFS3/NFSTypes+Common.swift +++ b/Sources/NIONFS3/NFSTypes+Common.swift @@ -21,17 +21,21 @@ public struct RPCNFS3Call: Hashable & Sendable { self.nfsCall = nfsCall } - public init(nfsCall: NFS3Call, - xid: UInt32, - credentials: RPCCredentials = .init(flavor: 0, length: 0, otherBytes: ByteBuffer()), - verifier: RPCOpaqueAuth = RPCOpaqueAuth(flavor: .noAuth)) { - var rpcCall = RPCCall(xid: xid, - rpcVersion: 2, - program: .max, // placeholder, overwritten below - programVersion: 3, - procedure: .max, // placeholder, overwritten below - credentials: credentials, - verifier: verifier) + public init( + nfsCall: NFS3Call, + xid: UInt32, + credentials: RPCCredentials = .init(flavor: 0, length: 0, otherBytes: ByteBuffer()), + verifier: RPCOpaqueAuth = RPCOpaqueAuth(flavor: .noAuth) + ) { + var rpcCall = RPCCall( + xid: xid, + rpcVersion: 2, + program: .max, // placeholder, overwritten below + programVersion: 3, + procedure: .max, // placeholder, overwritten below + credentials: credentials, + verifier: verifier + ) switch nfsCall { case .mountNull: @@ -80,7 +84,7 @@ extension RPCNFS3Call: Identifiable { public typealias ID = UInt32 public var id: ID { - return self.rpcCall.xid + self.rpcCall.xid } } @@ -98,7 +102,7 @@ extension RPCNFS3Reply: Identifiable { public typealias ID = UInt32 public var id: ID { - return self.rpcReply.xid + self.rpcReply.xid } } @@ -149,12 +153,12 @@ extension NFS3FileMode: ExpressibleByIntegerLiteral { extension ByteBuffer { public mutating func readNFS3FileMode() throws -> NFS3FileMode { - return NFS3FileMode(rawValue: try self.readNFS3Integer(as: NFS3FileMode.RawValue.self)) + NFS3FileMode(rawValue: try self.readNFS3Integer(as: NFS3FileMode.RawValue.self)) } @discardableResult public mutating func writeNFS3FileMode(_ value: NFS3FileMode) -> Int { - return self.writeInteger(value.rawValue) + self.writeInteger(value.rawValue) } } @@ -178,12 +182,12 @@ extension NFS3UID: ExpressibleByIntegerLiteral { extension ByteBuffer { public mutating func readNFS3UID() throws -> NFS3UID { - return NFS3UID(rawValue: try self.readNFS3Integer(as: NFS3UID.RawValue.self)) + NFS3UID(rawValue: try self.readNFS3Integer(as: NFS3UID.RawValue.self)) } @discardableResult public mutating func writeNFS3UID(_ value: NFS3UID) -> Int { - return self.writeInteger(value.rawValue) + self.writeInteger(value.rawValue) } } @@ -207,12 +211,12 @@ extension NFS3GID: ExpressibleByIntegerLiteral { extension ByteBuffer { public mutating func readNFS3GID() throws -> NFS3GID { - return NFS3GID(rawValue: try self.readNFS3Integer(as: NFS3GID.RawValue.self)) + NFS3GID(rawValue: try self.readNFS3Integer(as: NFS3GID.RawValue.self)) } @discardableResult public mutating func writeNFS3GID(_ value: NFS3GID) -> Int { - return self.writeInteger(value.rawValue) + self.writeInteger(value.rawValue) } } @@ -236,12 +240,12 @@ extension NFS3Size: ExpressibleByIntegerLiteral { extension ByteBuffer { public mutating func readNFS3Size() throws -> NFS3Size { - return NFS3Size(rawValue: try self.readNFS3Integer(as: NFS3Size.RawValue.self)) + NFS3Size(rawValue: try self.readNFS3Integer(as: NFS3Size.RawValue.self)) } @discardableResult public mutating func writeNFS3Size(_ value: NFS3Size) -> Int { - return self.writeInteger(value.rawValue) + self.writeInteger(value.rawValue) } } @@ -265,12 +269,12 @@ extension NFS3SpecData: ExpressibleByIntegerLiteral { extension ByteBuffer { public mutating func readNFS3SpecData() throws -> NFS3SpecData { - return NFS3SpecData(rawValue: try self.readNFS3Integer(as: NFS3SpecData.RawValue.self)) + NFS3SpecData(rawValue: try self.readNFS3Integer(as: NFS3SpecData.RawValue.self)) } @discardableResult public mutating func writeNFS3SpecData(_ value: NFS3SpecData) -> Int { - return self.writeInteger(value.rawValue) + self.writeInteger(value.rawValue) } } @@ -294,12 +298,12 @@ extension NFS3FileID: ExpressibleByIntegerLiteral { extension ByteBuffer { public mutating func readNFS3FileID() throws -> NFS3FileID { - return NFS3FileID(rawValue: try self.readNFS3Integer(as: NFS3FileID.RawValue.self)) + NFS3FileID(rawValue: try self.readNFS3Integer(as: NFS3FileID.RawValue.self)) } @discardableResult public mutating func writeNFS3FileID(_ value: NFS3FileID) -> Int { - return self.writeInteger(value.rawValue) + self.writeInteger(value.rawValue) } } @@ -323,12 +327,12 @@ extension NFS3Cookie: ExpressibleByIntegerLiteral { extension ByteBuffer { public mutating func readNFS3Cookie() throws -> NFS3Cookie { - return NFS3Cookie(rawValue: try self.readNFS3Integer(as: NFS3Cookie.RawValue.self)) + NFS3Cookie(rawValue: try self.readNFS3Integer(as: NFS3Cookie.RawValue.self)) } @discardableResult public mutating func writeNFS3Cookie(_ value: NFS3Cookie) -> Int { - return self.writeInteger(value.rawValue) + self.writeInteger(value.rawValue) } } @@ -352,12 +356,12 @@ extension NFS3CookieVerifier: ExpressibleByIntegerLiteral { extension ByteBuffer { public mutating func readNFS3CookieVerifier() throws -> NFS3CookieVerifier { - return NFS3CookieVerifier(rawValue: try self.readNFS3Integer(as: NFS3CookieVerifier.RawValue.self)) + NFS3CookieVerifier(rawValue: try self.readNFS3Integer(as: NFS3CookieVerifier.RawValue.self)) } @discardableResult public mutating func writeNFS3CookieVerifier(_ value: NFS3CookieVerifier) -> Int { - return self.writeInteger(value.rawValue) + self.writeInteger(value.rawValue) } } @@ -381,12 +385,12 @@ extension NFS3Offset: ExpressibleByIntegerLiteral { extension ByteBuffer { public mutating func readNFS3Offset() throws -> NFS3Offset { - return NFS3Offset(rawValue: try self.readNFS3Integer(as: NFS3Offset.RawValue.self)) + NFS3Offset(rawValue: try self.readNFS3Integer(as: NFS3Offset.RawValue.self)) } @discardableResult public mutating func writeNFS3Offset(_ value: NFS3Offset) -> Int { - return self.writeInteger(value.rawValue) + self.writeInteger(value.rawValue) } } @@ -410,16 +414,15 @@ extension NFS3Count: ExpressibleByIntegerLiteral { extension ByteBuffer { public mutating func readNFS3Count() throws -> NFS3Count { - return NFS3Count(rawValue: try self.readNFS3Integer(as: NFS3Count.RawValue.self)) + NFS3Count(rawValue: try self.readNFS3Integer(as: NFS3Count.RawValue.self)) } @discardableResult public mutating func writeNFS3Count(_ value: NFS3Count) -> Int { - return self.writeInteger(value.rawValue) + self.writeInteger(value.rawValue) } } - public struct NFS3Nothing: Hashable & Sendable { public init() {} } @@ -430,33 +433,33 @@ public struct NFS3Nothing: Hashable & Sendable { public enum NFS3Status: UInt32, Sendable { case ok = 0 case errorPERM = 1 - case errorNOENT = 2 - case errorIO = 5 - case errorNXIO = 6 - case errorACCES = 13 - case errorEXIST = 17 - case errorXDEV = 18 - case errorNODEV = 19 - case errorNOTDIR = 20 - case errorISDIR = 21 - case errorINVAL = 22 - case errorFBIG = 27 - case errorNOSPC = 28 - case errorROFS = 30 - case errorMLINK = 31 + case errorNOENT = 2 + case errorIO = 5 + case errorNXIO = 6 + case errorACCES = 13 + case errorEXIST = 17 + case errorXDEV = 18 + case errorNODEV = 19 + case errorNOTDIR = 20 + case errorISDIR = 21 + case errorINVAL = 22 + case errorFBIG = 27 + case errorNOSPC = 28 + case errorROFS = 30 + case errorMLINK = 31 case errorNAMETOOLONG = 63 - case errorNOTEMPTY = 66 - case errorDQUOT = 69 - case errorSTALE = 70 - case errorREMOTE = 71 - case errorBADHANDLE = 10001 - case errorNOT_SYNC = 10002 - case errorBAD_COOKIE = 10003 - case errorNOTSUPP = 10004 - case errorTOOSMALL = 10005 + case errorNOTEMPTY = 66 + case errorDQUOT = 69 + case errorSTALE = 70 + case errorREMOTE = 71 + case errorBADHANDLE = 10001 + case errorNOT_SYNC = 10002 + case errorBAD_COOKIE = 10003 + case errorNOTSUPP = 10004 + case errorTOOSMALL = 10005 case errorSERVERFAULT = 10006 - case errorBADTYPE = 10007 - case errorJUKEBOX = 10008 + case errorBADTYPE = 10007 + case errorJUKEBOX = 10008 } /// Check the access rights to a file. @@ -484,7 +487,7 @@ public struct NFS3Access: OptionSet & Hashable & Sendable { extension ByteBuffer { public mutating func readNFS3Access() throws -> NFS3Access { - return NFS3Access(rawValue: try self.readNFS3Integer(as: UInt32.self)) + NFS3Access(rawValue: try self.readNFS3Integer(as: UInt32.self)) } public mutating func writeNFS3Access(_ access: NFS3Access) { @@ -522,14 +525,16 @@ public struct NFS3FileHandle: Hashable & Sendable & CustomStringConvertible { /// - seealso: https://www.rfc-editor.org/rfc/rfc1813#page-106 public init(_ bytes: ByteBuffer) { precondition(bytes.readableBytes <= 64, "NFS3 mandates that file handles are NFS3_FHSIZE (64) bytes or less.") - precondition(bytes.readableBytes == MemoryLayout.size, - "Sorry, at the moment only file handles with exactly 8 bytes are implemented.") + precondition( + bytes.readableBytes == MemoryLayout.size, + "Sorry, at the moment only file handles with exactly 8 bytes are implemented." + ) var bytes = bytes self = NFS3FileHandle(bytes.readInteger(as: UInt64.self)!) } public var description: String { - return "NFS3FileHandle(\(self._value))" + "NFS3FileHandle(\(self._value))" } } @@ -553,7 +558,6 @@ extension UInt32 { } } - public struct NFS3Time: Hashable & Sendable { public init(seconds: UInt32, nanoseconds: UInt32) { self.seconds = seconds @@ -565,9 +569,21 @@ public struct NFS3Time: Hashable & Sendable { } public struct NFS3FileAttr: Hashable & Sendable { - public init(type: NFS3FileType, mode: NFS3FileMode, nlink: UInt32, uid: NFS3UID, gid: NFS3GID, size: NFS3Size, - used: NFS3Size, rdev: NFS3SpecData, fsid: UInt64, fileid: NFS3FileID, atime: NFS3Time, mtime: NFS3Time, - ctime: NFS3Time) { + public init( + type: NFS3FileType, + mode: NFS3FileMode, + nlink: UInt32, + uid: NFS3UID, + gid: NFS3GID, + size: NFS3Size, + used: NFS3Size, + rdev: NFS3SpecData, + fsid: UInt64, + fileid: NFS3FileID, + atime: NFS3Time, + mtime: NFS3Time, + ctime: NFS3Time + ) { self.type = type self.mode = mode self.nlink = nlink @@ -629,10 +645,12 @@ extension ByteBuffer { return .init(size: size, mtime: mtime, ctime: ctime) } - @discardableResult public mutating func writeNFS3WeakCacheConsistencyAttr(_ wccAttr: NFS3WeakCacheConsistencyAttr) -> Int { - return self.writeNFS3Size(wccAttr.size) - + self.writeNFS3Time(wccAttr.mtime) - + self.writeNFS3Time(wccAttr.ctime) + @discardableResult public mutating func writeNFS3WeakCacheConsistencyAttr( + _ wccAttr: NFS3WeakCacheConsistencyAttr + ) -> Int { + self.writeNFS3Size(wccAttr.size) + + self.writeNFS3Time(wccAttr.mtime) + + self.writeNFS3Time(wccAttr.ctime) } public mutating func readNFS3WeakCacheConsistencyData() throws -> NFS3WeakCacheConsistencyData { @@ -642,9 +660,11 @@ extension ByteBuffer { return .init(before: before, after: after) } - @discardableResult public mutating func writeNFS3WeakCacheConsistencyData(_ wccData: NFS3WeakCacheConsistencyData) -> Int { - return self.writeNFS3Optional(wccData.before, writer: { $0.writeNFS3WeakCacheConsistencyAttr($1) }) - + self.writeNFS3Optional(wccData.after, writer: { $0.writeNFS3FileAttr($1) }) + @discardableResult public mutating func writeNFS3WeakCacheConsistencyData( + _ wccData: NFS3WeakCacheConsistencyData + ) -> Int { + self.writeNFS3Optional(wccData.before, writer: { $0.writeNFS3WeakCacheConsistencyAttr($1) }) + + self.writeNFS3Optional(wccData.after, writer: { $0.writeNFS3FileAttr($1) }) } public mutating func readNFS3Integer(as: I.Type = I.self) throws -> I { @@ -658,7 +678,8 @@ extension ByteBuffer { public mutating func readNFS3Blob() throws -> ByteBuffer { let length = try self.readNFS3Integer(as: UInt32.self) guard let blob = self.readSlice(length: Int(length)), - let _ = self.readSlice(length: nfsStringFillBytes(Int(length))) else { + let _ = self.readSlice(length: nfsStringFillBytes(Int(length))) + else { throw NFS3Error.illegalRPCTooShort } return blob @@ -667,8 +688,8 @@ extension ByteBuffer { @discardableResult public mutating func writeNFS3Blob(_ blob: ByteBuffer) -> Int { let byteCount = blob.readableBytes return self.writeInteger(UInt32(byteCount)) - + self.writeImmutableBuffer(blob) - + self.writeRepeatingByte(0x42, count: nfsStringFillBytes(byteCount)) + + self.writeImmutableBuffer(blob) + + self.writeRepeatingByte(0x42, count: nfsStringFillBytes(byteCount)) } public mutating func readNFS3String() throws -> String { @@ -679,8 +700,8 @@ extension ByteBuffer { @discardableResult public mutating func writeNFS3String(_ string: String) -> Int { let byteCount = string.utf8.count return self.writeInteger(UInt32(byteCount)) - + self.writeString(string) - + self.writeRepeatingByte(0x42, count: nfsStringFillBytes(byteCount)) + + self.writeString(string) + + self.writeRepeatingByte(0x42, count: nfsStringFillBytes(byteCount)) } public mutating func readNFS3FileHandle() throws -> NFS3FileHandle { @@ -715,15 +736,26 @@ extension ByteBuffer { guard let values = self.readMultipleIntegers(as: (UInt32, UInt32, UInt32, UInt32, UInt32, UInt32).self) else { throw NFS3Error.illegalRPCTooShort } - return (NFS3Time(seconds: values.0, nanoseconds: values.1), - NFS3Time(seconds: values.2, nanoseconds: values.3), - NFS3Time(seconds: values.4, nanoseconds: values.5)) - } - - @discardableResult public mutating func write3NFS3Times(_ time1: NFS3Time, _ time2: NFS3Time, _ time3: NFS3Time) -> Int { - self.writeMultipleIntegers(time1.seconds, time1.nanoseconds, - time2.seconds, time2.nanoseconds, - time3.seconds, time3.nanoseconds) + return ( + NFS3Time(seconds: values.0, nanoseconds: values.1), + NFS3Time(seconds: values.2, nanoseconds: values.3), + NFS3Time(seconds: values.4, nanoseconds: values.5) + ) + } + + @discardableResult public mutating func write3NFS3Times( + _ time1: NFS3Time, + _ time2: NFS3Time, + _ time3: NFS3Time + ) -> Int { + self.writeMultipleIntegers( + time1.seconds, + time1.nanoseconds, + time2.seconds, + time2.nanoseconds, + time3.seconds, + time3.nanoseconds + ) } public mutating func readNFS3Time() throws -> NFS3Time { @@ -745,9 +777,15 @@ extension ByteBuffer { public mutating func readNFS3FileAttr() throws -> NFS3FileAttr { let type = try self.readNFS3FileType() - guard let values = self.readMultipleIntegers(as: (UInt32, UInt32, UInt32, UInt32, NFS3Size.RawValue, - NFS3Size.RawValue, UInt64, UInt64, NFS3FileID.RawValue, - UInt32, UInt32, UInt32, UInt32, UInt32, UInt32).self) else { + guard + let values = self.readMultipleIntegers( + as: ( + UInt32, UInt32, UInt32, UInt32, NFS3Size.RawValue, + NFS3Size.RawValue, UInt64, UInt64, NFS3FileID.RawValue, + UInt32, UInt32, UInt32, UInt32, UInt32, UInt32 + ).self + ) + else { throw NFS3Error.illegalRPCTooShort } let mode = values.0 @@ -763,31 +801,42 @@ extension ByteBuffer { let mtime = NFS3Time(seconds: values.11, nanoseconds: values.12) let ctime = NFS3Time(seconds: values.13, nanoseconds: values.14) - return .init(type: type, mode: NFS3FileMode(rawValue: mode), nlink: nlink, - uid: NFS3UID(rawValue: uid), gid: NFS3GID(rawValue: gid), - size: NFS3Size(rawValue: size), used: NFS3Size(rawValue: used), - rdev: NFS3SpecData(rawValue: rdev), fsid: fsid, fileid: NFS3FileID(rawValue: fileid), - atime: atime, mtime: mtime, ctime: ctime) + return .init( + type: type, + mode: NFS3FileMode(rawValue: mode), + nlink: nlink, + uid: NFS3UID(rawValue: uid), + gid: NFS3GID(rawValue: gid), + size: NFS3Size(rawValue: size), + used: NFS3Size(rawValue: used), + rdev: NFS3SpecData(rawValue: rdev), + fsid: fsid, + fileid: NFS3FileID(rawValue: fileid), + atime: atime, + mtime: mtime, + ctime: ctime + ) } @discardableResult public mutating func writeNFS3FileAttr(_ attributes: NFS3FileAttr) -> Int { - return self.writeNFS3FileType(attributes.type) - + self.writeMultipleIntegers( - attributes.mode.rawValue, - attributes.nlink, - attributes.uid.rawValue, - attributes.gid.rawValue, - attributes.size.rawValue, - attributes.used.rawValue, - attributes.rdev.rawValue, - attributes.fsid, - attributes.fileid.rawValue, - attributes.atime.seconds, - attributes.atime.nanoseconds, - attributes.mtime.seconds, - attributes.mtime.nanoseconds, - attributes.ctime.seconds, - attributes.ctime.nanoseconds) + self.writeNFS3FileType(attributes.type) + + self.writeMultipleIntegers( + attributes.mode.rawValue, + attributes.nlink, + attributes.uid.rawValue, + attributes.gid.rawValue, + attributes.size.rawValue, + attributes.used.rawValue, + attributes.rdev.rawValue, + attributes.fsid, + attributes.fileid.rawValue, + attributes.atime.seconds, + attributes.atime.nanoseconds, + attributes.mtime.seconds, + attributes.mtime.nanoseconds, + attributes.ctime.seconds, + attributes.ctime.nanoseconds + ) } @discardableResult public mutating func writeNFS3Bool(_ bool: NFS3Bool) -> Int { @@ -807,10 +856,13 @@ extension ByteBuffer { } } - @discardableResult public mutating func writeNFS3Optional(_ value: T?, writer: (inout ByteBuffer, T) -> Int) -> Int { + @discardableResult public mutating func writeNFS3Optional( + _ value: T?, + writer: (inout ByteBuffer, T) -> Int + ) -> Int { if let value = value { return self.writeInteger(1, as: UInt32.self) - + writer(&self, value) + + writer(&self, value) } else { return self.writeInteger(0, as: UInt32.self) } @@ -850,8 +902,10 @@ extension ByteBuffer { } } - public mutating func readNFS3Result(readOkay: (inout ByteBuffer) throws -> O, - readFail: (inout ByteBuffer) throws -> F) throws -> NFS3Result { + public mutating func readNFS3Result( + readOkay: (inout ByteBuffer) throws -> O, + readFail: (inout ByteBuffer) throws -> F + ) throws -> NFS3Result { let status = try self.readNFS3Status() switch status { case .ok: diff --git a/Sources/NIONFS3/NFSTypes+Containers.swift b/Sources/NIONFS3/NFSTypes+Containers.swift index 60c5354d..e1ede881 100644 --- a/Sources/NIONFS3/NFSTypes+Containers.swift +++ b/Sources/NIONFS3/NFSTypes+Containers.swift @@ -20,10 +20,10 @@ public struct RPCNFS3ProcedureID: Hashable & Sendable { public static let mountNull: Self = .init(program: 100005, procedure: 0) public static let mountMount: Self = .init(program: 100005, procedure: 1) - public static let mountDump: Self = .init(program: 100005, procedure: 2) // unimplemented + public static let mountDump: Self = .init(program: 100005, procedure: 2) // unimplemented public static let mountUnmount: Self = .init(program: 100005, procedure: 3) - public static let mountUnmountAll: Self = .init(program: 100005, procedure: 4) // unimplemented - public static let mountExport: Self = .init(program: 100005, procedure: 5) // unimplemented + public static let mountUnmountAll: Self = .init(program: 100005, procedure: 4) // unimplemented + public static let mountExport: Self = .init(program: 100005, procedure: 5) // unimplemented // The source of truth for the values in the NFS program (`1000003`) can be found in the NFS RFC at // https://www.rfc-editor.org/rfc/rfc1813#page-28 @@ -35,15 +35,15 @@ public struct RPCNFS3ProcedureID: Hashable & Sendable { public static let nfsReadLink: Self = .init(program: 100003, procedure: 5) public static let nfsRead: Self = .init(program: 100003, procedure: 6) - public static let nfsWrite: Self = .init(program: 100003, procedure: 7) // unimplemented - public static let nfsCreate: Self = .init(program: 100003, procedure: 8) // unimplemented - public static let nfsMkDir: Self = .init(program: 100003, procedure: 9) // unimplemented - public static let nfsSymlink: Self = .init(program: 100003, procedure: 10) // unimplemented - public static let nfsMkNod: Self = .init(program: 100003, procedure: 11) // unimplemented - public static let nfsRemove: Self = .init(program: 100003, procedure: 12) // unimplemented - public static let nfsRmDir: Self = .init(program: 100003, procedure: 13) // unimplemented - public static let nfsRename: Self = .init(program: 100003, procedure: 14) // unimplemented - public static let nfsLink: Self = .init(program: 100003, procedure: 15) // unimplemented + public static let nfsWrite: Self = .init(program: 100003, procedure: 7) // unimplemented + public static let nfsCreate: Self = .init(program: 100003, procedure: 8) // unimplemented + public static let nfsMkDir: Self = .init(program: 100003, procedure: 9) // unimplemented + public static let nfsSymlink: Self = .init(program: 100003, procedure: 10) // unimplemented + public static let nfsMkNod: Self = .init(program: 100003, procedure: 11) // unimplemented + public static let nfsRemove: Self = .init(program: 100003, procedure: 12) // unimplemented + public static let nfsRmDir: Self = .init(program: 100003, procedure: 13) // unimplemented + public static let nfsRename: Self = .init(program: 100003, procedure: 14) // unimplemented + public static let nfsLink: Self = .init(program: 100003, procedure: 15) // unimplemented public static let nfsReadDir: Self = .init(program: 100003, procedure: 16) public static let nfsReadDirPlus: Self = .init(program: 100003, procedure: 17) @@ -51,7 +51,7 @@ public struct RPCNFS3ProcedureID: Hashable & Sendable { public static let nfsFSInfo: Self = .init(program: 100003, procedure: 19) public static let nfsPathConf: Self = .init(program: 100003, procedure: 20) - public static let nfsCommit: Self = .init(program: 100003, procedure: 21) // unimplemented + public static let nfsCommit: Self = .init(program: 100003, procedure: 21) // unimplemented } extension RPCNFS3ProcedureID { @@ -148,13 +148,13 @@ public enum NFS3Error: Error { } internal func nfsStringFillBytes(_ byteCount: Int) -> Int { - return (4 - (byteCount % 4)) % 4 + (4 - (byteCount % 4)) % 4 } extension ByteBuffer { public mutating func readRPCVerifier() throws -> RPCOpaqueAuth { guard let (flavor, length) = self.readMultipleIntegers(as: (UInt32, UInt32).self) else { - throw NFS3Error.illegalRPCTooShort + throw NFS3Error.illegalRPCTooShort } guard (flavor == RPCAuthFlavor.system.rawValue || flavor == RPCAuthFlavor.noAuth.rawValue) && length == 0 else { throw RPCErrors.unknownVerifier(flavor) @@ -181,8 +181,8 @@ extension ByteBuffer { } @discardableResult public mutating func writeRPCCredentials(_ credentials: RPCCredentials) -> Int { - return self.writeInteger(credentials.flavor) - + self.writeNFS3Blob(credentials.otherBytes) + self.writeInteger(credentials.flavor) + + self.writeNFS3Blob(credentials.otherBytes) } public mutating func readRPCFragmentHeader() throws -> RPCFragmentHeader? { @@ -196,7 +196,7 @@ extension ByteBuffer { @discardableResult public mutating func setRPCFragmentHeader(_ header: RPCFragmentHeader, at index: Int) -> Int { - return self.setInteger(header.rawValue, at: index) + self.setInteger(header.rawValue, at: index) } @discardableResult public mutating func writeRPCFragmentHeader(_ header: RPCFragmentHeader) -> Int { @@ -208,41 +208,48 @@ extension ByteBuffer { mutating func readRPCReply(xid: UInt32) throws -> RPCReply { let acceptedOrDenied = try self.readNFS3Integer(as: UInt32.self) switch acceptedOrDenied { - case 0: // MSG_ACCEPTED + case 0: // MSG_ACCEPTED let verifier = try self.readRPCVerifier() let status = try self.readNFS3Integer(as: UInt32.self) let acceptedReplyStatus: RPCAcceptedReplyStatus switch status { - case 0: // SUCCESS + case 0: // SUCCESS acceptedReplyStatus = .success - case 1: //PROG_UNAVAIL + case 1: //PROG_UNAVAIL acceptedReplyStatus = .programUnavailable - case 2: //PROG_MISMATCH + case 2: //PROG_MISMATCH guard let values = self.readMultipleIntegers(as: (UInt32, UInt32).self) else { throw NFS3Error.illegalRPCTooShort } acceptedReplyStatus = .programMismatch(low: values.0, high: values.1) - case 3: //PROC_UNAVAIL + case 3: //PROC_UNAVAIL acceptedReplyStatus = .procedureUnavailable - case 4: //GARBAGE_ARGS + case 4: //GARBAGE_ARGS acceptedReplyStatus = .garbageArguments - case 5: //SYSTEM_ERR + case 5: //SYSTEM_ERR acceptedReplyStatus = .systemError default: throw RPCErrors.illegalReplyAcceptanceStatus(status) } - return RPCReply(xid: xid, status: .messageAccepted(.init(verifier: verifier, - status: acceptedReplyStatus))) - case 1: // MSG_DENIED + return RPCReply( + xid: xid, + status: .messageAccepted( + .init( + verifier: verifier, + status: acceptedReplyStatus + ) + ) + ) + case 1: // MSG_DENIED let rejectionKind = try self.readNFS3Integer(as: UInt32.self) switch rejectionKind { - case 0: // RPC_MISMATCH: RPC version number != 2 + case 0: // RPC_MISMATCH: RPC version number != 2 guard let values = self.readMultipleIntegers(as: (UInt32, UInt32).self) else { throw NFS3Error.illegalRPCTooShort } return RPCReply(xid: xid, status: .messageDenied(.rpcMismatch(low: values.0, high: values.1))) - case 1: // AUTH_ERROR + case 1: // AUTH_ERROR let rawValue = try self.readNFS3Integer(as: UInt32.self) if let value = RPCAuthStatus(rawValue: rawValue) { return RPCReply(xid: xid, status: .messageDenied(.authError(value))) @@ -258,14 +265,15 @@ extension ByteBuffer { } @discardableResult public mutating func writeRPCCall(_ call: RPCCall) -> Int { - return self.writeMultipleIntegers( + self.writeMultipleIntegers( RPCMessageType.call.rawValue, call.rpcVersion, call.program, call.programVersion, - call.procedure) - + self.writeRPCCredentials(call.credentials) - + self.writeRPCVerifier(call.verifier) + call.procedure + ) + + self.writeRPCCredentials(call.credentials) + + self.writeRPCVerifier(call.verifier) } @discardableResult public mutating func writeRPCReply(_ reply: RPCReply) -> Int { @@ -273,17 +281,17 @@ extension ByteBuffer { switch reply.status { case .messageAccepted(_): - bytesWritten += self.writeInteger(0 /* accepted */, as: UInt32.self) + bytesWritten += self.writeInteger(0, as: UInt32.self) // 0 -> accepted case .messageDenied(_): // FIXME: MSG_DENIED (spec name) isn't actually handled correctly here. - bytesWritten += self.writeInteger(1 /* denied */, as: UInt32.self) + bytesWritten += self.writeInteger(1, as: UInt32.self) // 1 -> denied } - bytesWritten += self.writeInteger(0 /* verifier */, as: UInt64.self) - + self.writeInteger(0 /* executed successfully */, as: UInt32.self) + bytesWritten += + self.writeInteger(0, as: UInt64.self) // 0 -> verifier + + self.writeInteger(0, as: UInt32.self) // 0 -> executed successfully return bytesWritten } - public mutating func readRPCCall(xid: UInt32) throws -> RPCCall { guard let values = self.readMultipleIntegers(as: (UInt32, UInt32, UInt32, UInt32).self) else { throw NFS3Error.illegalRPCTooShort @@ -297,16 +305,21 @@ extension ByteBuffer { throw RPCErrors.unknownVersion(version) } - return RPCCall(xid: xid, - rpcVersion: version, - program: program, - programVersion: programVersion, - procedure: procedure, - credentials: credentials, - verifier: verifier) + return RPCCall( + xid: xid, + rpcVersion: version, + program: program, + programVersion: programVersion, + procedure: procedure, + credentials: credentials, + verifier: verifier + ) } - public mutating func readNFS3Reply(programAndProcedure: RPCNFS3ProcedureID, rpcReply: RPCReply) throws -> RPCNFS3Reply { + public mutating func readNFS3Reply( + programAndProcedure: RPCNFS3ProcedureID, + rpcReply: RPCReply + ) throws -> RPCNFS3Reply { switch programAndProcedure { case .mountNull: return .init(rpcReply: rpcReply, nfsReply: .mountNull) @@ -382,20 +395,20 @@ extension ByteBuffer { @discardableResult public mutating func writeRPCNFS3Call(_ rpcNFS3Call: RPCNFS3Call) -> Int { let startWriterIndex = self.writerIndex - self.writeRPCFragmentHeader(.init(length: 12345678, last: false)) // placeholder, overwritten later + self.writeRPCFragmentHeader(.init(length: 12_345_678, last: false)) // placeholder, overwritten later self.writeInteger(rpcNFS3Call.rpcCall.xid) self.writeRPCCall(rpcNFS3Call.rpcCall) switch rpcNFS3Call.nfsCall { case .mountNull: - () // noop + () // noop case .mount(let nfsCallMount): self.writeNFS3CallMount(nfsCallMount) case .unmount(let nfsCallUnmount): self.writeNFS3CallUnmount(nfsCallUnmount) case .null: - () // noop + () // noop case .getattr(let nfsCallGetAttr): self.writeNFS3CallGetattr(nfsCallGetAttr) case .fsinfo(let nfsCallFSInfo): @@ -423,30 +436,36 @@ extension ByteBuffer { preconditionFailure("unknown NFS3 call, this should never happen. Please report a bug.") } - self.setRPCFragmentHeader(.init(length: UInt32(self.writerIndex - startWriterIndex - 4), - last: true), - at: startWriterIndex) + self.setRPCFragmentHeader( + .init( + length: UInt32(self.writerIndex - startWriterIndex - 4), + last: true + ), + at: startWriterIndex + ) return self.writerIndex - startWriterIndex } - @discardableResult public mutating func writeRPCNFS3ReplyPartially(_ rpcNFS3Reply: RPCNFS3Reply) -> (Int, NFS3PartialWriteNextStep) { + @discardableResult public mutating func writeRPCNFS3ReplyPartially( + _ rpcNFS3Reply: RPCNFS3Reply + ) -> (Int, NFS3PartialWriteNextStep) { var nextStep: NFS3PartialWriteNextStep = .doNothing let startWriterIndex = self.writerIndex - self.writeRPCFragmentHeader(.init(length: 12345678, last: false)) // placeholder, overwritten later + self.writeRPCFragmentHeader(.init(length: 12_345_678, last: false)) // placeholder, overwritten later self.writeInteger(rpcNFS3Reply.rpcReply.xid) self.writeRPCReply(rpcNFS3Reply.rpcReply) switch rpcNFS3Reply.nfsReply { case .mountNull: - () // noop + () // noop case .mount(let nfsReplyMount): self.writeNFS3ReplyMount(nfsReplyMount) case .unmount(let nfsReplyUnmount): self.writeNFS3ReplyUnmount(nfsReplyUnmount) case .null: - () // noop + () // noop case .getattr(let nfsReplyGetAttr): self.writeNFS3ReplyGetAttr(nfsReplyGetAttr) case .fsinfo(let nfsReplyFSInfo): @@ -474,9 +493,13 @@ extension ByteBuffer { preconditionFailure("unknown NFS3 reply, this should never happen. Please report a bug.") } - self.setRPCFragmentHeader(.init(length: UInt32(self.writerIndex - startWriterIndex - 4 + nextStep.bytesToFollow), - last: true), - at: startWriterIndex) + self.setRPCFragmentHeader( + .init( + length: UInt32(self.writerIndex - startWriterIndex - 4 + nextStep.bytesToFollow), + last: true + ), + at: startWriterIndex + ) return (self.writerIndex - startWriterIndex, nextStep) } @@ -488,16 +511,17 @@ extension ByteBuffer { return bytesWritten case .writeBlob(let buffer, numberOfFillBytes: let fillBytes): return bytesWritten - &+ self.writeImmutableBuffer(buffer) - &+ self.writeRepeatingByte(0x41, count: fillBytes) + &+ self.writeImmutableBuffer(buffer) + &+ self.writeRepeatingByte(0x41, count: fillBytes) } } public mutating func readRPCMessage() throws -> (RPCMessage, ByteBuffer)? { let save = self guard let fragmentHeader = try self.readRPCFragmentHeader(), - let xid = self.readInteger(as: UInt32.self), - let messageType = self.readInteger(as: UInt32.self) else { + let xid = self.readInteger(as: UInt32.self), + let messageType = self.readInteger(as: UInt32.self) + else { self = save return nil } diff --git a/Sources/NIONFS3/NFSTypes+FSInfo.swift b/Sources/NIONFS3/NFSTypes+FSInfo.swift index 26aba1ce..4f8f4a26 100644 --- a/Sources/NIONFS3/NFSTypes+FSInfo.swift +++ b/Sources/NIONFS3/NFSTypes+FSInfo.swift @@ -45,13 +45,19 @@ public struct NFS3ReplyFSInfo: Hashable & Sendable { } public struct Okay: Hashable & Sendable { - public init(attributes: NFS3FileAttr?, - rtmax: UInt32, rtpref: UInt32, rtmult: UInt32, - wtmax: UInt32, wtpref: UInt32, wtmult: UInt32, - dtpref: UInt32, - maxFileSize: NFS3Size, - timeDelta: NFS3Time, - properties: NFS3ReplyFSInfo.Properties) { + public init( + attributes: NFS3FileAttr?, + rtmax: UInt32, + rtpref: UInt32, + rtmult: UInt32, + wtmax: UInt32, + wtpref: UInt32, + wtmult: UInt32, + dtpref: UInt32, + maxFileSize: NFS3Size, + timeDelta: NFS3Time, + properties: NFS3ReplyFSInfo.Properties + ) { self.attributes = attributes self.rtmax = rtmax self.rtpref = rtpref @@ -109,18 +115,20 @@ extension ByteBuffer { switch reply.result { case .okay(let reply): - bytesWritten += self.writeNFS3Optional(reply.attributes, writer: { $0.writeNFS3FileAttr($1) }) - + self.writeMultipleIntegers( - reply.rtmax, - reply.rtpref, - reply.rtmult, - reply.wtmax, - reply.wtpref, - reply.wtmult, - reply.dtpref, - reply.maxFileSize.rawValue) - + self.writeNFS3Time(reply.timeDelta) - + self.writeInteger(reply.properties.rawValue) + bytesWritten += + self.writeNFS3Optional(reply.attributes, writer: { $0.writeNFS3FileAttr($1) }) + + self.writeMultipleIntegers( + reply.rtmax, + reply.rtpref, + reply.rtmult, + reply.wtmax, + reply.wtpref, + reply.wtmult, + reply.dtpref, + reply.maxFileSize.rawValue + ) + + self.writeNFS3Time(reply.timeDelta) + + self.writeInteger(reply.properties.rawValue) case .fail(_, let fail): bytesWritten += self.writeNFS3Optional(fail.attributes, writer: { $0.writeNFS3FileAttr($1) }) } @@ -129,7 +137,8 @@ extension ByteBuffer { private mutating func readNFS3ReplyFSInfoOkay() throws -> NFS3ReplyFSInfo.Okay { let fileAttr = try self.readNFS3Optional { try $0.readNFS3FileAttr() } - guard let values = self.readMultipleIntegers(as: (UInt32, UInt32, UInt32, UInt32, UInt32, UInt32, UInt32).self) else { + guard let values = self.readMultipleIntegers(as: (UInt32, UInt32, UInt32, UInt32, UInt32, UInt32, UInt32).self) + else { throw NFS3Error.illegalRPCTooShort } let rtmax = values.0 @@ -143,17 +152,27 @@ extension ByteBuffer { let timeDelta = try self.readNFS3Time() let properties = try self.readNFS3CallFSInfoProperties() - return .init(attributes: fileAttr, - rtmax: rtmax, rtpref: rtpref, rtmult: rtmult, - wtmax: wtmax, wtpref: wtpref, wtmult: wtmult, - dtpref: dtpref, - maxFileSize: maxFileSize, timeDelta: timeDelta, properties: properties) + return .init( + attributes: fileAttr, + rtmax: rtmax, + rtpref: rtpref, + rtmult: rtmult, + wtmax: wtmax, + wtpref: wtpref, + wtmult: wtmult, + dtpref: dtpref, + maxFileSize: maxFileSize, + timeDelta: timeDelta, + properties: properties + ) } public mutating func readNFS3ReplyFSInfo() throws -> NFS3ReplyFSInfo { - return NFS3ReplyFSInfo(result: try self.readNFS3Result( - readOkay: { try $0.readNFS3ReplyFSInfoOkay() }, - readFail: { NFS3ReplyFSInfo.Fail(attributes: try $0.readNFS3FileAttr()) } - )) + NFS3ReplyFSInfo( + result: try self.readNFS3Result( + readOkay: { try $0.readNFS3ReplyFSInfoOkay() }, + readFail: { NFS3ReplyFSInfo.Fail(attributes: try $0.readNFS3FileAttr()) } + ) + ) } } diff --git a/Sources/NIONFS3/NFSTypes+FSStat.swift b/Sources/NIONFS3/NFSTypes+FSStat.swift index 5a6b6d94..16e778f3 100644 --- a/Sources/NIONFS3/NFSTypes+FSStat.swift +++ b/Sources/NIONFS3/NFSTypes+FSStat.swift @@ -29,10 +29,16 @@ public struct NFS3ReplyFSStat: Hashable & Sendable { } public struct Okay: Hashable & Sendable { - public init(attributes: NFS3FileAttr?, - tbytes: NFS3Size, fbytes: NFS3Size, abytes: NFS3Size, - tfiles: NFS3Size, ffiles: NFS3Size, afiles: NFS3Size, - invarsec: UInt32) { + public init( + attributes: NFS3FileAttr?, + tbytes: NFS3Size, + fbytes: NFS3Size, + abytes: NFS3Size, + tfiles: NFS3Size, + ffiles: NFS3Size, + afiles: NFS3Size, + invarsec: UInt32 + ) { self.attributes = attributes self.tbytes = tbytes self.fbytes = fbytes @@ -79,21 +85,23 @@ extension ByteBuffer { try buffer.readNFS3FileAttr() } if let values = self.readMultipleIntegers(as: (UInt64, UInt64, UInt64, UInt64, UInt64, UInt64, UInt32).self) { - return .init(attributes: attrs, - tbytes: NFS3Size(rawValue: values.0), - fbytes: NFS3Size(rawValue: values.1), - abytes: NFS3Size(rawValue: values.2), - tfiles: NFS3Size(rawValue: values.3), - ffiles: NFS3Size(rawValue: values.4), - afiles: NFS3Size(rawValue: values.5), - invarsec: values.6) + return .init( + attributes: attrs, + tbytes: NFS3Size(rawValue: values.0), + fbytes: NFS3Size(rawValue: values.1), + abytes: NFS3Size(rawValue: values.2), + tfiles: NFS3Size(rawValue: values.3), + ffiles: NFS3Size(rawValue: values.4), + afiles: NFS3Size(rawValue: values.5), + invarsec: values.6 + ) } else { throw NFS3Error.illegalRPCTooShort } } public mutating func readNFS3ReplyFSStat() throws -> NFS3ReplyFSStat { - return NFS3ReplyFSStat( + NFS3ReplyFSStat( result: try self.readNFS3Result( readOkay: { buffer in try buffer.readNFS3ReplyFSStatOkay() @@ -104,7 +112,8 @@ extension ByteBuffer { try buffer.readNFS3FileAttr() } ) - }) + } + ) ) } @@ -113,15 +122,17 @@ extension ByteBuffer { switch reply.result { case .okay(let okay): - bytesWritten += self.writeNFS3Optional(okay.attributes, writer: { $0.writeNFS3FileAttr($1) }) - + self.writeMultipleIntegers( - okay.tbytes.rawValue, - okay.fbytes.rawValue, - okay.abytes.rawValue, - okay.tfiles.rawValue, - okay.ffiles.rawValue, - okay.afiles.rawValue, - okay.invarsec) + bytesWritten += + self.writeNFS3Optional(okay.attributes, writer: { $0.writeNFS3FileAttr($1) }) + + self.writeMultipleIntegers( + okay.tbytes.rawValue, + okay.fbytes.rawValue, + okay.abytes.rawValue, + okay.tfiles.rawValue, + okay.ffiles.rawValue, + okay.afiles.rawValue, + okay.invarsec + ) case .fail(_, let fail): bytesWritten += self.writeNFS3Optional(fail.attributes, writer: { $0.writeNFS3FileAttr($1) }) } diff --git a/Sources/NIONFS3/NFSTypes+Getattr.swift b/Sources/NIONFS3/NFSTypes+Getattr.swift index 6863a0c0..7f579a60 100644 --- a/Sources/NIONFS3/NFSTypes+Getattr.swift +++ b/Sources/NIONFS3/NFSTypes+Getattr.swift @@ -50,14 +50,15 @@ extension ByteBuffer { } public mutating func readNFS3ReplyGetAttr() throws -> NFS3ReplyGetAttr { - return NFS3ReplyGetAttr( + NFS3ReplyGetAttr( result: try self.readNFS3Result( readOkay: { buffer in - return NFS3ReplyGetAttr.Okay(attributes: try buffer.readNFS3FileAttr()) + NFS3ReplyGetAttr.Okay(attributes: try buffer.readNFS3FileAttr()) }, readFail: { _ in - return NFS3Nothing() - }) + NFS3Nothing() + } + ) ) } diff --git a/Sources/NIONFS3/NFSTypes+Lookup.swift b/Sources/NIONFS3/NFSTypes+Lookup.swift index c81170a0..45176e7d 100644 --- a/Sources/NIONFS3/NFSTypes+Lookup.swift +++ b/Sources/NIONFS3/NFSTypes+Lookup.swift @@ -61,12 +61,12 @@ extension ByteBuffer { } @discardableResult public mutating func writeNFS3CallLookup(_ call: NFS3CallLookup) -> Int { - return self.writeNFS3FileHandle(call.dir) - + self.writeNFS3String(call.name) + self.writeNFS3FileHandle(call.dir) + + self.writeNFS3String(call.name) } public mutating func readNFS3ReplyLookup() throws -> NFS3ReplyLookup { - return NFS3ReplyLookup( + NFS3ReplyLookup( result: try self.readNFS3Result( readOkay: { buffer in let fileHandle = try buffer.readNFS3FileHandle() @@ -84,7 +84,8 @@ extension ByteBuffer { try buffer.readNFS3FileAttr() } return NFS3ReplyLookup.Fail(dirAttributes: attrs) - }) + } + ) ) } @@ -93,17 +94,20 @@ extension ByteBuffer { switch lookupResult.result { case .okay(let result): - bytesWritten += self.writeInteger(NFS3Status.ok.rawValue) - + self.writeNFS3FileHandle(result.fileHandle) + bytesWritten += + self.writeInteger(NFS3Status.ok.rawValue) + + self.writeNFS3FileHandle(result.fileHandle) if let attrs = result.attributes { - bytesWritten += self.writeInteger(1, as: UInt32.self) - + self.writeNFS3FileAttr(attrs) + bytesWritten += + self.writeInteger(1, as: UInt32.self) + + self.writeNFS3FileAttr(attrs) } else { bytesWritten += self.writeInteger(0, as: UInt32.self) } if let attrs = result.dirAttributes { - bytesWritten += self.writeInteger(1, as: UInt32.self) - + self.writeNFS3FileAttr(attrs) + bytesWritten += + self.writeInteger(1, as: UInt32.self) + + self.writeNFS3FileAttr(attrs) } else { bytesWritten += self.writeInteger(0, as: UInt32.self) } @@ -111,8 +115,9 @@ extension ByteBuffer { precondition(status != .ok) bytesWritten += self.writeInteger(status.rawValue) if let attrs = fail.dirAttributes { - bytesWritten += self.writeInteger(1, as: UInt32.self) - + self.writeNFS3FileAttr(attrs) + bytesWritten += + self.writeInteger(1, as: UInt32.self) + + self.writeNFS3FileAttr(attrs) } else { bytesWritten += self.writeInteger(0, as: UInt32.self) } diff --git a/Sources/NIONFS3/NFSTypes+Null.swift b/Sources/NIONFS3/NFSTypes+Null.swift index da0703fd..5fe28200 100644 --- a/Sources/NIONFS3/NFSTypes+Null.swift +++ b/Sources/NIONFS3/NFSTypes+Null.swift @@ -21,10 +21,10 @@ public struct NFS3CallNull: Hashable & Sendable { extension ByteBuffer { public mutating func readNFS3CallNull() throws -> NFS3CallNull { - return NFS3CallNull() + NFS3CallNull() } @discardableResult public mutating func writeNFS3CallNull(_ call: NFS3CallNull) -> Int { - return 0 + 0 } } diff --git a/Sources/NIONFS3/NFSTypes+PathConf.swift b/Sources/NIONFS3/NFSTypes+PathConf.swift index afb0fd70..bbb418f0 100644 --- a/Sources/NIONFS3/NFSTypes+PathConf.swift +++ b/Sources/NIONFS3/NFSTypes+PathConf.swift @@ -29,7 +29,15 @@ public struct NFS3ReplyPathConf: Hashable & Sendable { } public struct Okay: Hashable & Sendable { - public init(attributes: NFS3FileAttr?, linkMax: UInt32, nameMax: UInt32, noTrunc: NFS3Bool, chownRestricted: NFS3Bool, caseInsensitive: NFS3Bool, casePreserving: NFS3Bool) { + public init( + attributes: NFS3FileAttr?, + linkMax: UInt32, + nameMax: UInt32, + noTrunc: NFS3Bool, + chownRestricted: NFS3Bool, + caseInsensitive: NFS3Bool, + casePreserving: NFS3Bool + ) { self.attributes = attributes self.linkMax = linkMax self.nameMax = nameMax @@ -70,30 +78,37 @@ extension ByteBuffer { } public mutating func readNFS3ReplyPathConf() throws -> NFS3ReplyPathConf { - return NFS3ReplyPathConf( + NFS3ReplyPathConf( result: try self.readNFS3Result( readOkay: { buffer in let attrs = try buffer.readNFS3Optional { buffer in try buffer.readNFS3FileAttr() } - guard let values = buffer.readMultipleIntegers(as: (UInt32, UInt32, UInt32, UInt32, UInt32, UInt32).self) else { + guard + let values = buffer.readMultipleIntegers( + as: (UInt32, UInt32, UInt32, UInt32, UInt32, UInt32).self + ) + else { throw NFS3Error.illegalRPCTooShort } - return NFS3ReplyPathConf.Okay(attributes: attrs, - linkMax: values.0, - nameMax: values.1, - noTrunc: values.2 == 0 ? false : true, - chownRestricted: values.3 == 0 ? false : true, - caseInsensitive: values.4 == 0 ? false : true, - casePreserving: values.5 == 0 ? false : true) + return NFS3ReplyPathConf.Okay( + attributes: attrs, + linkMax: values.0, + nameMax: values.1, + noTrunc: values.2 == 0 ? false : true, + chownRestricted: values.3 == 0 ? false : true, + caseInsensitive: values.4 == 0 ? false : true, + casePreserving: values.5 == 0 ? false : true + ) }, readFail: { buffer in let attrs = try buffer.readNFS3Optional { buffer in try buffer.readNFS3FileAttr() } return NFS3ReplyPathConf.Fail(attributes: attrs) - }) + } + ) ) } @@ -102,15 +117,16 @@ extension ByteBuffer { switch pathconf.result { case .okay(let pathconf): - bytesWritten += self.writeNFS3Optional(pathconf.attributes, writer: { $0.writeNFS3FileAttr($1) }) - + self.writeMultipleIntegers( - pathconf.linkMax, - pathconf.nameMax, - pathconf.noTrunc ? UInt32(1) : 0, - pathconf.chownRestricted ? UInt32(1) : 0, - pathconf.caseInsensitive ? UInt32(1) : 0, - pathconf.casePreserving ? UInt32(1) : 0 - ) + bytesWritten += + self.writeNFS3Optional(pathconf.attributes, writer: { $0.writeNFS3FileAttr($1) }) + + self.writeMultipleIntegers( + pathconf.linkMax, + pathconf.nameMax, + pathconf.noTrunc ? UInt32(1) : 0, + pathconf.chownRestricted ? UInt32(1) : 0, + pathconf.caseInsensitive ? UInt32(1) : 0, + pathconf.casePreserving ? UInt32(1) : 0 + ) case .fail(_, let fail): bytesWritten += self.writeNFS3Optional(fail.attributes, writer: { $0.writeNFS3FileAttr($1) }) } diff --git a/Sources/NIONFS3/NFSTypes+Read.swift b/Sources/NIONFS3/NFSTypes+Read.swift index c43b3940..5753d8b5 100644 --- a/Sources/NIONFS3/NFSTypes+Read.swift +++ b/Sources/NIONFS3/NFSTypes+Read.swift @@ -64,18 +64,20 @@ extension ByteBuffer { throw NFS3Error.illegalRPCTooShort } - return NFS3CallRead(fileHandle: fileHandle, - offset: .init(rawValue: values.0), - count: .init(rawValue: values.1)) + return NFS3CallRead( + fileHandle: fileHandle, + offset: .init(rawValue: values.0), + count: .init(rawValue: values.1) + ) } @discardableResult public mutating func writeNFS3CallRead(_ call: NFS3CallRead) -> Int { - return self.writeNFS3FileHandle(call.fileHandle) - + self.writeMultipleIntegers(call.offset.rawValue, call.count.rawValue) + self.writeNFS3FileHandle(call.fileHandle) + + self.writeMultipleIntegers(call.offset.rawValue, call.count.rawValue) } public mutating func readNFS3ReplyRead() throws -> NFS3ReplyRead { - return NFS3ReplyRead( + NFS3ReplyRead( result: try self.readNFS3Result( readOkay: { buffer in let attrs = try buffer.readNFS3Optional { buffer in @@ -85,17 +87,20 @@ extension ByteBuffer { throw NFS3Error.illegalRPCTooShort } let bytes = try buffer.readNFS3Blob() - return NFS3ReplyRead.Okay(attributes: attrs, - count: NFS3Count(rawValue: values.0), - eof: values.1 == 0 ? false : true, - data: bytes) + return NFS3ReplyRead.Okay( + attributes: attrs, + count: NFS3Count(rawValue: values.0), + eof: values.1 == 0 ? false : true, + data: bytes + ) }, readFail: { buffer in let attrs = try buffer.readNFS3Optional { buffer in try buffer.readNFS3FileAttr() } return NFS3ReplyRead.Fail(attributes: attrs) - }) + } + ) ) } @@ -124,7 +129,7 @@ extension ByteBuffer { return 0 case .writeBlob(let blob, numberOfFillBytes: let fillBytes): return self.writeImmutableBuffer(blob) - + self.writeRepeatingByte(0x41, count: fillBytes) + + self.writeRepeatingByte(0x41, count: fillBytes) } } } diff --git a/Sources/NIONFS3/NFSTypes+ReadDir.swift b/Sources/NIONFS3/NFSTypes+ReadDir.swift index 4b4f2205..5fa3b588 100644 --- a/Sources/NIONFS3/NFSTypes+ReadDir.swift +++ b/Sources/NIONFS3/NFSTypes+ReadDir.swift @@ -16,8 +16,12 @@ import NIOCore // MARK: - ReadDir public struct NFS3CallReadDir: Hashable & Sendable { - public init(fileHandle: NFS3FileHandle, cookie: NFS3Cookie, cookieVerifier: NFS3CookieVerifier, - maxResultByteCount: NFS3Count) { + public init( + fileHandle: NFS3FileHandle, + cookie: NFS3Cookie, + cookieVerifier: NFS3CookieVerifier, + maxResultByteCount: NFS3Count + ) { self.fileHandle = fileHandle self.cookie = cookie self.cookieVerifier = cookieVerifier @@ -48,7 +52,12 @@ public struct NFS3ReplyReadDir: Hashable & Sendable { } public struct Okay: Hashable & Sendable { - public init(dirAttributes: NFS3FileAttr? = nil, cookieVerifier: NFS3CookieVerifier, entries: [NFS3ReplyReadDir.Entry], eof: NFS3Bool) { + public init( + dirAttributes: NFS3FileAttr? = nil, + cookieVerifier: NFS3CookieVerifier, + entries: [NFS3ReplyReadDir.Entry], + eof: NFS3Bool + ) { self.dirAttributes = dirAttributes self.cookieVerifier = cookieVerifier self.entries = entries @@ -79,19 +88,21 @@ extension ByteBuffer { let cookieVerifier = try self.readNFS3CookieVerifier() let maxResultByteCount = try self.readNFS3Count() - return NFS3CallReadDir(fileHandle: dir, - cookie: cookie, - cookieVerifier: cookieVerifier, - maxResultByteCount: maxResultByteCount) + return NFS3CallReadDir( + fileHandle: dir, + cookie: cookie, + cookieVerifier: cookieVerifier, + maxResultByteCount: maxResultByteCount + ) } @discardableResult public mutating func writeNFS3CallReadDir(_ call: NFS3CallReadDir) -> Int { - return self.writeNFS3FileHandle(call.fileHandle) - + self.writeMultipleIntegers( - call.cookie.rawValue, - call.cookieVerifier.rawValue, - call.maxResultByteCount.rawValue - ) + self.writeNFS3FileHandle(call.fileHandle) + + self.writeMultipleIntegers( + call.cookie.rawValue, + call.cookieVerifier.rawValue, + call.maxResultByteCount.rawValue + ) } private mutating func readReadDirEntry() throws -> NFS3ReplyReadDir.Entry { @@ -99,19 +110,21 @@ extension ByteBuffer { let fileName = try self.readNFS3String() let cookie = try self.readNFS3Cookie() - return NFS3ReplyReadDir.Entry(fileID: fileID, - fileName: fileName, - cookie: cookie) + return NFS3ReplyReadDir.Entry( + fileID: fileID, + fileName: fileName, + cookie: cookie + ) } private mutating func writeReadDirEntry(_ entry: NFS3ReplyReadDir.Entry) -> Int { - return self.writeNFS3FileID(entry.fileID) - + self.writeNFS3String(entry.fileName) - + self.writeNFS3Cookie(entry.cookie) + self.writeNFS3FileID(entry.fileID) + + self.writeNFS3String(entry.fileName) + + self.writeNFS3Cookie(entry.cookie) } public mutating func readNFS3ReplyReadDir() throws -> NFS3ReplyReadDir { - return NFS3ReplyReadDir( + NFS3ReplyReadDir( result: try self.readNFS3Result( readOkay: { buffer in let dirAttributes = try buffer.readNFS3Optional { try $0.readNFS3FileAttr() } @@ -123,16 +136,19 @@ extension ByteBuffer { } let eof = try buffer.readNFS3Bool() - return NFS3ReplyReadDir.Okay(dirAttributes: dirAttributes, - cookieVerifier: cookieVerifier, - entries: entries, - eof: eof) + return NFS3ReplyReadDir.Okay( + dirAttributes: dirAttributes, + cookieVerifier: cookieVerifier, + entries: entries, + eof: eof + ) }, readFail: { buffer in let attrs = try buffer.readNFS3Optional { try $0.readNFS3FileAttr() } return NFS3ReplyReadDir.Fail(dirAttributes: attrs) - }) + } + ) ) } @@ -140,19 +156,23 @@ extension ByteBuffer { var bytesWritten = 0 switch rd.result { case .okay(let result): - bytesWritten += self.writeInteger(NFS3Status.ok.rawValue) - + self.writeNFS3Optional(result.dirAttributes, writer: { $0.writeNFS3FileAttr($1) }) - + self.writeNFS3CookieVerifier(result.cookieVerifier) + bytesWritten += + self.writeInteger(NFS3Status.ok.rawValue) + + self.writeNFS3Optional(result.dirAttributes, writer: { $0.writeNFS3FileAttr($1) }) + + self.writeNFS3CookieVerifier(result.cookieVerifier) for entry in result.entries { - bytesWritten += self.writeInteger(1, as: UInt32.self) - + self.writeReadDirEntry(entry) + bytesWritten += + self.writeInteger(1, as: UInt32.self) + + self.writeReadDirEntry(entry) } - bytesWritten += self.writeInteger(0, as: UInt32.self) - + self.writeInteger(result.eof == true ? 1 : 0, as: UInt32.self) + bytesWritten += + self.writeInteger(0, as: UInt32.self) + + self.writeInteger(result.eof == true ? 1 : 0, as: UInt32.self) case .fail(let status, let fail): precondition(status != .ok) - bytesWritten += self.writeInteger(status.rawValue) - + self.writeNFS3Optional(fail.dirAttributes, writer: { $0.writeNFS3FileAttr($1) }) + bytesWritten += + self.writeInteger(status.rawValue) + + self.writeNFS3Optional(fail.dirAttributes, writer: { $0.writeNFS3FileAttr($1) }) } return bytesWritten } diff --git a/Sources/NIONFS3/NFSTypes+ReadDirPlus.swift b/Sources/NIONFS3/NFSTypes+ReadDirPlus.swift index 17757d39..7e6160ae 100644 --- a/Sources/NIONFS3/NFSTypes+ReadDirPlus.swift +++ b/Sources/NIONFS3/NFSTypes+ReadDirPlus.swift @@ -16,8 +16,13 @@ import NIOCore // MARK: - ReadDirPlus public struct NFS3CallReadDirPlus: Hashable & Sendable { - public init(fileHandle: NFS3FileHandle, cookie: NFS3Cookie, cookieVerifier: NFS3CookieVerifier, - dirCount: NFS3Count, maxCount: NFS3Count) { + public init( + fileHandle: NFS3FileHandle, + cookie: NFS3Cookie, + cookieVerifier: NFS3CookieVerifier, + dirCount: NFS3Count, + maxCount: NFS3Count + ) { self.fileHandle = fileHandle self.cookie = cookie self.cookieVerifier = cookieVerifier @@ -38,7 +43,13 @@ public struct NFS3ReplyReadDirPlus: Hashable & Sendable { } public struct Entry: Hashable & Sendable { - public init(fileID: NFS3FileID, fileName: String, cookie: NFS3Cookie, nameAttributes: NFS3FileAttr? = nil, nameHandle: NFS3FileHandle? = nil) { + public init( + fileID: NFS3FileID, + fileName: String, + cookie: NFS3Cookie, + nameAttributes: NFS3FileAttr? = nil, + nameHandle: NFS3FileHandle? = nil + ) { self.fileID = fileID self.fileName = fileName self.cookie = cookie @@ -54,7 +65,12 @@ public struct NFS3ReplyReadDirPlus: Hashable & Sendable { } public struct Okay: Hashable & Sendable { - public init(dirAttributes: NFS3FileAttr? = nil, cookieVerifier: NFS3CookieVerifier, entries: [NFS3ReplyReadDirPlus.Entry], eof: NFS3Bool) { + public init( + dirAttributes: NFS3FileAttr? = nil, + cookieVerifier: NFS3CookieVerifier, + entries: [NFS3ReplyReadDirPlus.Entry], + eof: NFS3Bool + ) { self.dirAttributes = dirAttributes self.cookieVerifier = cookieVerifier self.entries = entries @@ -86,21 +102,23 @@ extension ByteBuffer { let dirCount = try self.readNFS3Count() let maxCount = try self.readNFS3Count() - return NFS3CallReadDirPlus(fileHandle: dir, - cookie: cookie, - cookieVerifier: cookieVerifier, - dirCount: dirCount, - maxCount: maxCount) + return NFS3CallReadDirPlus( + fileHandle: dir, + cookie: cookie, + cookieVerifier: cookieVerifier, + dirCount: dirCount, + maxCount: maxCount + ) } @discardableResult public mutating func writeNFS3CallReadDirPlus(_ call: NFS3CallReadDirPlus) -> Int { - return self.writeNFS3FileHandle(call.fileHandle) - + self.writeMultipleIntegers( - call.cookie.rawValue, - call.cookieVerifier.rawValue, - call.dirCount.rawValue, - call.maxCount.rawValue - ) + self.writeNFS3FileHandle(call.fileHandle) + + self.writeMultipleIntegers( + call.cookie.rawValue, + call.cookieVerifier.rawValue, + call.dirCount.rawValue, + call.maxCount.rawValue + ) } private mutating func readReadDirPlusEntry() throws -> NFS3ReplyReadDirPlus.Entry { @@ -110,23 +128,25 @@ extension ByteBuffer { let nameAttrs = try self.readNFS3Optional { try $0.readNFS3FileAttr() } let nameHandle = try self.readNFS3Optional { try $0.readNFS3FileHandle() } - return NFS3ReplyReadDirPlus.Entry(fileID: fileID, - fileName: fileName, - cookie: cookie, - nameAttributes: nameAttrs, - nameHandle: nameHandle) + return NFS3ReplyReadDirPlus.Entry( + fileID: fileID, + fileName: fileName, + cookie: cookie, + nameAttributes: nameAttrs, + nameHandle: nameHandle + ) } private mutating func writeReadDirPlusEntry(_ entry: NFS3ReplyReadDirPlus.Entry) -> Int { - return self.writeNFS3FileID(entry.fileID) - + self.writeNFS3String(entry.fileName) - + self.writeNFS3Cookie(entry.cookie) - + self.writeNFS3Optional(entry.nameAttributes, writer: { $0.writeNFS3FileAttr($1) }) - + self.writeNFS3Optional(entry.nameHandle, writer: { $0.writeNFS3FileHandle($1) }) + self.writeNFS3FileID(entry.fileID) + + self.writeNFS3String(entry.fileName) + + self.writeNFS3Cookie(entry.cookie) + + self.writeNFS3Optional(entry.nameAttributes, writer: { $0.writeNFS3FileAttr($1) }) + + self.writeNFS3Optional(entry.nameHandle, writer: { $0.writeNFS3FileHandle($1) }) } public mutating func readNFS3ReplyReadDirPlus() throws -> NFS3ReplyReadDirPlus { - return NFS3ReplyReadDirPlus( + NFS3ReplyReadDirPlus( result: try self.readNFS3Result( readOkay: { buffer in let attrs = try buffer.readNFS3Optional { try $0.readNFS3FileAttr() } @@ -138,16 +158,19 @@ extension ByteBuffer { } let eof = try buffer.readNFS3Bool() - return NFS3ReplyReadDirPlus.Okay(dirAttributes: attrs, - cookieVerifier: cookieVerifier, - entries: entries, - eof: eof) + return NFS3ReplyReadDirPlus.Okay( + dirAttributes: attrs, + cookieVerifier: cookieVerifier, + entries: entries, + eof: eof + ) }, readFail: { buffer in let attrs = try buffer.readNFS3Optional { try $0.readNFS3FileAttr() } return NFS3ReplyReadDirPlus.Fail(dirAttributes: attrs) - }) + } + ) ) } @@ -156,19 +179,23 @@ extension ByteBuffer { switch rdp.result { case .okay(let result): - bytesWritten += self.writeInteger(NFS3Status.ok.rawValue) - + self.writeNFS3Optional(result.dirAttributes, writer: { $0.writeNFS3FileAttr($1) }) - + self.writeNFS3CookieVerifier(result.cookieVerifier) + bytesWritten += + self.writeInteger(NFS3Status.ok.rawValue) + + self.writeNFS3Optional(result.dirAttributes, writer: { $0.writeNFS3FileAttr($1) }) + + self.writeNFS3CookieVerifier(result.cookieVerifier) for entry in result.entries { - bytesWritten += self.writeInteger(1, as: UInt32.self) - + self.writeReadDirPlusEntry(entry) + bytesWritten += + self.writeInteger(1, as: UInt32.self) + + self.writeReadDirPlusEntry(entry) } - bytesWritten += self.writeInteger(0, as: UInt32.self) - + self.writeInteger(result.eof == true ? 1 : 0, as: UInt32.self) + bytesWritten += + self.writeInteger(0, as: UInt32.self) + + self.writeInteger(result.eof == true ? 1 : 0, as: UInt32.self) case .fail(let status, let fail): precondition(status != .ok) - bytesWritten += self.writeInteger(status.rawValue) - + self.writeNFS3Optional(fail.dirAttributes, writer: { $0.writeNFS3FileAttr($1) }) + bytesWritten += + self.writeInteger(status.rawValue) + + self.writeNFS3Optional(fail.dirAttributes, writer: { $0.writeNFS3FileAttr($1) }) } return bytesWritten diff --git a/Sources/NIONFS3/NFSTypes+Readlink.swift b/Sources/NIONFS3/NFSTypes+Readlink.swift index 230195e8..5ff165e0 100644 --- a/Sources/NIONFS3/NFSTypes+Readlink.swift +++ b/Sources/NIONFS3/NFSTypes+Readlink.swift @@ -61,7 +61,7 @@ extension ByteBuffer { } public mutating func readNFS3ReplyReadlink() throws -> NFS3ReplyReadlink { - return NFS3ReplyReadlink( + NFS3ReplyReadlink( result: try self.readNFS3Result( readOkay: { buffer in let attrs = try buffer.readNFS3Optional { try $0.readNFS3FileAttr() } @@ -72,7 +72,9 @@ extension ByteBuffer { readFail: { buffer in let attrs = try buffer.readNFS3Optional { try $0.readNFS3FileAttr() } return NFS3ReplyReadlink.Fail(symlinkAttributes: attrs) - })) + } + ) + ) } @discardableResult public mutating func writeNFS3ReplyReadlink(_ reply: NFS3ReplyReadlink) -> Int { @@ -80,8 +82,9 @@ extension ByteBuffer { switch reply.result { case .okay(let okay): - bytesWritten += self.writeNFS3Optional(okay.symlinkAttributes, writer: { $0.writeNFS3FileAttr($1) }) - + self.writeNFS3String(okay.target) + bytesWritten += + self.writeNFS3Optional(okay.symlinkAttributes, writer: { $0.writeNFS3FileAttr($1) }) + + self.writeNFS3String(okay.target) case .fail(_, let fail): bytesWritten += self.writeNFS3Optional(fail.symlinkAttributes, writer: { $0.writeNFS3FileAttr($1) }) } diff --git a/Sources/NIONFS3/NFSTypes+SetAttr.swift b/Sources/NIONFS3/NFSTypes+SetAttr.swift index 4ee43cfe..110b871e 100644 --- a/Sources/NIONFS3/NFSTypes+SetAttr.swift +++ b/Sources/NIONFS3/NFSTypes+SetAttr.swift @@ -23,7 +23,14 @@ public struct NFS3CallSetattr: Hashable & Sendable { } public struct Attributes: Hashable & Sendable { - public init(mode: NFS3FileMode? = nil, uid: NFS3UID? = nil, gid: NFS3GID? = nil, size: NFS3Size? = nil, atime: NFS3Time? = nil, mtime: NFS3Time? = nil) { + public init( + mode: NFS3FileMode? = nil, + uid: NFS3UID? = nil, + gid: NFS3GID? = nil, + size: NFS3Size? = nil, + atime: NFS3Time? = nil, + mtime: NFS3Time? = nil + ) { self.mode = mode self.uid = uid self.gid = gid @@ -82,12 +89,12 @@ extension ByteBuffer { } private mutating func writeNFS3CallSetattrAttributes(_ attrs: NFS3CallSetattr.Attributes) -> Int { - return self.writeNFS3Optional(attrs.mode, writer: { $0.writeNFS3FileMode($1) }) - + self.writeNFS3Optional(attrs.uid, writer: { $0.writeNFS3UID($1) }) - + self.writeNFS3Optional(attrs.gid, writer: { $0.writeNFS3GID($1) }) - + self.writeNFS3Optional(attrs.size, writer: { $0.writeNFS3Size($1) }) - + self.writeNFS3Optional(attrs.atime, writer: { $0.writeNFS3Time($1) }) - + self.writeNFS3Optional(attrs.mtime, writer: { $0.writeNFS3Time($1) }) + self.writeNFS3Optional(attrs.mode, writer: { $0.writeNFS3FileMode($1) }) + + self.writeNFS3Optional(attrs.uid, writer: { $0.writeNFS3UID($1) }) + + self.writeNFS3Optional(attrs.gid, writer: { $0.writeNFS3GID($1) }) + + self.writeNFS3Optional(attrs.size, writer: { $0.writeNFS3Size($1) }) + + self.writeNFS3Optional(attrs.atime, writer: { $0.writeNFS3Time($1) }) + + self.writeNFS3Optional(attrs.mtime, writer: { $0.writeNFS3Time($1) }) } public mutating func readNFS3CallSetattr() throws -> NFS3CallSetattr { @@ -99,20 +106,22 @@ extension ByteBuffer { } @discardableResult public mutating func writeNFS3CallSetattr(_ call: NFS3CallSetattr) -> Int { - return self.writeNFS3FileHandle(call.object) - + self.writeNFS3CallSetattrAttributes(call.newAttributes) - + self.writeNFS3Optional(call.guard, writer: { $0.writeNFS3Time($1) }) + self.writeNFS3FileHandle(call.object) + + self.writeNFS3CallSetattrAttributes(call.newAttributes) + + self.writeNFS3Optional(call.guard, writer: { $0.writeNFS3Time($1) }) } public mutating func readNFS3ReplySetattr() throws -> NFS3ReplySetattr { - return NFS3ReplySetattr( + NFS3ReplySetattr( result: try self.readNFS3Result( readOkay: { buffer in - return NFS3ReplySetattr.Okay(wcc: try buffer.readNFS3WeakCacheConsistencyData()) + NFS3ReplySetattr.Okay(wcc: try buffer.readNFS3WeakCacheConsistencyData()) }, readFail: { buffer in - return NFS3ReplySetattr.Fail(wcc: try buffer.readNFS3WeakCacheConsistencyData()) - })) + NFS3ReplySetattr.Fail(wcc: try buffer.readNFS3WeakCacheConsistencyData()) + } + ) + ) } @discardableResult public mutating func writeNFS3ReplySetattr(_ reply: NFS3ReplySetattr) -> Int { diff --git a/Sources/NIONFS3/RPCTypes.swift b/Sources/NIONFS3/RPCTypes.swift index 6f4aa364..a84c6a3d 100644 --- a/Sources/NIONFS3/RPCTypes.swift +++ b/Sources/NIONFS3/RPCTypes.swift @@ -71,7 +71,15 @@ public enum RPCMessage: Hashable & Sendable { /// RFC 5531: struct call_body public struct RPCCall: Hashable & Sendable { - public init(xid: UInt32, rpcVersion: UInt32, program: UInt32, programVersion: UInt32, procedure: UInt32, credentials: RPCCredentials, verifier: RPCOpaqueAuth) { + public init( + xid: UInt32, + rpcVersion: UInt32, + program: UInt32, + programVersion: UInt32, + procedure: UInt32, + credentials: RPCCredentials, + verifier: RPCOpaqueAuth + ) { self.xid = xid self.rpcVersion = rpcVersion self.program = program @@ -82,7 +90,7 @@ public struct RPCCall: Hashable & Sendable { } public var xid: UInt32 - public var rpcVersion: UInt32 // must be 2 + public var rpcVersion: UInt32 // must be 2 public var program: UInt32 public var programVersion: UInt32 public var procedure: UInt32 @@ -93,7 +101,7 @@ public struct RPCCall: Hashable & Sendable { extension RPCCall { public var programAndProcedure: RPCNFS3ProcedureID { get { - return RPCNFS3ProcedureID(program: self.program, procedure: self.procedure) + RPCNFS3ProcedureID(program: self.program, procedure: self.procedure) } set { self.program = newValue.program @@ -147,21 +155,21 @@ public struct RPCAcceptedReply: Hashable & Sendable { } public enum RPCAuthStatus: UInt32, Hashable & Sendable { - case ok = 0 /* success */ - case badCredentials = 1 /* bad credential (seal broken) */ - case rejectedCredentials = 2 /* client must begin new session */ - case badVerifier = 3 /* bad verifier (seal broken) */ - case rejectedVerifier = 4 /* verifier expired or replayed */ - case rejectedForSecurityReasons = 5 /* rejected for security reasons */ - case invalidResponseVerifier = 6 /* bogus response verifier */ - case failedForUnknownReason = 7 /* reason unknown */ - case kerberosError = 8 /* kerberos generic error */ - case credentialExpired = 9 /* time of credential expired */ - case ticketFileProblem = 10 /* problem with ticket file */ - case cannotDecodeAuthenticator = 11 /* can't decode authenticator */ - case illegalNetworkAddressInTicket = 12 /* wrong net address in ticket */ - case noCredentialsForUser = 13 /* no credentials for user */ - case problemWithGSSContext = 14 /* problem with context */ + case ok = 0 // success + case badCredentials = 1 // bad credential (seal broken) + case rejectedCredentials = 2 // client must begin new session + case badVerifier = 3 // bad verifier (seal broken) + case rejectedVerifier = 4 // verifier expired or replayed + case rejectedForSecurityReasons = 5 // rejected for security reasons + case invalidResponseVerifier = 6 // bogus response verifier + case failedForUnknownReason = 7 // reason unknown + case kerberosError = 8 // kerberos generic error + case credentialExpired = 9 // time of credential expired + case ticketFileProblem = 10 // problem with ticket file + case cannotDecodeAuthenticator = 11 // can't decode authenticator + case illegalNetworkAddressInTicket = 12 // wrong net address in ticket + case noCredentialsForUser = 13 // no credentials for user + case problemWithGSSContext = 14 // problem with context } public enum RPCRejectedReply: Hashable & Sendable { diff --git a/Sources/NIOSOCKS/Channel Handlers/SOCKSClientHandler.swift b/Sources/NIOSOCKS/Channel Handlers/SOCKSClientHandler.swift index 313bca57..45a6f33e 100644 --- a/Sources/NIOSOCKS/Channel Handlers/SOCKSClientHandler.swift +++ b/Sources/NIOSOCKS/Channel Handlers/SOCKSClientHandler.swift @@ -27,31 +27,31 @@ public final class SOCKSClientHandler: ChannelDuplexHandler { public typealias OutboundIn = ByteBuffer /// Sends `ByteBuffer` to the next outbound stage. public typealias OutboundOut = ByteBuffer - + private let targetAddress: SOCKSAddress - + private var state: ClientStateMachine private var removalToken: ChannelHandlerContext.RemovalToken? private var inboundBuffer: ByteBuffer? - + private var bufferedWrites: MarkedCircularBuffer<(NIOAny, EventLoopPromise?)> = .init(initialCapacity: 8) - + /// Creates a new ``SOCKSClientHandler`` that connects to a server /// and instructs the server to connect to `targetAddress`. /// - parameter targetAddress: The desired end point - note that only IPv4, IPv6, and FQDNs are supported. public init(targetAddress: SOCKSAddress) { - + switch targetAddress { case .address(.unixDomainSocket): preconditionFailure("UNIX domain sockets are not supported.") case .domain, .address(.v4), .address(.v6): break } - + self.state = ClientStateMachine() self.targetAddress = targetAddress } - + public func channelActive(context: ChannelHandlerContext) { self.beginHandshake(context: context) } @@ -61,17 +61,17 @@ public final class SOCKSClientHandler: ChannelDuplexHandler { public func handlerAdded(context: ChannelHandlerContext) { self.beginHandshake(context: context) } - + public func channelRead(context: ChannelHandlerContext, data: NIOAny) { - + // if we've established the connection then forward on the data if self.state.proxyEstablished { context.fireChannelRead(data) return } - + var inboundBuffer = self.unwrapInboundIn(data) - + self.inboundBuffer.setOrWriteBuffer(&inboundBuffer) do { // Safe to bang, `setOrWrite` above means there will @@ -83,7 +83,7 @@ public final class SOCKSClientHandler: ChannelDuplexHandler { context.close(promise: nil) } } - + public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { if self.state.proxyEstablished && self.bufferedWrites.count == 0 { context.write(data, promise: promise) @@ -91,7 +91,7 @@ public final class SOCKSClientHandler: ChannelDuplexHandler { self.bufferedWrites.append((data, promise)) } } - + private func writeBufferedData(context: ChannelHandlerContext) { guard self.state.proxyEstablished else { return @@ -100,14 +100,14 @@ public final class SOCKSClientHandler: ChannelDuplexHandler { let (data, promise) = self.bufferedWrites.removeFirst() context.write(data, promise: promise) } - context.flush() // safe to flush otherwise we wouldn't have the mark - + context.flush() // safe to flush otherwise we wouldn't have the mark + while !self.bufferedWrites.isEmpty { let (data, promise) = self.bufferedWrites.removeFirst() context.write(data, promise: promise) } } - + public func flush(context: ChannelHandlerContext) { self.bufferedWrites.mark() self.writeBufferedData(context: context) @@ -118,7 +118,7 @@ public final class SOCKSClientHandler: ChannelDuplexHandler { extension SOCKSClientHandler: Sendable {} extension SOCKSClientHandler { - + private func beginHandshake(context: ChannelHandlerContext) { guard context.channel.isActive, self.state.shouldBeginHandshake else { return @@ -130,11 +130,11 @@ extension SOCKSClientHandler { context.close(promise: nil) } } - + private func handleAction(_ action: ClientAction, context: ChannelHandlerContext) throws { switch action { case .waitForMoreData: - break // do nothing, we've already buffered the data + break // do nothing, we've already buffered the data case .sendGreeting: try self.handleActionSendClientGreeting(context: context) case .sendRequest: @@ -143,30 +143,30 @@ extension SOCKSClientHandler { self.handleProxyEstablished(context: context) } } - + private func handleActionSendClientGreeting(context: ChannelHandlerContext) throws { - let greeting = ClientGreeting(methods: [.noneRequired]) // no authentication currently supported - let capacity = 3 // [version, #methods, methods...] + let greeting = ClientGreeting(methods: [.noneRequired]) // no authentication currently supported + let capacity = 3 // [version, #methods, methods...] var buffer = context.channel.allocator.buffer(capacity: capacity) buffer.writeClientGreeting(greeting) try self.state.sendClientGreeting(greeting) context.writeAndFlush(self.wrapOutboundOut(buffer), promise: nil) } - + private func handleProxyEstablished(context: ChannelHandlerContext) { context.fireUserInboundEventTriggered(SOCKSProxyEstablishedEvent()) - + self.emptyInboundAndOutboundBuffer(context: context) - + if let removalToken = self.removalToken { context.leavePipeline(removalToken: removalToken) } } - + private func handleActionSendRequest(context: ChannelHandlerContext) throws { let request = SOCKSRequest(command: .connect, addressType: self.targetAddress) try self.state.sendClientRequest(request) - + // the client request is always 6 bytes + the address info // [protocol_version, command, reserved, address type,
, port (2bytes)] let capacity = 6 + self.targetAddress.size @@ -174,7 +174,7 @@ extension SOCKSClientHandler { buffer.writeClientRequest(request) context.writeAndFlush(self.wrapOutboundOut(buffer), promise: nil) } - + private func emptyInboundAndOutboundBuffer(context: ChannelHandlerContext) { if let inboundBuffer = self.inboundBuffer, inboundBuffer.readableBytes > 0 { // after the SOCKS handshake message we already received further bytes. @@ -182,20 +182,20 @@ extension SOCKSClientHandler { self.inboundBuffer = nil context.fireChannelRead(self.wrapInboundOut(inboundBuffer)) } - + // If we have any buffered writes, we must send them before we are removed from the pipeline self.writeBufferedData(context: context) } } extension SOCKSClientHandler: RemovableChannelHandler { - + public func removeHandler(context: ChannelHandlerContext, removalToken: ChannelHandlerContext.RemovalToken) { guard self.state.proxyEstablished else { self.removalToken = removalToken return } - + // We must clear the buffers here before we are removed, since the // handler removal may be triggered as a side effect of the // `SOCKSProxyEstablishedEvent`. In this case we may end up here, @@ -204,7 +204,7 @@ extension SOCKSClientHandler: RemovableChannelHandler { self.emptyInboundAndOutboundBuffer(context: context) context.leavePipeline(removalToken: removalToken) } - + } /// A `Channel` user event that is sent when a SOCKS connection has been established diff --git a/Sources/NIOSOCKS/Channel Handlers/SOCKSServerHandshakeHandler.swift b/Sources/NIOSOCKS/Channel Handlers/SOCKSServerHandshakeHandler.swift index f9b7c30a..b3f4ad6c 100644 --- a/Sources/NIOSOCKS/Channel Handlers/SOCKSServerHandshakeHandler.swift +++ b/Sources/NIOSOCKS/Channel Handlers/SOCKSServerHandshakeHandler.swift @@ -28,27 +28,27 @@ public final class SOCKSServerHandshakeHandler: ChannelDuplexHandler, RemovableC public typealias OutboundIn = ServerMessage /// Passes `ByteBuffer` to the next pipeline stage when sending data. public typealias OutboundOut = ByteBuffer - + var inboundBuffer: ByteBuffer? var stateMachine: ServerStateMachine - + public init() { self.stateMachine = ServerStateMachine() } - + public func channelRead(context: ChannelHandlerContext, data: NIOAny) { - + var message = self.unwrapInboundIn(data) self.inboundBuffer.setOrWriteBuffer(&message) - + if self.stateMachine.proxyEstablished { return } - + do { // safe to bang inbound buffer, it's always written above guard let message = try self.stateMachine.receiveBuffer(&self.inboundBuffer!) else { - return // do nothing, we've buffered the data + return // do nothing, we've buffered the data } context.fireChannelRead(self.wrapInboundOut(message)) } catch { @@ -74,7 +74,7 @@ public final class SOCKSServerHandshakeHandler: ChannelDuplexHandler, RemovableC } context.fireChannelRead(.init(buffer)) } - + public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { do { let message = self.unwrapOutboundIn(data) @@ -88,23 +88,27 @@ public final class SOCKSServerHandshakeHandler: ChannelDuplexHandler, RemovableC outboundBuffer = try self.handleWriteAuthenticationData(data, complete: complete, context: context) } context.write(self.wrapOutboundOut(outboundBuffer), promise: promise) - + } catch { context.fireErrorCaught(error) promise?.fail(error) } } - + private func handleWriteSelectedAuthenticationMethod( - _ method: SelectedAuthenticationMethod, context: ChannelHandlerContext) throws -> ByteBuffer { + _ method: SelectedAuthenticationMethod, + context: ChannelHandlerContext + ) throws -> ByteBuffer { try stateMachine.sendAuthenticationMethod(method) var buffer = context.channel.allocator.buffer(capacity: 16) buffer.writeMethodSelection(method) return buffer } - + private func handleWriteResponse( - _ response: SOCKSResponse, context: ChannelHandlerContext) throws -> ByteBuffer { + _ response: SOCKSResponse, + context: ChannelHandlerContext + ) throws -> ByteBuffer { try stateMachine.sendServerResponse(response) if case .succeeded = response.reply { context.fireUserInboundEventTriggered(SOCKSProxyEstablishedEvent()) @@ -113,13 +117,16 @@ public final class SOCKSServerHandshakeHandler: ChannelDuplexHandler, RemovableC buffer.writeServerResponse(response) return buffer } - + private func handleWriteAuthenticationData( - _ data: ByteBuffer, complete: Bool, context: ChannelHandlerContext) throws -> ByteBuffer { + _ data: ByteBuffer, + complete: Bool, + context: ChannelHandlerContext + ) throws -> ByteBuffer { try self.stateMachine.sendAuthenticationData(data, complete: complete) return data } - + } @available(*, unavailable) diff --git a/Sources/NIOSOCKS/Messages/AuthenticationMethod.swift b/Sources/NIOSOCKS/Messages/AuthenticationMethod.swift index 125da033..353b048e 100644 --- a/Sources/NIOSOCKS/Messages/AuthenticationMethod.swift +++ b/Sources/NIOSOCKS/Messages/AuthenticationMethod.swift @@ -16,24 +16,24 @@ import NIOCore /// The SOCKS authentication method to use, defined in RFC 1928. public struct AuthenticationMethod: Hashable, Sendable { - + /// No authentication required public static let noneRequired = AuthenticationMethod(value: 0x00) - + /// Use GSSAPI public static let gssapi = AuthenticationMethod(value: 0x01) - + /// Username / password authentication public static let usernamePassword = AuthenticationMethod(value: 0x02) - + /// No acceptable authentication methods public static let noneAcceptable = AuthenticationMethod(value: 0xFF) - + /// The method identifier, valid values are in the range 0:255. public var value: UInt8 - + public init(value: UInt8) { self.value = value } - + } diff --git a/Sources/NIOSOCKS/Messages/ClientGreeting.swift b/Sources/NIOSOCKS/Messages/ClientGreeting.swift index 603ae075..85b527be 100644 --- a/Sources/NIOSOCKS/Messages/ClientGreeting.swift +++ b/Sources/NIOSOCKS/Messages/ClientGreeting.swift @@ -18,14 +18,14 @@ import NIOCore /// by providing an array of suggested authentication /// methods. public struct ClientGreeting: Hashable, Sendable { - + /// The protocol version. public let version: UInt8 = 5 - + /// The client-supported authentication methods. /// The SOCKS server will select one to use. public var methods: [AuthenticationMethod] - + /// Creates a new ``ClientGreeting`` /// - parameter methods: The client-supported authentication methods. public init(methods: [AuthenticationMethod]) { @@ -34,9 +34,9 @@ public struct ClientGreeting: Hashable, Sendable { } extension ByteBuffer { - + mutating func readClientGreeting() throws -> ClientGreeting? { - return try self.parseUnwindingIfNeeded { buffer in + try self.parseUnwindingIfNeeded { buffer in guard try buffer.readAndValidateProtocolVersion() != nil, let numMethods = buffer.readInteger(as: UInt8.self), @@ -44,23 +44,23 @@ extension ByteBuffer { else { return nil } - + // safe to bang as we've already checked the buffer size let methods = buffer.readBytes(length: Int(numMethods))!.map { AuthenticationMethod(value: $0) } return .init(methods: methods) } } - + @discardableResult mutating func writeClientGreeting(_ greeting: ClientGreeting) -> Int { var written = 0 written += self.writeInteger(greeting.version) written += self.writeInteger(UInt8(greeting.methods.count)) - + for method in greeting.methods { written += self.writeInteger(method.value) } - + return written } - + } diff --git a/Sources/NIOSOCKS/Messages/Errors.swift b/Sources/NIOSOCKS/Messages/Errors.swift index 169b317f..ddb8a1a9 100644 --- a/Sources/NIOSOCKS/Messages/Errors.swift +++ b/Sources/NIOSOCKS/Messages/Errors.swift @@ -14,21 +14,21 @@ /// Wrapper for SOCKS protcol error. public enum SOCKSError { - + /// The SOCKS client was in a different state to that required. public struct InvalidClientState: Error, Hashable { public init() { - + } } - + /// The SOCKS server was in a different state to that required. public struct InvalidServerState: Error, Hashable { public init() { - + } } - + /// The protocol version was something other than *5*. Note that /// we currently only supported SOCKv5. public struct InvalidProtocolVersion: Error, Hashable { @@ -66,7 +66,7 @@ public enum SOCKSError { /// The client and server were unable to agree on an authentication method. public struct NoValidAuthenticationMethod: Error, Hashable { public init() { - + } } @@ -77,12 +77,12 @@ public enum SOCKSError { self.reply = reply } } - + /// The client or server receieved data when it did not expect to. public struct UnexpectedRead: Error, Hashable { public init() { - + } } - + } diff --git a/Sources/NIOSOCKS/Messages/Helpers.swift b/Sources/NIOSOCKS/Messages/Helpers.swift index b8d544c6..1ff493d3 100644 --- a/Sources/NIOSOCKS/Messages/Helpers.swift +++ b/Sources/NIOSOCKS/Messages/Helpers.swift @@ -15,7 +15,7 @@ import NIOCore extension ByteBuffer { - + mutating func parseUnwindingIfNeeded(_ closure: (inout ByteBuffer) throws -> T?) rethrows -> T? { let save = self do { @@ -29,9 +29,9 @@ extension ByteBuffer { throw error } } - + mutating func readAndValidateProtocolVersion() throws -> UInt8? { - return try self.parseUnwindingIfNeeded { buffer -> UInt8? in + try self.parseUnwindingIfNeeded { buffer -> UInt8? in guard let version = buffer.readInteger(as: UInt8.self) else { return nil } @@ -41,9 +41,9 @@ extension ByteBuffer { return version } } - + mutating func readAndValidateReserved() throws -> UInt8? { - return try self.parseUnwindingIfNeeded { buffer -> UInt8? in + try self.parseUnwindingIfNeeded { buffer -> UInt8? in guard let reserved = buffer.readInteger(as: UInt8.self) else { return nil } @@ -53,5 +53,5 @@ extension ByteBuffer { return reserved } } - + } diff --git a/Sources/NIOSOCKS/Messages/Messages.swift b/Sources/NIOSOCKS/Messages/Messages.swift index d187329b..da8fc71e 100644 --- a/Sources/NIOSOCKS/Messages/Messages.swift +++ b/Sources/NIOSOCKS/Messages/Messages.swift @@ -16,33 +16,33 @@ import NIOCore /// Sent by the client and received by the server. public enum ClientMessage: Hashable, Sendable { - + /// Contains the proposed authentication methods. case greeting(ClientGreeting) - + /// Instructs the server of the target host, and the type of connection. case request(SOCKSRequest) - + /// Used to respond to server authentication challenges case authenticationData(ByteBuffer) } /// Sent by the server and received by the client. public enum ServerMessage: Hashable, Sendable { - + /// Used by the server to instruct the client of the authentication method to use. case selectedAuthenticationMethod(SelectedAuthenticationMethod) - + /// Sent by the server to inform the client that establishing the proxy to the target /// host succeeded or failed. case response(SOCKSResponse) - + /// Used when authenticating to send server challenges to the client. case authenticationData(ByteBuffer, complete: Bool) } extension ByteBuffer { - + @discardableResult mutating func writeServerMessage(_ message: ServerMessage) -> Int { switch message { case .selectedAuthenticationMethod(let method): @@ -53,5 +53,5 @@ extension ByteBuffer { return self.writeBuffer(&buffer) } } - + } diff --git a/Sources/NIOSOCKS/Messages/SOCKSRequest.swift b/Sources/NIOSOCKS/Messages/SOCKSRequest.swift index f17e73f8..afdcc08d 100644 --- a/Sources/NIOSOCKS/Messages/SOCKSRequest.swift +++ b/Sources/NIOSOCKS/Messages/SOCKSRequest.swift @@ -12,6 +12,8 @@ // //===----------------------------------------------------------------------===// +import NIOCore + #if canImport(Darwin) import Darwin #elseif canImport(Musl) @@ -20,23 +22,21 @@ import Musl import Glibc #endif -import NIOCore - // MARK: - ClientRequest /// Instructs the SOCKS proxy server of the target host, /// and how to connect. public struct SOCKSRequest: Hashable, Sendable { - + /// The SOCKS protocol version - we currently only support v5. public let version: UInt8 = 5 - + /// How to connect to the host. public var command: SOCKSCommand - + /// The target host address. public var addressType: SOCKSAddress - + /// Creates a new ``SOCKSRequest``. /// - parameter command: How to connect to the host. /// - parameter addressType: The target host address. @@ -44,11 +44,11 @@ public struct SOCKSRequest: Hashable, Sendable { self.command = command self.addressType = addressType } - + } extension ByteBuffer { - + @discardableResult mutating func writeClientRequest(_ request: SOCKSRequest) -> Int { var written = self.writeInteger(request.version) written += self.writeInteger(request.command.value) @@ -56,9 +56,9 @@ extension ByteBuffer { written += self.writeAddressType(request.addressType) return written } - + @discardableResult mutating func readClientRequest() throws -> SOCKSRequest? { - return try self.parseUnwindingIfNeeded { buffer -> SOCKSRequest? in + try self.parseUnwindingIfNeeded { buffer -> SOCKSRequest? in guard try buffer.readAndValidateProtocolVersion() != nil, let command = buffer.readInteger(as: UInt8.self), @@ -70,7 +70,7 @@ extension ByteBuffer { return .init(command: .init(value: command), addressType: address) } } - + } // MARK: - SOCKSCommand @@ -78,14 +78,14 @@ extension ByteBuffer { /// What type of connection the SOCKS server should establish with /// the target host. public struct SOCKSCommand: Hashable, Sendable { - + /// Typically the primary connection type, suitable for HTTP. public static let connect = SOCKSCommand(value: 0x01) - + /// Used in protocols that require the client to accept connections /// from the server, e.g. FTP. public static let bind = SOCKSCommand(value: 0x02) - + /// Used to establish an association within the UDP relay process to /// handle UDP datagrams. public static let udpAssociate = SOCKSCommand(value: 0x03) @@ -106,11 +106,11 @@ public enum SOCKSAddress: Hashable, Sendable { case address(SocketAddress) /// Host and port case domain(String, port: Int) - + static let ipv4IdentifierByte: UInt8 = 0x01 static let domainIdentifierByte: UInt8 = 0x03 static let ipv6IdentifierByte: UInt8 = 0x04 - + /// How many bytes are needed to represent the address, excluding the port var size: Int { switch self { @@ -130,13 +130,13 @@ public enum SOCKSAddress: Hashable, Sendable { } extension ByteBuffer { - + mutating func readAddressType() throws -> SOCKSAddress? { - return try self.parseUnwindingIfNeeded { buffer in + try self.parseUnwindingIfNeeded { buffer in guard let type = buffer.readInteger(as: UInt8.self) else { return nil } - + switch type { case SOCKSAddress.ipv4IdentifierByte: return try buffer.readIPv4Address() @@ -149,9 +149,9 @@ extension ByteBuffer { } } } - + mutating func readIPv4Address() throws -> SOCKSAddress? { - return try self.parseUnwindingIfNeeded { buffer in + try self.parseUnwindingIfNeeded { buffer in guard let bytes = buffer.readSlice(length: 4), let port = buffer.readPort() @@ -161,9 +161,9 @@ extension ByteBuffer { return .address(try .init(packedIPAddress: bytes, port: port)) } } - + mutating func readIPv6Address() throws -> SOCKSAddress? { - return try self.parseUnwindingIfNeeded { buffer in + try self.parseUnwindingIfNeeded { buffer in guard let bytes = buffer.readSlice(length: 16), let port = buffer.readPort() @@ -173,9 +173,9 @@ extension ByteBuffer { return .address(try .init(packedIPAddress: bytes, port: port)) } } - + mutating func readDomain() -> SOCKSAddress? { - return self.parseUnwindingIfNeeded { buffer in + self.parseUnwindingIfNeeded { buffer in guard let length = buffer.readInteger(as: UInt8.self), let host = buffer.readString(length: Int(length)), @@ -186,14 +186,14 @@ extension ByteBuffer { return .domain(host, port: port) } } - + mutating func readPort() -> Int? { guard let port = self.readInteger(as: UInt16.self) else { return nil } return Int(port) } - + @discardableResult mutating func writeAddressType(_ type: SOCKSAddress) -> Int { switch type { case .address(.v4(let address)): @@ -207,23 +207,23 @@ extension ByteBuffer { case .address(.unixDomainSocket): // enforced in the channel initalisers. fatalError("UNIX domain sockets are not supported") - case .domain(let domain, port: let port): + case .domain(let domain, let port): return self.writeInteger(SOCKSAddress.domainIdentifierByte) + self.writeInteger(UInt8(domain.utf8.count)) + self.writeString(domain) + self.writeInteger(UInt16(port)) } } - + @discardableResult mutating func writeIPv6Address(_ addr: sockaddr_in6) -> Int { - return withUnsafeBytes(of: addr.sin6_addr) { pointer in - return self.writeBytes(pointer) + withUnsafeBytes(of: addr.sin6_addr) { pointer in + self.writeBytes(pointer) } } - + @discardableResult mutating func writeIPv4Address(_ addr: sockaddr_in) -> Int { - return withUnsafeBytes(of: addr.sin_addr) { pointer in - return self.writeBytes(pointer) + withUnsafeBytes(of: addr.sin_addr) { pointer in + self.writeBytes(pointer) } } } diff --git a/Sources/NIOSOCKS/Messages/SOCKSResponse.swift b/Sources/NIOSOCKS/Messages/SOCKSResponse.swift index 67630c04..9b9b19c3 100644 --- a/Sources/NIOSOCKS/Messages/SOCKSResponse.swift +++ b/Sources/NIOSOCKS/Messages/SOCKSResponse.swift @@ -19,17 +19,17 @@ import NIOCore /// The SOCKS Server's response to the client's request /// indicating if the request succeeded or failed. public struct SOCKSResponse: Hashable, Sendable { - + /// The SOCKS protocol version - we currently only support v5. public let version: UInt8 = 5 - + /// The status of the connection - used to check if the request /// succeeded or failed. public var reply: SOCKSServerReply - + /// The host address. public var boundAddress: SOCKSAddress - + /// Creates a new ``SOCKSResponse``. /// - parameter reply: The status of the connection - used to check if the request /// succeeded or failed. @@ -41,9 +41,9 @@ public struct SOCKSResponse: Hashable, Sendable { } extension ByteBuffer { - + mutating func readServerResponse() throws -> SOCKSResponse? { - return try self.parseUnwindingIfNeeded { buffer in + try self.parseUnwindingIfNeeded { buffer in guard try buffer.readAndValidateProtocolVersion() != nil, let reply = buffer.readInteger(as: UInt8.self).map({ SOCKSServerReply(value: $0) }), @@ -55,14 +55,12 @@ extension ByteBuffer { return .init(reply: reply, boundAddress: boundAddress) } } - + @discardableResult mutating func writeServerResponse(_ response: SOCKSResponse) -> Int { - return self.writeInteger(response.version) + - self.writeInteger(response.reply.value) + - self.writeInteger(0, as: UInt8.self) + - self.writeAddressType(response.boundAddress) + self.writeInteger(response.version) + self.writeInteger(response.reply.value) + + self.writeInteger(0, as: UInt8.self) + self.writeAddressType(response.boundAddress) } - + } // MARK: - SOCKSServerReply @@ -70,37 +68,37 @@ extension ByteBuffer { /// Used to indicate if the SOCKS client's connection request succeeded /// or failed. public struct SOCKSServerReply: Hashable, Sendable { - + /// The connection succeeded and data can now be transmitted. public static let succeeded = SOCKSServerReply(value: 0x00) - + /// The SOCKS server encountered an internal failure. public static let serverFailure = SOCKSServerReply(value: 0x01) - + /// The connection to the host was not allowed. public static let notAllowed = SOCKSServerReply(value: 0x02) - + /// The host network is not reachable. public static let networkUnreachable = SOCKSServerReply(value: 0x03) - + /// The target host was not reachable. public static let hostUnreachable = SOCKSServerReply(value: 0x04) - + /// The connection tot he host was refused public static let refused = SOCKSServerReply(value: 0x05) - + /// The host address's TTL has expired. public static let ttlExpired = SOCKSServerReply(value: 0x06) - + /// The provided command is not supported. public static let commandUnsupported = SOCKSServerReply(value: 0x07) - + /// The provided address type is not supported. public static let addressUnsupported = SOCKSServerReply(value: 0x08) - + /// The raw `UInt8` status code. public var value: UInt8 - + /// Creates a new `Reply` from the given raw status code. Common /// statuses have convenience variables. /// - parameter value: The raw `UInt8` code sent by the SOCKS server. diff --git a/Sources/NIOSOCKS/Messages/SelectedAuthenticationMethod.swift b/Sources/NIOSOCKS/Messages/SelectedAuthenticationMethod.swift index 24dba9a2..6ff57580 100644 --- a/Sources/NIOSOCKS/Messages/SelectedAuthenticationMethod.swift +++ b/Sources/NIOSOCKS/Messages/SelectedAuthenticationMethod.swift @@ -18,13 +18,13 @@ import NIOCore /// authentication method it would like to use out of those /// offered. public struct SelectedAuthenticationMethod: Hashable, Sendable { - + /// The SOCKS protocol version - we currently only support v5. public let version: UInt8 = 5 - + /// The server's selected authentication method. public var method: AuthenticationMethod - + /// Creates a new `MethodSelection` wrapping an ``AuthenticationMethod``. /// - parameter method: The selected `AuthenticationMethod`. public init(method: AuthenticationMethod) { @@ -33,9 +33,9 @@ public struct SelectedAuthenticationMethod: Hashable, Sendable { } extension ByteBuffer { - + mutating func readMethodSelection() throws -> SelectedAuthenticationMethod? { - return try self.parseUnwindingIfNeeded { buffer in + try self.parseUnwindingIfNeeded { buffer in guard try buffer.readAndValidateProtocolVersion() != nil, let method = buffer.readInteger(as: UInt8.self) @@ -45,9 +45,9 @@ extension ByteBuffer { return .init(method: .init(value: method)) } } - + @discardableResult mutating func writeMethodSelection(_ method: SelectedAuthenticationMethod) -> Int { - return self.writeInteger(method.version) + self.writeInteger(method.method.value) + self.writeInteger(method.version) + self.writeInteger(method.method.value) } - + } diff --git a/Sources/NIOSOCKS/State/ClientStateMachine.swift b/Sources/NIOSOCKS/State/ClientStateMachine.swift index c04685ae..ca760484 100644 --- a/Sources/NIOSOCKS/State/ClientStateMachine.swift +++ b/Sources/NIOSOCKS/State/ClientStateMachine.swift @@ -34,34 +34,36 @@ enum ClientAction: Hashable { struct ClientStateMachine { private var state: ClientState - + var proxyEstablished: Bool { switch self.state { case .active: return true - case .error, .inactive, .waitingForAuthenticationMethod, .waitingForClientGreeting, .waitingForClientRequest, .waitingForServerResponse: + case .error, .inactive, .waitingForAuthenticationMethod, .waitingForClientGreeting, .waitingForClientRequest, + .waitingForServerResponse: return false } } - - var shouldBeginHandshake: Bool { + + var shouldBeginHandshake: Bool { switch self.state { case .inactive: return true - case .active, .error, .waitingForAuthenticationMethod, .waitingForClientGreeting, .waitingForClientRequest, .waitingForServerResponse: + case .active, .error, .waitingForAuthenticationMethod, .waitingForClientGreeting, .waitingForClientRequest, + .waitingForServerResponse: return false } } - + init() { self.state = .inactive } - + } // MARK: - Incoming extension ClientStateMachine { - + mutating func receiveBuffer(_ buffer: inout ByteBuffer) throws -> ClientAction { do { switch self.state { @@ -83,23 +85,26 @@ extension ClientStateMachine { throw error } } - - mutating func handleSelectedAuthenticationMethod(_ buffer: inout ByteBuffer, greeting: ClientGreeting) throws -> ClientAction? { - return try buffer.parseUnwindingIfNeeded { buffer -> ClientAction? in + + mutating func handleSelectedAuthenticationMethod( + _ buffer: inout ByteBuffer, + greeting: ClientGreeting + ) throws -> ClientAction? { + try buffer.parseUnwindingIfNeeded { buffer -> ClientAction? in guard let selected = try buffer.readMethodSelection() else { return nil } guard greeting.methods.contains(selected.method) else { throw SOCKSError.InvalidAuthenticationSelection(selection: selected.method) } - + // we don't current support any form of authentication return self.authenticate(&buffer, method: selected.method) } } - + mutating func handleServerResponse(_ buffer: inout ByteBuffer, request: SOCKSRequest) throws -> ClientAction? { - return try buffer.parseUnwindingIfNeeded { buffer -> ClientAction? in + try buffer.parseUnwindingIfNeeded { buffer -> ClientAction? in guard let response = try buffer.readServerResponse() else { return nil } @@ -110,22 +115,22 @@ extension ClientStateMachine { return .proxyEstablished } } - + mutating func authenticate(_ buffer: inout ByteBuffer, method: AuthenticationMethod) -> ClientAction { precondition(method == .noneRequired, "No authentication mechanism is supported. Use .noneRequired only.") - + // we don't currently support any authentication // so assume all is fine, and instruct the client // to send the request self.state = .waitingForClientRequest return .sendRequest } - + } // MARK: - Outgoing extension ClientStateMachine { - + mutating func connectionEstablished() throws -> ClientAction { guard self.state == .inactive else { throw SOCKSError.InvalidClientState() @@ -140,12 +145,12 @@ extension ClientStateMachine { } self.state = .waitingForAuthenticationMethod(greeting) } - + mutating func sendClientRequest(_ request: SOCKSRequest) throws { guard self.state == .waitingForClientRequest else { throw SOCKSError.InvalidClientState() } self.state = .waitingForServerResponse(request) } - + } diff --git a/Sources/NIOSOCKS/State/ServerStateMachine.swift b/Sources/NIOSOCKS/State/ServerStateMachine.swift index 69853bcb..8d48ff99 100644 --- a/Sources/NIOSOCKS/State/ServerStateMachine.swift +++ b/Sources/NIOSOCKS/State/ServerStateMachine.swift @@ -26,25 +26,25 @@ enum ServerState: Hashable { } struct ServerStateMachine: Hashable { - + private var state: ServerState private var authenticationMethod: AuthenticationMethod? - + var proxyEstablished: Bool { switch self.state { case .active: return true case .inactive, - .waitingForClientGreeting, - .waitingToSendAuthenticationMethod, - .authenticating, - .waitingForClientRequest, - .waitingToSendResponse, - .error: + .waitingForClientGreeting, + .waitingToSendAuthenticationMethod, + .authenticating, + .waitingForClientRequest, + .waitingToSendResponse, + .error: return false } } - + init() { self.state = .inactive } @@ -52,7 +52,7 @@ struct ServerStateMachine: Hashable { // MARK: - Inbound extension ServerStateMachine { - + mutating func receiveBuffer(_ buffer: inout ByteBuffer) throws -> ClientMessage? { do { switch self.state { @@ -70,9 +70,9 @@ extension ServerStateMachine { throw error } } - - fileprivate mutating func handleClientGreeting(from buffer: inout ByteBuffer) throws -> ClientMessage? { - return try buffer.parseUnwindingIfNeeded { buffer -> ClientMessage? in + + fileprivate mutating func handleClientGreeting(from buffer: inout ByteBuffer) throws -> ClientMessage? { + try buffer.parseUnwindingIfNeeded { buffer -> ClientMessage? in guard let greeting = try buffer.readClientGreeting() else { return nil } @@ -80,9 +80,9 @@ extension ServerStateMachine { return .greeting(greeting) } } - + fileprivate mutating func handleClientRequest(from buffer: inout ByteBuffer) throws -> ClientMessage? { - return try buffer.parseUnwindingIfNeeded { buffer -> ClientMessage? in + try buffer.parseUnwindingIfNeeded { buffer -> ClientMessage? in guard let request = try buffer.readClientRequest() else { return nil } @@ -90,49 +90,49 @@ extension ServerStateMachine { return .request(request) } } - + fileprivate mutating func handleAuthenticationData(from buffer: inout ByteBuffer) -> ClientMessage? { guard let buffer = buffer.readSlice(length: buffer.readableBytes) else { return nil } return .authenticationData(buffer) } - + } // MARK: - Outbound extension ServerStateMachine { - + mutating func connectionEstablished() throws { switch self.state { case .inactive: () case .authenticating, - .waitingForClientGreeting, - .waitingToSendAuthenticationMethod, - .waitingForClientRequest, - .waitingToSendResponse, - .active, - .error: - throw SOCKSError.InvalidServerState() + .waitingForClientGreeting, + .waitingToSendAuthenticationMethod, + .waitingForClientRequest, + .waitingToSendResponse, + .active, + .error: + throw SOCKSError.InvalidServerState() } self.state = .waitingForClientGreeting } - + mutating func sendAuthenticationMethod(_ selected: SelectedAuthenticationMethod) throws { switch self.state { case .waitingToSendAuthenticationMethod: () case .inactive, - .waitingForClientGreeting, - .authenticating, - .waitingForClientRequest, - .waitingToSendResponse, - .active, - .error: - throw SOCKSError.InvalidServerState() + .waitingForClientGreeting, + .authenticating, + .waitingForClientRequest, + .waitingToSendResponse, + .active, + .error: + throw SOCKSError.InvalidServerState() } - + self.authenticationMethod = selected.method if selected.method == .noneRequired { self.state = .waitingForClientRequest @@ -140,28 +140,28 @@ extension ServerStateMachine { self.state = .authenticating } } - + mutating func sendServerResponse(_ response: SOCKSResponse) throws { switch self.state { case .waitingToSendResponse: () case .inactive, - .waitingForClientGreeting, - .waitingToSendAuthenticationMethod, - .waitingForClientRequest, - .authenticating, - .active, - .error: - throw SOCKSError.InvalidServerState() + .waitingForClientGreeting, + .waitingToSendAuthenticationMethod, + .waitingForClientRequest, + .authenticating, + .active, + .error: + throw SOCKSError.InvalidServerState() } - + if response.reply == .succeeded { self.state = .active } else { self.state = .error } } - + mutating func sendAuthenticationData(_ data: ByteBuffer, complete: Bool) throws { switch self.state { case .authenticating: @@ -171,14 +171,14 @@ extension ServerStateMachine { throw SOCKSError.InvalidServerState() } case .inactive, - .waitingForClientGreeting, - .waitingToSendAuthenticationMethod, - .waitingToSendResponse, - .active, - .error: - throw SOCKSError.InvalidServerState() + .waitingForClientGreeting, + .waitingToSendAuthenticationMethod, + .waitingToSendResponse, + .active, + .error: + throw SOCKSError.InvalidServerState() } - + if complete { self.state = .waitingForClientRequest } diff --git a/Sources/NIOSOCKSClient/main.swift b/Sources/NIOSOCKSClient/main.swift index d2d206c9..242d47de 100644 --- a/Sources/NIOSOCKSClient/main.swift +++ b/Sources/NIOSOCKSClient/main.swift @@ -18,11 +18,11 @@ import NIOSOCKS class EchoHandler: ChannelInboundHandler { typealias InboundIn = ByteBuffer - + func channelRead(context: ChannelHandlerContext, data: NIOAny) { context.writeAndFlush(data, promise: nil) } - + } let targetIPAddress = "127.0.0.1" @@ -34,9 +34,9 @@ let bootstrap = ClientBootstrap(group: elg) .channelInitializer { channel in channel.pipeline.addHandlers([ SOCKSClientHandler(targetAddress: targetAddress), - EchoHandler() + EchoHandler(), ]) -} + } let channel = try bootstrap.connect(host: "127.0.0.1", port: 1080).wait() while let string = readLine(strippingNewline: true) { diff --git a/Sources/NIOWritePCAPDemo/main.swift b/Sources/NIOWritePCAPDemo/main.swift index 96a28aa4..01f2ffcc 100644 --- a/Sources/NIOWritePCAPDemo/main.swift +++ b/Sources/NIOWritePCAPDemo/main.swift @@ -13,38 +13,49 @@ //===----------------------------------------------------------------------===// import NIOCore -import NIOPosix import NIOExtras import NIOHTTP1 +import NIOPosix class SendSimpleRequestHandler: ChannelInboundHandler { typealias InboundIn = HTTPClientResponsePart typealias OutboundOut = HTTPClientRequestPart - + private let allDonePromise: EventLoopPromise - + init(allDonePromise: EventLoopPromise) { self.allDonePromise = allDonePromise } - + func channelRead(context: ChannelHandlerContext, data: NIOAny) { if case .body(let body) = self.unwrapInboundIn(data) { self.allDonePromise.succeed(body) } } - + func errorCaught(context: ChannelHandlerContext, error: Error) { self.allDonePromise.fail(error) context.close(promise: nil) } - + func channelActive(context: ChannelHandlerContext) { - let headers = HTTPHeaders([("host", "httpbin.org"), - ("accept", "application/json")]) - context.write(self.wrapOutboundOut(.head(.init(version: .init(major: 1, minor: 1), - method: .GET, - uri: "/delay/0.2", - headers: headers))), promise: nil) + let headers = HTTPHeaders([ + ("host", "httpbin.org"), + ("accept", "application/json"), + ]) + context.write( + self.wrapOutboundOut( + .head( + .init( + version: .init(major: 1, minor: 1), + method: .GET, + uri: "/delay/0.2", + headers: headers + ) + ) + ), + promise: nil + ) context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) } } @@ -66,7 +77,7 @@ defer { let allDonePromise = group.next().makePromise(of: ByteBuffer.self) let connection = try ClientBootstrap(group: group.next()) .channelInitializer { channel in - return channel.pipeline.addHandler(NIOWritePCAPHandler(mode: .client, fileSink: fileSink.write)).flatMap { + channel.pipeline.addHandler(NIOWritePCAPHandler(mode: .client, fileSink: fileSink.write)).flatMap { channel.pipeline.addHTTPClientHandlers() }.flatMap { channel.pipeline.addHandler(SendSimpleRequestHandler(allDonePromise: allDonePromise)) diff --git a/Sources/NIOWritePartialPCAPDemo/main.swift b/Sources/NIOWritePartialPCAPDemo/main.swift index dfba1658..f7b94774 100644 --- a/Sources/NIOWritePartialPCAPDemo/main.swift +++ b/Sources/NIOWritePartialPCAPDemo/main.swift @@ -13,9 +13,9 @@ //===----------------------------------------------------------------------===// import NIOCore -import NIOPosix import NIOExtras import NIOHTTP1 +import NIOPosix /// Trigger recording pcap data when a "precondition failed" is seen. class TriggerPCAPHandler: ChannelInboundHandler { @@ -23,9 +23,9 @@ class TriggerPCAPHandler: ChannelInboundHandler { typealias OutboundOut = HTTPClientRequestPart private let pcapRingBuffer: NIOPCAPRingBuffer - private let sink: (ByteBuffer) -> () + private let sink: (ByteBuffer) -> Void - init(pcapRingBuffer: NIOPCAPRingBuffer, sink: @escaping (ByteBuffer) -> ()) { + init(pcapRingBuffer: NIOPCAPRingBuffer, sink: @escaping (ByteBuffer) -> Void) { self.pcapRingBuffer = pcapRingBuffer self.sink = sink } @@ -55,43 +55,56 @@ class TriggerPCAPHandler: ChannelInboundHandler { class SendSimpleSequenceRequestHandler: ChannelInboundHandler { typealias InboundIn = HTTPClientResponsePart typealias OutboundOut = HTTPClientRequestPart - + private let allDonePromise: EventLoopPromise private var nextRequestNumber = 0 - private var requestsToMake: [HTTPResponseStatus] = [ .ok, .created, .accepted, .nonAuthoritativeInformation, - .noContent, .resetContent, .preconditionFailed, - .partialContent, .multiStatus, .alreadyReported ] + private var requestsToMake: [HTTPResponseStatus] = [ + .ok, .created, .accepted, .nonAuthoritativeInformation, + .noContent, .resetContent, .preconditionFailed, + .partialContent, .multiStatus, .alreadyReported, + ] init(allDonePromise: EventLoopPromise) { self.allDonePromise = allDonePromise } - + func channelRead(context: ChannelHandlerContext, data: NIOAny) { if case .end = self.unwrapInboundIn(data) { self.makeNextRequestOrComplete(context: context) } } - + func errorCaught(context: ChannelHandlerContext, error: Error) { self.allDonePromise.fail(error) context.close(promise: nil) } - + func channelActive(context: ChannelHandlerContext) { self.makeNextRequestOrComplete(context: context) } private func makeNextRequestOrComplete(context: ChannelHandlerContext) { if self.nextRequestNumber < self.requestsToMake.count { - let headers = HTTPHeaders([("host", "httpbin.org"), - ("accept", "application/json")]) + let headers = HTTPHeaders([ + ("host", "httpbin.org"), + ("accept", "application/json"), + ]) let currentStatus = self.requestsToMake[self.nextRequestNumber].code self.nextRequestNumber += 1 - context.write(self.wrapOutboundOut(.head(.init(version: .init(major: 1, minor: 1), - method: .GET, - uri: "/status/\(currentStatus)", - headers: headers))), promise: nil) + context.write( + self.wrapOutboundOut( + .head( + .init( + version: .init(major: 1, minor: 1), + method: .GET, + uri: "/status/\(currentStatus)", + headers: headers + ) + ) + ), + promise: nil + ) context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) } else { self.allDonePromise.succeed(()) @@ -117,10 +130,16 @@ let allDonePromise = group.next().makePromise(of: Void.self) let maximumFragments = 4 let connection = try ClientBootstrap(group: group.next()) .channelInitializer { channel in - let pcapRingBuffer = NIOPCAPRingBuffer(maximumFragments: maximumFragments, - maximumBytes: 1_000_000) - return channel.pipeline.addHandler(NIOWritePCAPHandler(mode: .client, - fileSink: pcapRingBuffer.addFragment)).flatMap { + let pcapRingBuffer = NIOPCAPRingBuffer( + maximumFragments: maximumFragments, + maximumBytes: 1_000_000 + ) + return channel.pipeline.addHandler( + NIOWritePCAPHandler( + mode: .client, + fileSink: pcapRingBuffer.addFragment + ) + ).flatMap { channel.pipeline.addHTTPClientHandlers() }.flatMap { channel.pipeline.addHandler(TriggerPCAPHandler(pcapRingBuffer: pcapRingBuffer, sink: fileSink.write)) diff --git a/Tests/NIOExtrasTests/DebugInboundEventsHandlerTest.swift b/Tests/NIOExtrasTests/DebugInboundEventsHandlerTest.swift index fe169c2f..a60316a7 100644 --- a/Tests/NIOExtrasTests/DebugInboundEventsHandlerTest.swift +++ b/Tests/NIOExtrasTests/DebugInboundEventsHandlerTest.swift @@ -12,17 +12,17 @@ // //===----------------------------------------------------------------------===// -import XCTest import NIOCore import NIOEmbedded import NIOExtras +import XCTest class DebugInboundEventsHandlerTest: XCTestCase { - + private var channel: EmbeddedChannel! private var lastEvent: DebugInboundEventsHandler.Event! private var handlerUnderTest: DebugInboundEventsHandler! - + override func setUp() { super.setUp() channel = EmbeddedChannel() @@ -31,61 +31,61 @@ class DebugInboundEventsHandlerTest: XCTestCase { } try? channel.pipeline.addHandler(handlerUnderTest).wait() } - + override func tearDown() { channel = nil lastEvent = nil handlerUnderTest = nil super.tearDown() } - + func testRegistered() { channel.pipeline.register(promise: nil) XCTAssertEqual(lastEvent, .registered) } - + func testUnregistered() { channel.pipeline.fireChannelUnregistered() XCTAssertEqual(lastEvent, .unregistered) } - + func testActive() { channel.pipeline.fireChannelActive() XCTAssertEqual(lastEvent, .active) } - + func testInactive() { channel.pipeline.fireChannelInactive() XCTAssertEqual(lastEvent, .inactive) } - + func testReadComplete() { channel.pipeline.fireChannelReadComplete() XCTAssertEqual(lastEvent, .readComplete) } - + func testWritabilityChanged() { channel.pipeline.fireChannelWritabilityChanged() XCTAssertEqual(lastEvent, .writabilityChanged(isWritable: true)) } - + func testUserInboundEvent() { let eventString = "new user inbound event" channel.pipeline.fireUserInboundEventTriggered(eventString) XCTAssertEqual(lastEvent, .userInboundEventTriggered(event: eventString)) } - + func testErrorCaught() { struct E: Error { var localizedDescription: String { - return "desc" + "desc" } } let error = E() channel.pipeline.fireErrorCaught(error) XCTAssertEqual(lastEvent, .errorCaught(error)) } - + func testRead() { let messageString = "message" var expectedBuffer = ByteBufferAllocator().buffer(capacity: messageString.count) @@ -94,7 +94,7 @@ class DebugInboundEventsHandlerTest: XCTestCase { channel.pipeline.fireChannelRead(nioAny) XCTAssertEqual(lastEvent, .read(data: nioAny)) } - + } extension DebugInboundEventsHandler.Event { @@ -125,8 +125,7 @@ extension DebugInboundEventsHandler.Event { } #if compiler(>=6.0) -extension DebugInboundEventsHandler.Event: @retroactive Equatable { } +extension DebugInboundEventsHandler.Event: @retroactive Equatable {} #else -extension DebugInboundEventsHandler.Event: Equatable { } +extension DebugInboundEventsHandler.Event: Equatable {} #endif - diff --git a/Tests/NIOExtrasTests/DebugOutboundEventsHandlerTest.swift b/Tests/NIOExtrasTests/DebugOutboundEventsHandlerTest.swift index cf44789c..cc4c2c53 100644 --- a/Tests/NIOExtrasTests/DebugOutboundEventsHandlerTest.swift +++ b/Tests/NIOExtrasTests/DebugOutboundEventsHandlerTest.swift @@ -12,17 +12,17 @@ // //===----------------------------------------------------------------------===// -import XCTest import NIOCore import NIOEmbedded import NIOExtras +import XCTest class DebugOutboundEventsHandlerTest: XCTestCase { private var channel: EmbeddedChannel! private var lastEvent: DebugOutboundEventsHandler.Event! private var handlerUnderTest: DebugOutboundEventsHandler! - + override func setUp() { super.setUp() channel = EmbeddedChannel() @@ -31,7 +31,7 @@ class DebugOutboundEventsHandlerTest: XCTestCase { } try? channel.pipeline.addHandler(handlerUnderTest).wait() } - + override func tearDown() { channel = nil lastEvent = nil @@ -43,40 +43,40 @@ class DebugOutboundEventsHandlerTest: XCTestCase { channel.pipeline.register(promise: nil) XCTAssertEqual(lastEvent, .register) } - + func testBind() throws { let address = try SocketAddress(unixDomainSocketPath: "path") channel.bind(to: address, promise: nil) XCTAssertEqual(lastEvent, .bind(address: address)) } - + func testConnect() throws { let address = try SocketAddress(unixDomainSocketPath: "path") channel.connect(to: address, promise: nil) XCTAssertEqual(lastEvent, .connect(address: address)) } - + func testWrite() { let data = NIOAny(" 1 2 3 ") channel.write(data, promise: nil) XCTAssertEqual(lastEvent, .write(data: data)) } - + func testFlush() { channel.flush() XCTAssertEqual(lastEvent, .flush) } - + func testRead() { channel.read() XCTAssertEqual(lastEvent, .read) } - + func testClose() { channel.close(mode: .all, promise: nil) XCTAssertEqual(lastEvent, .close(mode: .all)) } - + func testTriggerUserOutboundEvent() { let event = "user event" channel.triggerUserOutboundEvent(event, promise: nil) @@ -111,8 +111,7 @@ extension DebugOutboundEventsHandler.Event { } #if compiler(>=6.0) -extension DebugOutboundEventsHandler.Event: @retroactive Equatable { } +extension DebugOutboundEventsHandler.Event: @retroactive Equatable {} #else -extension DebugOutboundEventsHandler.Event: Equatable { } +extension DebugOutboundEventsHandler.Event: Equatable {} #endif - diff --git a/Tests/NIOExtrasTests/FixedLengthFrameDecoderTest.swift b/Tests/NIOExtrasTests/FixedLengthFrameDecoderTest.swift index 0a4627ed..592879a3 100644 --- a/Tests/NIOExtrasTests/FixedLengthFrameDecoderTest.swift +++ b/Tests/NIOExtrasTests/FixedLengthFrameDecoderTest.swift @@ -12,11 +12,11 @@ // //===----------------------------------------------------------------------===// -import XCTest import NIOCore import NIOEmbedded import NIOExtras import NIOTestUtils +import XCTest class FixedLengthFrameDecoderTest: XCTestCase { public func testDecodeIfFewerBytesAreSent() throws { @@ -30,9 +30,12 @@ class FixedLengthFrameDecoderTest: XCTestCase { XCTAssertTrue(try channel.writeInbound(buffer).isEmpty) XCTAssertTrue(try channel.writeInbound(buffer).isFull) - XCTAssertEqual("xxxxxxxx", try (channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { - String(decoding: $0, as: Unicode.UTF8.self) - }) + XCTAssertEqual( + "xxxxxxxx", + try (channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { + String(decoding: $0, as: Unicode.UTF8.self) + } + ) XCTAssertTrue(try channel.finish().isClean) } @@ -46,13 +49,19 @@ class FixedLengthFrameDecoderTest: XCTestCase { buffer.writeString("xxxxxxxxaaaaaaaabbb") XCTAssertTrue(try channel.writeInbound(buffer).isFull) - XCTAssertEqual("xxxxxxxx", try (channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { - String(decoding: $0, as: Unicode.UTF8.self) - }) + XCTAssertEqual( + "xxxxxxxx", + try (channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { + String(decoding: $0, as: Unicode.UTF8.self) + } + ) - XCTAssertEqual("aaaaaaaa", try (channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { - String(decoding: $0, as: Unicode.UTF8.self) - }) + XCTAssertEqual( + "aaaaaaaa", + try (channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { + String(decoding: $0, as: Unicode.UTF8.self) + } + ) XCTAssertNoThrow(XCTAssertNil(try channel.readInbound(as: ByteBuffer.self))) XCTAssertThrowsError(try channel.finish()) { error in @@ -75,9 +84,12 @@ class FixedLengthFrameDecoderTest: XCTestCase { buffer.writeString("xxxxxxxxxxxxxxx") XCTAssertTrue(try channel.writeInbound(buffer).isFull) - XCTAssertEqual("xxxxxxxx", try (channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { - String(decoding: $0, as: Unicode.UTF8.self) - }) + XCTAssertEqual( + "xxxxxxxx", + try (channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { + String(decoding: $0, as: Unicode.UTF8.self) + } + ) let removeFuture = channel.pipeline.removeHandler(handler) (channel.eventLoop as! EmbeddedEventLoop).run() @@ -106,9 +118,12 @@ class FixedLengthFrameDecoderTest: XCTestCase { buffer.writeString("xxxxxxxx") XCTAssertTrue(try channel.writeInbound(buffer).isFull) - XCTAssertEqual("xxxxxxxx", try (channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { - String(decoding: $0, as: Unicode.UTF8.self) - }) + XCTAssertEqual( + "xxxxxxxx", + try (channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { + String(decoding: $0, as: Unicode.UTF8.self) + } + ) let removeFuture = channel.pipeline.removeHandler(handler) (channel.eventLoop as! EmbeddedEventLoop).run() @@ -118,7 +133,7 @@ class FixedLengthFrameDecoderTest: XCTestCase { } func testBasicValidation() { - for length in 1 ... 20 { + for length in 1...20 { let inputs = [ String(decoding: Array(repeating: UInt8(ascii: "a"), count: length), as: Unicode.UTF8.self), String(decoding: Array(repeating: UInt8(ascii: "b"), count: length), as: Unicode.UTF8.self), @@ -130,9 +145,11 @@ class FixedLengthFrameDecoderTest: XCTestCase { return buffer } let inputOutputPairs: [(String, [ByteBuffer])] = inputs.map { ($0, [byteBuffer($0)]) } - XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder(stringInputOutputPairs: inputOutputPairs) { - FixedLengthFrameDecoder(frameLength: length) - }) + XCTAssertNoThrow( + try ByteToMessageDecoderVerifier.verifyDecoder(stringInputOutputPairs: inputOutputPairs) { + FixedLengthFrameDecoder(frameLength: length) + } + ) } } } diff --git a/Tests/NIOExtrasTests/HTTP1ProxyConnectHandlerTests.swift b/Tests/NIOExtrasTests/HTTP1ProxyConnectHandlerTests.swift index f2457d18..e65517b6 100644 --- a/Tests/NIOExtrasTests/HTTP1ProxyConnectHandlerTests.swift +++ b/Tests/NIOExtrasTests/HTTP1ProxyConnectHandlerTests.swift @@ -12,12 +12,13 @@ // //===----------------------------------------------------------------------===// -@testable import NIOExtras import NIOCore import NIOEmbedded import NIOHTTP1 import XCTest +@testable import NIOExtras + class HTTP1ProxyConnectHandlerTests: XCTestCase { func testProxyConnectWithoutAuthorizationSuccess() throws { let embedded = EmbeddedChannel() @@ -60,7 +61,7 @@ class HTTP1ProxyConnectHandlerTests: XCTestCase { let proxyConnectHandler = NIOHTTP1ProxyConnectHandler( targetHost: "swift.org", targetPort: 443, - headers: ["proxy-authorization" : "Basic abc123"], + headers: ["proxy-authorization": "Basic abc123"], deadline: .now() + .seconds(10), promise: promise ) @@ -213,13 +214,20 @@ class HTTP1ProxyConnectHandlerTests: XCTestCase { // write a request to be buffered inside the ProxyConnectHandler // it will be unbuffered when the handler completes and removes itself - let requestHead = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .GET, uri: "http://apple.com") + let requestHead = HTTPRequestHead( + version: HTTPVersion(major: 1, minor: 1), + method: .GET, + uri: "http://apple.com" + ) var promises: [EventLoopPromise] = [] promises.append(embedded.eventLoop.makePromise()) embedded.pipeline.write(NIOAny(HTTPClientRequestPart.head(requestHead)), promise: promises.last) promises.append(embedded.eventLoop.makePromise()) - embedded.pipeline.write(NIOAny(HTTPClientRequestPart.body(.byteBuffer(ByteBuffer(string: "Test")))), promise: promises.last) + embedded.pipeline.write( + NIOAny(HTTPClientRequestPart.body(.byteBuffer(ByteBuffer(string: "Test")))), + promise: promises.last + ) promises.append(embedded.eventLoop.makePromise()) embedded.pipeline.write(NIOAny(HTTPClientRequestPart.end(nil)), promise: promises.last) @@ -286,7 +294,10 @@ class HTTP1ProxyConnectHandlerTests: XCTestCase { embedded.pipeline.write(NIOAny(HTTPClientRequestPart.head(requestHead)), promise: promises.last) promises.append(embedded.eventLoop.makePromise()) - embedded.pipeline.write(NIOAny(HTTPClientRequestPart.body(.byteBuffer(ByteBuffer(string: "Test")))), promise: promises.last) + embedded.pipeline.write( + NIOAny(HTTPClientRequestPart.body(.byteBuffer(ByteBuffer(string: "Test")))), + promise: promises.last + ) promises.append(embedded.eventLoop.makePromise()) embedded.pipeline.write(NIOAny(HTTPClientRequestPart.end(nil)), promise: promises.last) diff --git a/Tests/NIOExtrasTests/JSONRPCFramingContentLengthHeaderDecoderTests.swift b/Tests/NIOExtrasTests/JSONRPCFramingContentLengthHeaderDecoderTests.swift index d390acd6..7c548bac 100644 --- a/Tests/NIOExtrasTests/JSONRPCFramingContentLengthHeaderDecoderTests.swift +++ b/Tests/NIOExtrasTests/JSONRPCFramingContentLengthHeaderDecoderTests.swift @@ -12,20 +12,23 @@ // //===----------------------------------------------------------------------===// -import XCTest - import NIOCore import NIOEmbedded import NIOExtras +import XCTest final class JSONRPCFramingContentLengthHeaderDecoderTests: XCTestCase { - private var channel: EmbeddedChannel! // not a real network connection + private var channel: EmbeddedChannel! // not a real network connection override func setUp() { self.channel = EmbeddedChannel() // let's add the framing handler to the pipeline as that's what we're testing here. - XCTAssertNoThrow(try self.channel.pipeline.addHandler(ByteToMessageHandler(NIOJSONRPCFraming.ContentLengthHeaderFrameDecoder())).wait()) + XCTAssertNoThrow( + try self.channel.pipeline.addHandler( + ByteToMessageHandler(NIOJSONRPCFraming.ContentLengthHeaderFrameDecoder()) + ).wait() + ) // this pretends to connect the channel to this IP address. XCTAssertNoThrow(self.channel.connect(to: try .init(ipAddress: "1.2.3.4", port: 5678))) } @@ -51,7 +54,7 @@ final class JSONRPCFramingContentLengthHeaderDecoderTests: XCTestCase { } private func readInboundString() throws -> String? { - return try self.channel.readInbound(as: ByteBuffer.self).map { + try self.channel.readInbound(as: ByteBuffer.self).map { String(decoding: $0.readableBytesView, as: Unicode.UTF8.self) } } @@ -76,10 +79,11 @@ final class JSONRPCFramingContentLengthHeaderDecoderTests: XCTestCase { func testTechnicallyInvalidButWeAreNicePeople() { // this writes a bunch of messages that are technically not okay, but we're fine with them - let coupleOfMessages = "Content-Length:1\r\n\r\nX" + // space after colon missing - /* */ "Content-Length : 1\r\n\r\nX" + // extra space before colon - /* */ " Content-Length: 1\r\n\r\nX" + // extra space at the beginning of the header - /* */ "Content-Length: 1\n\r\nX" // \r missing + let coupleOfMessages = + "Content-Length:1\r\n\r\nX" // space after colon missing + + "Content-Length : 1\r\n\r\nX" // extra space before colon + + " Content-Length: 1\r\n\r\nX" // extra space at the beginning of the header + + "Content-Length: 1\n\r\nX" // \r missing XCTAssertNoThrow(try self.channel.writeInbound(self.buffer(string: coupleOfMessages))) @@ -101,11 +105,12 @@ final class JSONRPCFramingContentLengthHeaderDecoderTests: XCTestCase { func testDripAndMassFeedMessages() { let messagesAndExpectedOutput: [(String, String)] = - [ ("Content-Length: 1\r\n\r\n1", "1"), - ("Content-Length: 0\r\n\r\n", ""), - ("foo: bar\r\nContent-Length: 7\r\nbuz: qux\r\n\r\nqwerasd", "qwerasd"), - ("content-lengTH: 1 \r\n\r\nX", "X") - ] + [ + ("Content-Length: 1\r\n\r\n1", "1"), + ("Content-Length: 0\r\n\r\n", ""), + ("foo: bar\r\nContent-Length: 7\r\nbuz: qux\r\n\r\nqwerasd", "qwerasd"), + ("content-lengTH: 1 \r\n\r\nX", "X"), + ] // drip feed (byte by byte) for (message, expected) in messagesAndExpectedOutput { @@ -124,7 +129,7 @@ final class JSONRPCFramingContentLengthHeaderDecoderTests: XCTestCase { XCTAssertNoThrow(try self.channel.writeInbound(self.buffer(string: everything + everything + everything))) for _ in 0..<3 { - for expected in messagesAndExpectedOutput.map({$0.1}) { + for expected in messagesAndExpectedOutput.map({ $0.1 }) { XCTAssertNoThrow(try XCTAssertEqual(expected, self.readInboundString())) } } @@ -144,7 +149,7 @@ final class JSONRPCFramingContentLengthHeaderDecoderTests: XCTestCase { } func testErrorNotEnoughDataAtEOF() { - let s = "Content-Length: 4\r\n\r\n123" // only three bytes payload, not 4 + let s = "Content-Length: 4\r\n\r\n123" // only three bytes payload, not 4 XCTAssertNoThrow(try self.channel.writeInbound(self.buffer(string: s))) XCTAssertNoThrow(try XCTAssertNil(self.channel.readInbound())) diff --git a/Tests/NIOExtrasTests/JSONRPCFramingContentLengthHeaderEncoderTests.swift b/Tests/NIOExtrasTests/JSONRPCFramingContentLengthHeaderEncoderTests.swift index d4491334..68326052 100644 --- a/Tests/NIOExtrasTests/JSONRPCFramingContentLengthHeaderEncoderTests.swift +++ b/Tests/NIOExtrasTests/JSONRPCFramingContentLengthHeaderEncoderTests.swift @@ -12,22 +12,27 @@ // //===----------------------------------------------------------------------===// -import XCTest - import NIOCore import NIOEmbedded import NIOExtras +import XCTest final class JSONRPCFramingContentLengthHeaderEncoderTests: XCTestCase { - private var channel: EmbeddedChannel! // not a real network connection + private var channel: EmbeddedChannel! // not a real network connection override func setUp() { self.channel = EmbeddedChannel() // let's add the framing handler to the pipeline as that's what we're testing here. - XCTAssertNoThrow(try self.channel.pipeline.addHandler(NIOJSONRPCFraming.ContentLengthHeaderFrameEncoder()).wait()) + XCTAssertNoThrow( + try self.channel.pipeline.addHandler(NIOJSONRPCFraming.ContentLengthHeaderFrameEncoder()).wait() + ) // let's also add the decoder so we can round-trip - XCTAssertNoThrow(try self.channel.pipeline.addHandler(ByteToMessageHandler(NIOJSONRPCFraming.ContentLengthHeaderFrameDecoder())).wait()) + XCTAssertNoThrow( + try self.channel.pipeline.addHandler( + ByteToMessageHandler(NIOJSONRPCFraming.ContentLengthHeaderFrameDecoder()) + ).wait() + ) // this pretends to connect the channel to this IP address. XCTAssertNoThrow(self.channel.connect(to: try .init(ipAddress: "1.2.3.4", port: 5678))) } @@ -41,15 +46,19 @@ final class JSONRPCFramingContentLengthHeaderEncoderTests: XCTestCase { } private func readOutboundString() throws -> String? { - return try self.channel.readOutbound(as: ByteBuffer.self).map { + try self.channel.readOutbound(as: ByteBuffer.self).map { String(decoding: $0.readableBytesView, as: Unicode.UTF8.self) } } func testEmptyMessage() { XCTAssertNoThrow(try self.channel.writeOutbound(self.channel.allocator.buffer(capacity: 0))) - XCTAssertNoThrow(XCTAssertEqual("Content-Length: 0\r\n\r\n", - try self.readOutboundString())) + XCTAssertNoThrow( + XCTAssertEqual( + "Content-Length: 0\r\n\r\n", + try self.readOutboundString() + ) + ) XCTAssertNoThrow(XCTAssertNil(try self.readOutboundString())) } @@ -57,15 +66,21 @@ final class JSONRPCFramingContentLengthHeaderEncoderTests: XCTestCase { var buffer = self.channel.allocator.buffer(capacity: 8) buffer.writeString("01234567") XCTAssertNoThrow(try self.channel.writeOutbound(buffer)) - XCTAssertNoThrow(try { - while let encoded = try self.channel.readOutbound(as: ByteBuffer.self) { - // round trip it back - try self.channel.writeInbound(encoded) - } - }()) - XCTAssertNoThrow(XCTAssertEqual("01234567", - try self.channel.readInbound(as: ByteBuffer.self).map { - String(decoding: $0.readableBytesView, as: Unicode.UTF8.self) - })) + XCTAssertNoThrow( + try { + while let encoded = try self.channel.readOutbound(as: ByteBuffer.self) { + // round trip it back + try self.channel.writeInbound(encoded) + } + }() + ) + XCTAssertNoThrow( + XCTAssertEqual( + "01234567", + try self.channel.readInbound(as: ByteBuffer.self).map { + String(decoding: $0.readableBytesView, as: Unicode.UTF8.self) + } + ) + ) } } diff --git a/Tests/NIOExtrasTests/LengthFieldBasedFrameDecoderTest.swift b/Tests/NIOExtrasTests/LengthFieldBasedFrameDecoderTest.swift index caf26f02..4647efbe 100644 --- a/Tests/NIOExtrasTests/LengthFieldBasedFrameDecoderTest.swift +++ b/Tests/NIOExtrasTests/LengthFieldBasedFrameDecoderTest.swift @@ -12,19 +12,20 @@ // //===----------------------------------------------------------------------===// -import XCTest import NIOCore import NIOEmbedded -@testable import NIOExtras import NIOTestUtils +import XCTest + +@testable import NIOExtras private let standardDataString = "abcde" class LengthFieldBasedFrameDecoderTest: XCTestCase { - + private var channel: EmbeddedChannel! private var decoderUnderTest: ByteToMessageHandler! - + override func setUp() { self.channel = EmbeddedChannel() } @@ -51,218 +52,295 @@ class LengthFieldBasedFrameDecoderTest: XCTestCase { UInt32(UInt8.max) + 1, UInt32(UInt16.max) + 1, ] - + for input in inputs { var buffer = ByteBuffer() buffer.write24UInt(input, endianness: .big) XCTAssertEqual(buffer.read24UInt(endianness: .big), input) - + buffer.write24UInt(input, endianness: .little) XCTAssertEqual(buffer.read24UInt(endianness: .little), input) } } func testDecodeWithUInt8HeaderWithData() throws { - - self.decoderUnderTest = .init(LengthFieldBasedFrameDecoder(lengthFieldLength: .one, - lengthFieldEndianness: .little)) + + self.decoderUnderTest = .init( + LengthFieldBasedFrameDecoder( + lengthFieldLength: .one, + lengthFieldEndianness: .little + ) + ) XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.decoderUnderTest).wait()) let dataBytes: [UInt8] = [10, 20, 30, 40] let dataBytesLengthHeader = UInt8(dataBytes.count) - + var buffer = self.channel.allocator.buffer(capacity: 5) buffer.writeBytes([dataBytesLengthHeader]) buffer.writeBytes(dataBytes) - + XCTAssertTrue(try self.channel.writeInbound(buffer).isFull) - - XCTAssertNoThrow(XCTAssertEqual(dataBytes, - try self.channel.readInbound(as: ByteBuffer.self)?.readableBytesView.map { - $0 - })) + + XCTAssertNoThrow( + XCTAssertEqual( + dataBytes, + try self.channel.readInbound(as: ByteBuffer.self)?.readableBytesView.map { + $0 + } + ) + ) XCTAssertTrue(try self.channel.finish().isClean) } - + func testDecodeWithUInt16HeaderWithString() throws { - - self.decoderUnderTest = .init(LengthFieldBasedFrameDecoder(lengthFieldLength: .two, - lengthFieldEndianness: .little)) + + self.decoderUnderTest = .init( + LengthFieldBasedFrameDecoder( + lengthFieldLength: .two, + lengthFieldEndianness: .little + ) + ) XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.decoderUnderTest).wait()) let dataLength: UInt16 = 5 - - var buffer = self.channel.allocator.buffer(capacity: 7) // 2 byte header + 5 character string + + var buffer = self.channel.allocator.buffer(capacity: 7) // 2 byte header + 5 character string buffer.writeInteger(dataLength, endianness: .little, as: UInt16.self) buffer.writeString(standardDataString) - + XCTAssertTrue(try self.channel.writeInbound(buffer).isFull) - - XCTAssertNoThrow(XCTAssertEqual(standardDataString, - try (self.channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { - String(decoding: $0, as: Unicode.UTF8.self) - })) + + XCTAssertNoThrow( + XCTAssertEqual( + standardDataString, + try (self.channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { + String(decoding: $0, as: Unicode.UTF8.self) + } + ) + ) XCTAssertTrue(try self.channel.finish().isClean) } - + func testDecodeWithUInt24HeaderWithString() throws { - - self.decoderUnderTest = .init(LengthFieldBasedFrameDecoder(lengthFieldBitLength: .threeBytes, - lengthFieldEndianness: .big)) + + self.decoderUnderTest = .init( + LengthFieldBasedFrameDecoder( + lengthFieldBitLength: .threeBytes, + lengthFieldEndianness: .big + ) + ) XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.decoderUnderTest).wait()) - var buffer = self.channel.allocator.buffer(capacity: 8) // 3 byte header + 5 character string + var buffer = self.channel.allocator.buffer(capacity: 8) // 3 byte header + 5 character string buffer.writeBytes([0, 0, 5]) buffer.writeString(standardDataString) - + XCTAssertTrue(try self.channel.writeInbound(buffer).isFull) - - XCTAssertNoThrow(XCTAssertEqual(standardDataString, - try (self.channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { - String(decoding: $0, as: Unicode.UTF8.self) - })) + + XCTAssertNoThrow( + XCTAssertEqual( + standardDataString, + try (self.channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { + String(decoding: $0, as: Unicode.UTF8.self) + } + ) + ) XCTAssertTrue(try self.channel.finish().isClean) } - + func testDecodeWithUInt32HeaderWithString() throws { - - self.decoderUnderTest = .init(LengthFieldBasedFrameDecoder(lengthFieldLength: .four, - lengthFieldEndianness: .little)) + + self.decoderUnderTest = .init( + LengthFieldBasedFrameDecoder( + lengthFieldLength: .four, + lengthFieldEndianness: .little + ) + ) XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.decoderUnderTest).wait()) let dataLength: UInt32 = 5 - - var buffer = self.channel.allocator.buffer(capacity: 9) // 4 byte header + 5 character string + + var buffer = self.channel.allocator.buffer(capacity: 9) // 4 byte header + 5 character string buffer.writeInteger(dataLength, endianness: .little, as: UInt32.self) buffer.writeString(standardDataString) - + XCTAssertTrue(try self.channel.writeInbound(buffer).isFull) - - XCTAssertNoThrow(XCTAssertEqual(standardDataString, - try (self.channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { - String(decoding: $0, as: Unicode.UTF8.self) - })) + + XCTAssertNoThrow( + XCTAssertEqual( + standardDataString, + try (self.channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { + String(decoding: $0, as: Unicode.UTF8.self) + } + ) + ) XCTAssertTrue(try self.channel.finish().isClean) } - + func testDecodeWithUInt64HeaderWithString() throws { - - self.decoderUnderTest = .init(LengthFieldBasedFrameDecoder(lengthFieldLength: .eight, - lengthFieldEndianness: .little)) + + self.decoderUnderTest = .init( + LengthFieldBasedFrameDecoder( + lengthFieldLength: .eight, + lengthFieldEndianness: .little + ) + ) XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.decoderUnderTest).wait()) let dataLength: UInt64 = 5 - - var buffer = self.channel.allocator.buffer(capacity: 13) // 8 byte header + 5 character string + + var buffer = self.channel.allocator.buffer(capacity: 13) // 8 byte header + 5 character string buffer.writeInteger(dataLength, endianness: .little, as: UInt64.self) buffer.writeString(standardDataString) - + XCTAssertTrue(try self.channel.writeInbound(buffer).isFull) - - XCTAssertNoThrow(XCTAssertEqual(standardDataString, - try (self.channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { - String(decoding: $0, as: Unicode.UTF8.self) - })) + + XCTAssertNoThrow( + XCTAssertEqual( + standardDataString, + try (self.channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { + String(decoding: $0, as: Unicode.UTF8.self) + } + ) + ) XCTAssertTrue(try self.channel.finish().isClean) } - + func testDecodeWithInt64HeaderWithString() throws { - - self.decoderUnderTest = .init(LengthFieldBasedFrameDecoder(lengthFieldLength: .eight, - lengthFieldEndianness: .little)) + + self.decoderUnderTest = .init( + LengthFieldBasedFrameDecoder( + lengthFieldLength: .eight, + lengthFieldEndianness: .little + ) + ) XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.decoderUnderTest).wait()) let dataLength: Int64 = 5 - - var buffer = self.channel.allocator.buffer(capacity: 13) // 8 byte header + 5 character string + + var buffer = self.channel.allocator.buffer(capacity: 13) // 8 byte header + 5 character string buffer.writeInteger(dataLength, endianness: .little, as: Int64.self) buffer.writeString(standardDataString) - + XCTAssertTrue(try self.channel.writeInbound(buffer).isFull) - - XCTAssertNoThrow(XCTAssertEqual(standardDataString, - try (self.channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { - String(decoding: $0, as: Unicode.UTF8.self) - })) + + XCTAssertNoThrow( + XCTAssertEqual( + standardDataString, + try (self.channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { + String(decoding: $0, as: Unicode.UTF8.self) + } + ) + ) XCTAssertTrue(try self.channel.finish().isClean) } - + func testDecodeWithInt64HeaderStringBigEndian() throws { - - self.decoderUnderTest = .init(LengthFieldBasedFrameDecoder(lengthFieldLength: .eight, - lengthFieldEndianness: .big)) + + self.decoderUnderTest = .init( + LengthFieldBasedFrameDecoder( + lengthFieldLength: .eight, + lengthFieldEndianness: .big + ) + ) XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.decoderUnderTest).wait()) let dataLength: Int64 = 5 - - var buffer = self.channel.allocator.buffer(capacity: 13) // 8 byte header + 5 character string + + var buffer = self.channel.allocator.buffer(capacity: 13) // 8 byte header + 5 character string buffer.writeInteger(dataLength, endianness: .big, as: Int64.self) buffer.writeString(standardDataString) - + XCTAssertTrue(try self.channel.writeInbound(buffer).isFull) - - XCTAssertNoThrow(XCTAssertEqual(standardDataString, - try (self.channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { - String(decoding: $0, as: Unicode.UTF8.self) - })) + + XCTAssertNoThrow( + XCTAssertEqual( + standardDataString, + try (self.channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { + String(decoding: $0, as: Unicode.UTF8.self) + } + ) + ) XCTAssertTrue(try self.channel.finish().isClean) } - + func testDecodeWithInt64HeaderStringDefaultingToBigEndian() throws { - + self.decoderUnderTest = .init(LengthFieldBasedFrameDecoder(lengthFieldLength: .eight)) XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.decoderUnderTest).wait()) let dataLength: Int64 = 5 - - var buffer = self.channel.allocator.buffer(capacity: 13) // 8 byte header + 5 character string + + var buffer = self.channel.allocator.buffer(capacity: 13) // 8 byte header + 5 character string buffer.writeInteger(dataLength, endianness: .big, as: Int64.self) buffer.writeString(standardDataString) - + XCTAssertTrue(try self.channel.writeInbound(buffer).isFull) - - XCTAssertNoThrow(XCTAssertEqual(standardDataString, - try (self.channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { - String(decoding: $0, as: Unicode.UTF8.self) - })) + + XCTAssertNoThrow( + XCTAssertEqual( + standardDataString, + try (self.channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { + String(decoding: $0, as: Unicode.UTF8.self) + } + ) + ) XCTAssertTrue(try self.channel.finish().isClean) } - + func testDecodeWithUInt8HeaderTwoFrames() throws { - - self.decoderUnderTest = .init(LengthFieldBasedFrameDecoder(lengthFieldLength: .one, - lengthFieldEndianness: .little)) + + self.decoderUnderTest = .init( + LengthFieldBasedFrameDecoder( + lengthFieldLength: .one, + lengthFieldEndianness: .little + ) + ) XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.decoderUnderTest).wait()) let firstFrameDataLength: UInt8 = 5 let secondFrameDataLength: UInt8 = 3 let secondFrameString = "123" - - var buffer = self.channel.allocator.buffer(capacity: 10) // 1 byte header + 5 character string + 1 byte header + 3 character string + + // 1 byte header + 5 character string + 1 byte header + 3 character string + var buffer = self.channel.allocator.buffer(capacity: 10) buffer.writeInteger(firstFrameDataLength, endianness: .little, as: UInt8.self) buffer.writeString(standardDataString) buffer.writeInteger(secondFrameDataLength, endianness: .little, as: UInt8.self) buffer.writeString(secondFrameString) - - XCTAssertTrue(try self.channel.writeInbound(buffer).isFull) - XCTAssertNoThrow(XCTAssertEqual(standardDataString, - try (self.channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { - String(decoding: $0, as: Unicode.UTF8.self) - })) - XCTAssertNoThrow(XCTAssertEqual(secondFrameString, - try (self.channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { - String(decoding: $0, as: Unicode.UTF8.self) - })) + XCTAssertTrue(try self.channel.writeInbound(buffer).isFull) + XCTAssertNoThrow( + XCTAssertEqual( + standardDataString, + try (self.channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { + String(decoding: $0, as: Unicode.UTF8.self) + } + ) + ) + + XCTAssertNoThrow( + XCTAssertEqual( + secondFrameString, + try (self.channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { + String(decoding: $0, as: Unicode.UTF8.self) + } + ) + ) XCTAssertTrue(try self.channel.finish().isClean) } - + func testDecodeWithUInt8HeaderFrameSplitIncomingData() throws { - - self.decoderUnderTest = .init(LengthFieldBasedFrameDecoder(lengthFieldLength: .two, - lengthFieldEndianness: .little)) + + self.decoderUnderTest = .init( + LengthFieldBasedFrameDecoder( + lengthFieldLength: .two, + lengthFieldEndianness: .little + ) + ) XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.decoderUnderTest).wait()) let frameDataLength: UInt16 = 5 @@ -270,31 +348,31 @@ class LengthFieldBasedFrameDecoderTest: XCTestCase { // Write and try to read both bytes of the data individually let frameDataLengthFirstByte: UInt8 = UInt8(frameDataLength) let frameDataLengthSecondByte: UInt8 = 0 - - var firstBuffer = self.channel.allocator.buffer(capacity: 1) // Byte 1 of 2 byte header header + + var firstBuffer = self.channel.allocator.buffer(capacity: 1) // Byte 1 of 2 byte header header firstBuffer.writeInteger(frameDataLengthFirstByte, endianness: .little, as: UInt8.self) - + XCTAssertTrue(try self.channel.writeInbound(firstBuffer).isEmpty) - + // Read should fail because there is not yet enough data. XCTAssertNoThrow(XCTAssertNil(try self.channel.readInbound())) - - var secondBuffer = self.channel.allocator.buffer(capacity: 1) // Byte 2 of 2 byte header header + + var secondBuffer = self.channel.allocator.buffer(capacity: 1) // Byte 2 of 2 byte header header secondBuffer.writeInteger(frameDataLengthSecondByte, endianness: .little, as: UInt8.self) - + XCTAssertTrue(try self.channel.writeInbound(secondBuffer).isEmpty) - + // Read should fail because there is not yet enough data. XCTAssertNoThrow(XCTAssertNil(try self.channel.readInbound())) - + // Write and try to read each byte of the data individually for (index, character) in standardDataString.enumerated() { - + var characterBuffer = self.channel.allocator.buffer(capacity: 1) characterBuffer.writeString(String(character)) - + if index < standardDataString.count - 1 { - + XCTAssertTrue(try self.channel.writeInbound(characterBuffer).isEmpty) // Read should fail because there is not yet enough data. XCTAssertNoThrow(XCTAssertNil(try self.channel.readInbound())) @@ -302,135 +380,170 @@ class LengthFieldBasedFrameDecoderTest: XCTestCase { XCTAssertTrue(try self.channel.writeInbound(characterBuffer).isFull) } } - - XCTAssertNoThrow(XCTAssertEqual(standardDataString, - try (self.channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { - String(decoding: $0, as: Unicode.UTF8.self) - })) + + XCTAssertNoThrow( + XCTAssertEqual( + standardDataString, + try (self.channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { + String(decoding: $0, as: Unicode.UTF8.self) + } + ) + ) XCTAssertTrue(try self.channel.finish().isClean) } - + func testEmptyBuffer() throws { - - self.decoderUnderTest = .init(LengthFieldBasedFrameDecoder(lengthFieldLength: .one, - lengthFieldEndianness: .little)) + + self.decoderUnderTest = .init( + LengthFieldBasedFrameDecoder( + lengthFieldLength: .one, + lengthFieldEndianness: .little + ) + ) XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.decoderUnderTest).wait()) let buffer = self.channel.allocator.buffer(capacity: 1) XCTAssertTrue(try self.channel.writeInbound(buffer).isEmpty) XCTAssertTrue(try self.channel.finish().isClean) } - + func testDecodeWithUInt16HeaderWithPartialHeader() throws { - - self.decoderUnderTest = .init(LengthFieldBasedFrameDecoder(lengthFieldLength: .two, - lengthFieldEndianness: .little)) + + self.decoderUnderTest = .init( + LengthFieldBasedFrameDecoder( + lengthFieldLength: .two, + lengthFieldEndianness: .little + ) + ) XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.decoderUnderTest).wait()) - - let dataLength: UInt8 = 5 // 8 byte is only half the length required - - var buffer = self.channel.allocator.buffer(capacity: 7) // 2 byte header + 5 character string + + let dataLength: UInt8 = 5 // 8 byte is only half the length required + + var buffer = self.channel.allocator.buffer(capacity: 7) // 2 byte header + 5 character string buffer.writeInteger(dataLength, endianness: .little, as: UInt8.self) - + XCTAssertTrue(try self.channel.writeInbound(buffer).isEmpty) XCTAssertThrowsError(try channel.finish()) { error in if let error = error as? NIOExtrasErrors.LeftOverBytesError { - XCTAssertEqual(1 /* just the one byte of the length that arrived */, error.leftOverBytes.readableBytes) + XCTAssertEqual(1, error.leftOverBytes.readableBytes) // just the one byte of the length that arrived } else { XCTFail("unexpected error: \(error)") } } } - + func testDecodeWithUInt16HeaderWithPartialBody() throws { - - self.decoderUnderTest = .init(LengthFieldBasedFrameDecoder(lengthFieldLength: .two, - lengthFieldEndianness: .little)) + + self.decoderUnderTest = .init( + LengthFieldBasedFrameDecoder( + lengthFieldLength: .two, + lengthFieldEndianness: .little + ) + ) XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.decoderUnderTest).wait()) let dataLength: UInt16 = 7 - - var buffer = self.channel.allocator.buffer(capacity: 9) // 2 byte header + 7 character string + + var buffer = self.channel.allocator.buffer(capacity: 9) // 2 byte header + 7 character string buffer.writeInteger(dataLength, endianness: .little, as: UInt16.self) - buffer.writeString(standardDataString) // 2 bytes short of the 7 required. - + buffer.writeString(standardDataString) // 2 bytes short of the 7 required. + XCTAssertTrue(try self.channel.writeInbound(buffer).isEmpty) XCTAssertThrowsError(try channel.finish()) { error in if let error = error as? NIOExtrasErrors.LeftOverBytesError { - XCTAssertEqual(Int(dataLength) - 2 /* we're 2 bytes short of the required 7 */, - error.leftOverBytes.readableBytes) + XCTAssertEqual( + Int(dataLength) - 2, // we're 2 bytes short of the required 7 + error.leftOverBytes.readableBytes + ) } else { XCTFail("unexpected error: \(error)") } } } - + func testRemoveHandlerWhenBufferIsEmpty() throws { - - self.decoderUnderTest = .init(LengthFieldBasedFrameDecoder(lengthFieldLength: .eight, - lengthFieldEndianness: .little)) + + self.decoderUnderTest = .init( + LengthFieldBasedFrameDecoder( + lengthFieldLength: .eight, + lengthFieldEndianness: .little + ) + ) try? self.channel.pipeline.addHandler(self.decoderUnderTest).wait() - + let dataLength: Int64 = 5 - - var buffer = self.channel.allocator.buffer(capacity: 13) // 8 byte header + 5 character string + + var buffer = self.channel.allocator.buffer(capacity: 13) // 8 byte header + 5 character string buffer.writeInteger(dataLength, endianness: .little, as: Int64.self) buffer.writeString(standardDataString) - + XCTAssertTrue(try self.channel.writeInbound(buffer).isFull) - + let removeFuture = self.channel.pipeline.removeHandler(self.decoderUnderTest) (channel.eventLoop as! EmbeddedEventLoop).run() XCTAssertNoThrow(try removeFuture.wait()) - - XCTAssertNoThrow(XCTAssertEqual(standardDataString, - try (self.channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { - String(decoding: $0, as: Unicode.UTF8.self) - })) + XCTAssertNoThrow( + XCTAssertEqual( + standardDataString, + try (self.channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { + String(decoding: $0, as: Unicode.UTF8.self) + } + ) + ) XCTAssertNoThrow(try self.channel.throwIfErrorCaught()) XCTAssertTrue(try self.channel.finish().isClean) } - + func testRemoveHandlerWhenBufferIsNotEmpty() throws { - - self.decoderUnderTest = .init(LengthFieldBasedFrameDecoder(lengthFieldLength: .eight, - lengthFieldEndianness: .little)) + + self.decoderUnderTest = .init( + LengthFieldBasedFrameDecoder( + lengthFieldLength: .eight, + lengthFieldEndianness: .little + ) + ) try? self.channel.pipeline.addHandler(self.decoderUnderTest).wait() - + let extraUnusedDataString = "fghi" let dataLength: Int64 = 5 - var buffer = self.channel.allocator.buffer(capacity: 17) // 8 byte header + 5 character string + 4 unused + var buffer = self.channel.allocator.buffer(capacity: 17) // 8 byte header + 5 character string + 4 unused buffer.writeInteger(dataLength, endianness: .little, as: Int64.self) buffer.writeString(standardDataString + extraUnusedDataString) - + XCTAssertTrue(try channel.writeInbound(buffer).isFull) - + let removeFuture = self.channel.pipeline.removeHandler(self.decoderUnderTest) (channel.eventLoop as! EmbeddedEventLoop).run() XCTAssertNoThrow(try removeFuture.wait()) - + XCTAssertThrowsError(try self.channel.throwIfErrorCaught()) { error in guard let error = error as? NIOExtrasErrors.LeftOverBytesError else { XCTFail() return } - + var expectedBuffer = self.channel.allocator.buffer(capacity: 7) expectedBuffer.writeString(extraUnusedDataString) XCTAssertEqual(error.leftOverBytes, expectedBuffer) } - - XCTAssertNoThrow(XCTAssertEqual(standardDataString, - try (self.channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { - String(decoding: $0, as: Unicode.UTF8.self) - })) + + XCTAssertNoThrow( + XCTAssertEqual( + standardDataString, + try (self.channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { + String(decoding: $0, as: Unicode.UTF8.self) + } + ) + ) XCTAssertTrue(try self.channel.finish().isClean) } func testCloseInChannelRead() { - let channel = EmbeddedChannel(handler: ByteToMessageHandler(LengthFieldBasedFrameDecoder(lengthFieldLength: .four))) + let channel = EmbeddedChannel( + handler: ByteToMessageHandler(LengthFieldBasedFrameDecoder(lengthFieldLength: .four)) + ) class CloseInReadHandler: ChannelInboundHandler { typealias InboundIn = ByteBuffer @@ -457,41 +570,69 @@ class LengthFieldBasedFrameDecoderTest: XCTestCase { func testBasicVerification() { let inputs: [(NIOLengthFieldBitLength, [(Int, String)])] = [ - (.oneByte, [ - (6, "abcdef"), - (0, ""), - (9, "123456789"), - (Int(UInt8.max), - String(decoding: Array(repeating: UInt8(ascii: "X"), count: Int(UInt8.max)), as: Unicode.UTF8.self)), - ]), - (.twoBytes, [ - (1, "a"), - (0, ""), - (9, "123456789"), - (307, - String(decoding: Array(repeating: UInt8(ascii: "X"), count: 307), as: Unicode.UTF8.self)), - ]), - (.threeBytes, [ - (1, "a"), - (0, ""), - (9, "123456789"), - (307, - String(decoding: Array(repeating: UInt8(ascii: "X"), count: 307), as: Unicode.UTF8.self)), - ]), - (.fourBytes, [ - (1, "a"), - (0, ""), - (3, "333"), - (307, - String(decoding: Array(repeating: UInt8(ascii: "X"), count: 307), as: Unicode.UTF8.self)), - ]), - (.eightBytes, [ - (1, "a"), - (0, ""), - (4, "aaaa"), - (307, - String(decoding: Array(repeating: UInt8(ascii: "X"), count: 307), as: Unicode.UTF8.self)), - ]), + ( + .oneByte, + [ + (6, "abcdef"), + (0, ""), + (9, "123456789"), + ( + Int(UInt8.max), + String( + decoding: Array(repeating: UInt8(ascii: "X"), count: Int(UInt8.max)), + as: Unicode.UTF8.self + ) + ), + ] + ), + ( + .twoBytes, + [ + (1, "a"), + (0, ""), + (9, "123456789"), + ( + 307, + String(decoding: Array(repeating: UInt8(ascii: "X"), count: 307), as: Unicode.UTF8.self) + ), + ] + ), + ( + .threeBytes, + [ + (1, "a"), + (0, ""), + (9, "123456789"), + ( + 307, + String(decoding: Array(repeating: UInt8(ascii: "X"), count: 307), as: Unicode.UTF8.self) + ), + ] + ), + ( + .fourBytes, + [ + (1, "a"), + (0, ""), + (3, "333"), + ( + 307, + String(decoding: Array(repeating: UInt8(ascii: "X"), count: 307), as: Unicode.UTF8.self) + ), + ] + ), + ( + .eightBytes, + [ + (1, "a"), + (0, ""), + (4, "aaaa"), + ( + 307, + String(decoding: Array(repeating: UInt8(ascii: "X"), count: 307), as: Unicode.UTF8.self) + ), + ] + ), ] for input in inputs { @@ -509,63 +650,81 @@ class LengthFieldBasedFrameDecoderTest: XCTestCase { let bytes = byteBuffer(length: input.0, string: input.1) return (bytes, [bytes.getSlice(at: bytes.readerIndex + lenBytes.length, length: input.0)!]) } - XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder(inputOutputPairs: inputOutputPairs) { - LengthFieldBasedFrameDecoder(lengthFieldBitLength: lenBytes) - }) + XCTAssertNoThrow( + try ByteToMessageDecoderVerifier.verifyDecoder(inputOutputPairs: inputOutputPairs) { + LengthFieldBasedFrameDecoder(lengthFieldBitLength: lenBytes) + } + ) } } func testMaximumAllowedLengthWith32BitFieldLength() throws { - self.decoderUnderTest = .init(LengthFieldBasedFrameDecoder(lengthFieldLength: .four, - lengthFieldEndianness: .little)) + self.decoderUnderTest = .init( + LengthFieldBasedFrameDecoder( + lengthFieldLength: .four, + lengthFieldEndianness: .little + ) + ) XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.decoderUnderTest).wait()) let dataLength = UInt32(Int32.max) - - var buffer = self.channel.allocator.buffer(capacity: 4) // 4 byte header + + var buffer = self.channel.allocator.buffer(capacity: 4) // 4 byte header buffer.writeInteger(dataLength, endianness: .little, as: UInt32.self) buffer.writeString(standardDataString) - + XCTAssertNoThrow(try self.channel.writeInbound(buffer)) } - + func testMaliciousLengthWith32BitFieldLength() throws { - self.decoderUnderTest = .init(LengthFieldBasedFrameDecoder(lengthFieldLength: .four, - lengthFieldEndianness: .little)) + self.decoderUnderTest = .init( + LengthFieldBasedFrameDecoder( + lengthFieldLength: .four, + lengthFieldEndianness: .little + ) + ) XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.decoderUnderTest).wait()) let dataLength = UInt32(Int32.max) + 1 - - var buffer = self.channel.allocator.buffer(capacity: 4) // 4 byte header + + var buffer = self.channel.allocator.buffer(capacity: 4) // 4 byte header buffer.writeInteger(dataLength, endianness: .little, as: UInt32.self) buffer.writeString(standardDataString) - + XCTAssertThrowsError(try self.channel.writeInbound(buffer)) } - + func testMaximumAllowedLengthWith64BitFieldLength() throws { - self.decoderUnderTest = .init(LengthFieldBasedFrameDecoder(lengthFieldLength: .eight, - lengthFieldEndianness: .little)) + self.decoderUnderTest = .init( + LengthFieldBasedFrameDecoder( + lengthFieldLength: .eight, + lengthFieldEndianness: .little + ) + ) XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.decoderUnderTest).wait()) let dataLength = UInt64(Int32.max) - - var buffer = self.channel.allocator.buffer(capacity: 8) // 8 byte header + + var buffer = self.channel.allocator.buffer(capacity: 8) // 8 byte header buffer.writeInteger(dataLength, endianness: .little, as: UInt64.self) buffer.writeString(standardDataString) - + XCTAssertNoThrow(try self.channel.writeInbound(buffer)) } - + func testMaliciousLengthWith64BitFieldLength() { - self.decoderUnderTest = .init(LengthFieldBasedFrameDecoder(lengthFieldLength: .eight, - lengthFieldEndianness: .little)) + self.decoderUnderTest = .init( + LengthFieldBasedFrameDecoder( + lengthFieldLength: .eight, + lengthFieldEndianness: .little + ) + ) XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.decoderUnderTest).wait()) let dataLength = UInt64(Int32.max) + 1 - - var buffer = self.channel.allocator.buffer(capacity: 8) // 8 byte header + + var buffer = self.channel.allocator.buffer(capacity: 8) // 8 byte header buffer.writeInteger(dataLength, endianness: .little, as: UInt64.self) - + XCTAssertThrowsError(try self.channel.writeInbound(buffer)) } } diff --git a/Tests/NIOExtrasTests/LengthFieldPrependerTest.swift b/Tests/NIOExtrasTests/LengthFieldPrependerTest.swift index 8820fb80..aace8c04 100644 --- a/Tests/NIOExtrasTests/LengthFieldPrependerTest.swift +++ b/Tests/NIOExtrasTests/LengthFieldPrependerTest.swift @@ -12,9 +12,10 @@ // //===----------------------------------------------------------------------===// -import XCTest import NIOCore import NIOEmbedded +import XCTest + @testable import NIOExtras private let standardDataString = "abcde" @@ -23,7 +24,7 @@ private let standardDataStringCount = standardDataString.utf8.count class LengthFieldPrependerTest: XCTestCase { private var channel: EmbeddedChannel! private var encoderUnderTest: LengthFieldPrepender! - + override func setUp() { self.channel = EmbeddedChannel() } @@ -32,429 +33,449 @@ class LengthFieldPrependerTest: XCTestCase { buffer.write24UInt(5, endianness: .little) XCTAssertEqual(Array(buffer.readableBytesView), [5, 0, 0]) XCTAssertEqual(buffer.read24UInt(endianness: .little), 5) - + buffer.write24UInt(5, endianness: .big) XCTAssertEqual(Array(buffer.readableBytesView), [0, 0, 5]) XCTAssertEqual(buffer.read24UInt(endianness: .big), 5) } func testEncodeWithUInt8HeaderWithData() throws { - - self.encoderUnderTest = LengthFieldPrepender(lengthFieldLength: .one, - lengthFieldEndianness: .little) + + self.encoderUnderTest = LengthFieldPrepender( + lengthFieldLength: .one, + lengthFieldEndianness: .little + ) XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.encoderUnderTest).wait()) - + let dataBytes: [UInt8] = [10, 20, 30, 40] - + var buffer = self.channel.allocator.buffer(capacity: dataBytes.count) buffer.writeBytes(dataBytes) - + XCTAssertNoThrow(try self.channel.writeAndFlush(buffer).wait()) - + if case .some(.byteBuffer(var headerBuffer)) = try self.channel.readOutbound(as: IOData.self) { - + let outputData = headerBuffer.readBytes(length: headerBuffer.readableBytes) - XCTAssertEqual([UInt8(dataBytes.count)], outputData) - + XCTAssertEqual([UInt8(dataBytes.count)], outputData) + } else { XCTFail("couldn't read ByteBuffer from channel") } - + if case .some(.byteBuffer(var outputBuffer)) = try self.channel.readOutbound(as: IOData.self) { - + let outputData = outputBuffer.readBytes(length: outputBuffer.readableBytes) - XCTAssertEqual(dataBytes, outputData) - + XCTAssertEqual(dataBytes, outputData) + } else { XCTFail("couldn't read ByteBuffer from channel") } - + XCTAssertNoThrow(XCTAssertNil(try self.channel.readOutbound())) XCTAssertTrue(try self.channel.finish().isClean) } - + func testEncodeWithUInt16HeaderWithString() throws { - + let endianness: Endianness = .little - - self.encoderUnderTest = LengthFieldPrepender(lengthFieldLength: .two, - lengthFieldEndianness: endianness) - + + self.encoderUnderTest = LengthFieldPrepender( + lengthFieldLength: .two, + lengthFieldEndianness: endianness + ) + XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.encoderUnderTest).wait()) - + var buffer = self.channel.allocator.buffer(capacity: standardDataStringCount) buffer.writeString(standardDataString) - - XCTAssertNoThrow(try self.channel.writeAndFlush(buffer).wait()) - + + XCTAssertNoThrow(try self.channel.writeAndFlush(buffer).wait()) + if case .some(.byteBuffer(var outputBuffer)) = try self.channel.readOutbound(as: IOData.self) { - + let sizeInHeader = outputBuffer.readInteger(endianness: endianness, as: UInt16.self).map({ Int($0) }) XCTAssertEqual(standardDataStringCount, sizeInHeader) - + let additionalData = outputBuffer.readBytes(length: 1) XCTAssertNil(additionalData) - + } else { XCTFail("couldn't read ByteBuffer from channel") } - + if case .some(.byteBuffer(var outputBuffer)) = try self.channel.readOutbound(as: IOData.self) { - + let bodyString = outputBuffer.readString(length: standardDataStringCount) XCTAssertEqual(standardDataString, bodyString) - + let additionalData = outputBuffer.readBytes(length: 1) XCTAssertNil(additionalData) - + } else { XCTFail("couldn't read ByteBuffer from channel") } - + XCTAssertNoThrow(XCTAssertNil(try self.channel.readOutbound())) XCTAssertTrue(try self.channel.finish().isClean) } - + func testEncodeWithUInt24HeaderWithString() throws { - + let endianness: Endianness = .little - - self.encoderUnderTest = LengthFieldPrepender(lengthFieldBitLength: .threeBytes, - lengthFieldEndianness: endianness) - + + self.encoderUnderTest = LengthFieldPrepender( + lengthFieldBitLength: .threeBytes, + lengthFieldEndianness: endianness + ) + XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.encoderUnderTest).wait()) - + var buffer = self.channel.allocator.buffer(capacity: standardDataStringCount) buffer.writeString(standardDataString) - - XCTAssertNoThrow(try self.channel.writeAndFlush(buffer).wait()) - + + XCTAssertNoThrow(try self.channel.writeAndFlush(buffer).wait()) + if case .some(.byteBuffer(var outputBuffer)) = try self.channel.readOutbound(as: IOData.self) { - + let sizeInHeader = outputBuffer.read24UInt(endianness: endianness).map({ Int($0) }) XCTAssertEqual(standardDataStringCount, sizeInHeader) - + let additionalData = outputBuffer.readBytes(length: 1) XCTAssertNil(additionalData) - + } else { XCTFail("couldn't read ByteBuffer from channel") } - + if case .some(.byteBuffer(var outputBuffer)) = try self.channel.readOutbound(as: IOData.self) { - + let bodyString = outputBuffer.readString(length: standardDataStringCount) XCTAssertEqual(standardDataString, bodyString) - + let additionalData = outputBuffer.readBytes(length: 1) XCTAssertNil(additionalData) - + } else { XCTFail("couldn't read ByteBuffer from channel") } - + XCTAssertNoThrow(XCTAssertNil(try self.channel.readOutbound())) XCTAssertTrue(try self.channel.finish().isClean) } - + func testEncodeWithUInt32HeaderWithString() throws { - + let endianness: Endianness = .little - - self.encoderUnderTest = LengthFieldPrepender(lengthFieldLength: .four, - lengthFieldEndianness: endianness) - + + self.encoderUnderTest = LengthFieldPrepender( + lengthFieldLength: .four, + lengthFieldEndianness: endianness + ) + XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.encoderUnderTest).wait()) - + var buffer = self.channel.allocator.buffer(capacity: standardDataStringCount) buffer.writeString(standardDataString) - - XCTAssertNoThrow(try self.channel.writeAndFlush(buffer).wait()) - + + XCTAssertNoThrow(try self.channel.writeAndFlush(buffer).wait()) + if case .some(.byteBuffer(var outputBuffer)) = try self.channel.readOutbound(as: IOData.self) { - + let sizeInHeader = outputBuffer.readInteger(endianness: endianness, as: UInt32.self).map({ Int($0) }) XCTAssertEqual(standardDataStringCount, sizeInHeader) - + let additionalData = outputBuffer.readBytes(length: 1) XCTAssertNil(additionalData) - + } else { XCTFail("couldn't read ByteBuffer from channel") } - + if case .some(.byteBuffer(var outputBuffer)) = try self.channel.readOutbound(as: IOData.self) { - + let bodyString = outputBuffer.readString(length: standardDataStringCount) XCTAssertEqual(standardDataString, bodyString) - + let additionalData = outputBuffer.readBytes(length: 1) XCTAssertNil(additionalData) - + } else { XCTFail("couldn't read ByteBuffer from channel") } - + XCTAssertNoThrow(XCTAssertNil(try self.channel.readOutbound())) XCTAssertTrue(try self.channel.finish().isClean) } - + func testEncodeWithUInt64HeaderWithString() throws { - + let endianness: Endianness = .little - - self.encoderUnderTest = LengthFieldPrepender(lengthFieldLength: .eight, - lengthFieldEndianness: endianness) - + + self.encoderUnderTest = LengthFieldPrepender( + lengthFieldLength: .eight, + lengthFieldEndianness: endianness + ) + XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.encoderUnderTest).wait()) - + var buffer = self.channel.allocator.buffer(capacity: standardDataStringCount) buffer.writeString(standardDataString) - + XCTAssertNoThrow(try self.channel.writeAndFlush(buffer).wait()) - + if case .some(.byteBuffer(var outputBuffer)) = try self.channel.readOutbound(as: IOData.self) { - + let sizeInHeader = outputBuffer.readInteger(endianness: endianness, as: UInt64.self).map({ Int($0) }) XCTAssertEqual(standardDataStringCount, sizeInHeader) - + let additionalData = outputBuffer.readBytes(length: 1) XCTAssertNil(additionalData) - + } else { XCTFail("couldn't read ByteBuffer from channel") } - + if case .some(.byteBuffer(var outputBuffer)) = try self.channel.readOutbound(as: IOData.self) { - + let bodyString = outputBuffer.readString(length: outputBuffer.readableBytes) XCTAssertEqual(standardDataString, bodyString) - + } else { XCTFail("couldn't read ByteBuffer from channel") } - + XCTAssertNoThrow(XCTAssertNil(try self.channel.readOutbound())) XCTAssertTrue(try self.channel.finish().isClean) } - + func testEncodeWithInt64HeaderWithString() throws { - + let endianness: Endianness = .little - - self.encoderUnderTest = LengthFieldPrepender(lengthFieldLength: .eight, - lengthFieldEndianness: endianness) - + + self.encoderUnderTest = LengthFieldPrepender( + lengthFieldLength: .eight, + lengthFieldEndianness: endianness + ) + XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.encoderUnderTest).wait()) - + var buffer = self.channel.allocator.buffer(capacity: standardDataStringCount) buffer.writeString(standardDataString) - + XCTAssertNoThrow(try self.channel.writeAndFlush(buffer).wait()) - + if case .some(.byteBuffer(var outputBuffer)) = try self.channel.readOutbound(as: IOData.self) { - + let sizeInHeader = outputBuffer.readInteger(endianness: endianness, as: Int64.self).map({ Int($0) }) XCTAssertEqual(standardDataStringCount, sizeInHeader) - + let additionalData = outputBuffer.readBytes(length: 1) XCTAssertNil(additionalData) - + } else { XCTFail("couldn't read ByteBuffer from channel") } - + if case .some(.byteBuffer(var outputBuffer)) = try self.channel.readOutbound(as: IOData.self) { let bodyString = outputBuffer.readString(length: standardDataStringCount) XCTAssertEqual(standardDataString, bodyString) - + let additionalData = outputBuffer.readBytes(length: 1) XCTAssertNil(additionalData) - + } else { XCTFail("couldn't read ByteBuffer from channel") } - + XCTAssertNoThrow(XCTAssertNil(try self.channel.readOutbound())) XCTAssertTrue(try self.channel.finish().isClean) } - + func testEncodeWithUInt64HeaderStringBigEndian() throws { - + let endianness: Endianness = .big - - self.encoderUnderTest = LengthFieldPrepender(lengthFieldLength: .eight, - lengthFieldEndianness: endianness) - + + self.encoderUnderTest = LengthFieldPrepender( + lengthFieldLength: .eight, + lengthFieldEndianness: endianness + ) + XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.encoderUnderTest).wait()) - + var buffer = self.channel.allocator.buffer(capacity: standardDataStringCount) buffer.writeString(standardDataString) - + XCTAssertNoThrow(try self.channel.writeAndFlush(buffer).wait()) - + if case .some(.byteBuffer(var outputBuffer)) = try self.channel.readOutbound(as: IOData.self) { - + let sizeInHeader = outputBuffer.readInteger(endianness: endianness, as: UInt64.self).map({ Int($0) }) XCTAssertEqual(standardDataStringCount, sizeInHeader) - + let additionalData = outputBuffer.readBytes(length: 1) XCTAssertNil(additionalData) - + } else { XCTFail("couldn't read ByteBuffer from channel") } - + if case .some(.byteBuffer(var outputBuffer)) = try self.channel.readOutbound(as: IOData.self) { let bodyString = outputBuffer.readString(length: standardDataStringCount) XCTAssertEqual(standardDataString, bodyString) - + let additionalData = outputBuffer.readBytes(length: 1) XCTAssertNil(additionalData) - + } else { XCTFail("couldn't read ByteBuffer from channel") } - + XCTAssertNoThrow(XCTAssertNil(try self.channel.readOutbound())) XCTAssertTrue(try self.channel.finish().isClean) } - + func testEncodeWithInt64HeaderStringDefaultingToBigEndian() throws { - + self.encoderUnderTest = LengthFieldPrepender(lengthFieldLength: .eight) - + XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.encoderUnderTest).wait()) - + var buffer = self.channel.allocator.buffer(capacity: standardDataStringCount) buffer.writeString(standardDataString) - + XCTAssertNoThrow(try self.channel.writeAndFlush(buffer).wait()) - + if case .some(.byteBuffer(var outputBuffer)) = try self.channel.readOutbound(as: IOData.self) { - + let sizeInHeader = outputBuffer.readInteger(endianness: .big, as: UInt64.self).map({ Int($0) }) XCTAssertEqual(standardDataStringCount, sizeInHeader) - + let additionalData = outputBuffer.readBytes(length: 1) XCTAssertNil(additionalData) - + } else { XCTFail("couldn't read ByteBuffer from channel") } - + if case .some(.byteBuffer(var outputBuffer)) = try self.channel.readOutbound(as: IOData.self) { - + let bodyString = outputBuffer.readString(length: standardDataStringCount) XCTAssertEqual(standardDataString, bodyString) - + let additionalData = outputBuffer.readBytes(length: 1) XCTAssertNil(additionalData) - + } else { XCTFail("couldn't read ByteBuffer from channel") } - + XCTAssertNoThrow(XCTAssertNil(try self.channel.readOutbound())) XCTAssertTrue(try self.channel.finish().isClean) } - + func testEmptyBuffer() throws { - + let endianness: Endianness = .little - - self.encoderUnderTest = LengthFieldPrepender(lengthFieldLength: .eight, - lengthFieldEndianness: endianness) - + + self.encoderUnderTest = LengthFieldPrepender( + lengthFieldLength: .eight, + lengthFieldEndianness: endianness + ) + XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.encoderUnderTest).wait()) - + let buffer = self.channel.allocator.buffer(capacity: 0) - + XCTAssertNoThrow(try self.channel.writeAndFlush(buffer).wait()) - + if case .some(.byteBuffer(var outputBuffer)) = try self.channel.readOutbound(as: IOData.self) { - + let sizeInHeader = outputBuffer.readInteger(endianness: endianness, as: UInt64.self).map({ Int($0) }) XCTAssertEqual(0, sizeInHeader) - + let additionalData = outputBuffer.readBytes(length: 1) XCTAssertNil(additionalData) - + } else { XCTFail("couldn't read ByteBuffer from channel") } - + // Check that if there is any more buffer it has a zero size. if case .some(.byteBuffer(let outputBuffer)) = try self.channel.readOutbound(as: IOData.self) { XCTAssertEqual(0, outputBuffer.readableBytes) } - + XCTAssertNoThrow(XCTAssertNil(try self.channel.readOutbound())) XCTAssertTrue(try self.channel.finish().isClean) } - + func testLargeBuffer() throws { - + let endianness: Endianness = .little - - self.encoderUnderTest = LengthFieldPrepender(lengthFieldLength: .eight, - lengthFieldEndianness: endianness) - + + self.encoderUnderTest = LengthFieldPrepender( + lengthFieldLength: .eight, + lengthFieldEndianness: endianness + ) + XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.encoderUnderTest).wait()) - - let contents = Array(repeating: 200, count: 514) - + + let contents = [UInt8](repeating: 200, count: 514) + var buffer = self.channel.allocator.buffer(capacity: contents.count) buffer.writeBytes(contents) - + XCTAssertNoThrow(try self.channel.writeAndFlush(buffer).wait()) - + if case .some(.byteBuffer(var outputBuffer)) = try self.channel.readOutbound(as: IOData.self) { - + let sizeInHeader = outputBuffer.readInteger(endianness: endianness, as: UInt64.self).map({ Int($0) }) XCTAssertEqual(contents.count, sizeInHeader) - + let additionalData = outputBuffer.readBytes(length: 1) XCTAssertNil(additionalData) - + } else { XCTFail("couldn't read ByteBuffer from channel") } - + if case .some(.byteBuffer(var outputBuffer)) = try self.channel.readOutbound(as: IOData.self) { let bodyData = outputBuffer.readBytes(length: outputBuffer.readableBytes) XCTAssertEqual(contents, bodyData) - + } else { XCTFail("couldn't read ByteBuffer from channel") } - + XCTAssertNoThrow(XCTAssertNil(try self.channel.readOutbound())) XCTAssertTrue(try self.channel.finish().isClean) } - + func testTooLargeForLengthField() throws { - + let endianness: Endianness = .little // One byte has maximum integer description of 256 - self.encoderUnderTest = LengthFieldPrepender(lengthFieldLength: .one, - lengthFieldEndianness: endianness) - + self.encoderUnderTest = LengthFieldPrepender( + lengthFieldLength: .one, + lengthFieldEndianness: endianness + ) + XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.encoderUnderTest).wait()) - - let contents = Array(repeating: 200, count: 300) - + + let contents = [UInt8](repeating: 200, count: 300) + var buffer = self.channel.allocator.buffer(capacity: contents.count) buffer.writeBytes(contents) - - XCTAssertThrowsError(try self.channel.writeAndFlush(buffer).wait() ) { error in + + XCTAssertThrowsError(try self.channel.writeAndFlush(buffer).wait()) { error in XCTAssertEqual(.messageDataTooLongForLengthField, error as? LengthFieldPrependerError) } - + XCTAssertNoThrow(XCTAssertNil(try self.channel.readOutbound())) XCTAssertTrue(try self.channel.finish().isClean) } diff --git a/Tests/NIOExtrasTests/LineBasedFrameDecoderTest.swift b/Tests/NIOExtrasTests/LineBasedFrameDecoderTest.swift index 7b134706..8923f60f 100644 --- a/Tests/NIOExtrasTests/LineBasedFrameDecoderTest.swift +++ b/Tests/NIOExtrasTests/LineBasedFrameDecoderTest.swift @@ -12,11 +12,12 @@ // //===----------------------------------------------------------------------===// -import XCTest -@testable import NIOCore // to inspect the cumulationBuffer import NIOEmbedded import NIOExtras import NIOTestUtils +import XCTest + +@testable import NIOCore // to inspect the cumulationBuffer class LineBasedFrameDecoderTest: XCTestCase { private var channel: EmbeddedChannel! @@ -39,9 +40,9 @@ class LineBasedFrameDecoderTest: XCTestCase { func testDecodeOneCharacterAtATime() throws { let message = "abcdefghij\r" // we write one character at a time - try message.forEach { + for character in message { var buffer = self.channel.allocator.buffer(capacity: 1) - buffer.writeString("\($0)") + buffer.writeString("\(character)") XCTAssertTrue(try self.channel.writeInbound(buffer).isEmpty) } // let's add `\n` @@ -49,10 +50,14 @@ class LineBasedFrameDecoderTest: XCTestCase { buffer.writeString("\n") XCTAssertTrue(try self.channel.writeInbound(buffer).isFull) - XCTAssertNoThrow(XCTAssertEqual("abcdefghij", - (try self.channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { - String(decoding: $0[0..<10], as: Unicode.UTF8.self) - })) + XCTAssertNoThrow( + XCTAssertEqual( + "abcdefghij", + (try self.channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { + String(decoding: $0[0..<10], as: Unicode.UTF8.self) + } + ) + ) XCTAssertTrue(try self.channel.finish().isClean) } @@ -165,15 +170,25 @@ class LineBasedFrameDecoderTest: XCTestCase { buffer.writeString("a\nbb\nccc\ndddd\neeeee\nffffff\nXXX") XCTAssertNoThrow(try self.channel.writeInbound(buffer)) for s in ["a", "bb", "ccc", "dddd", "eeeee", "ffffff"] { - XCTAssertNoThrow(XCTAssertEqual(s, - (try self.channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { - String(decoding: $0, as: Unicode.UTF8.self) - })) + XCTAssertNoThrow( + XCTAssertEqual( + s, + (try self.channel.readInbound(as: ByteBuffer.self)?.readableBytesView).map { + String(decoding: $0, as: Unicode.UTF8.self) + } + ) + ) } XCTAssertNoThrow(XCTAssertNil(try self.channel.readInbound(as: ByteBuffer.self))) - XCTAssertNoThrow(try XCTAssertEqual("XXX", - String(decoding: receivedLeftOversPromise.futureResult.wait().readableBytesView, - as: UTF8.self))) + XCTAssertNoThrow( + try XCTAssertEqual( + "XXX", + String( + decoding: receivedLeftOversPromise.futureResult.wait().readableBytesView, + as: UTF8.self + ) + ) + ) } func testDripFedCRLN() { @@ -203,11 +218,16 @@ class LineBasedFrameDecoderTest: XCTestCase { ("a\r\n", [byteBuffer("a")]), ("a\n", [byteBuffer("a")]), ("a\rb\n", [byteBuffer("a\rb")]), - ("Content-Length: 17\r\nConnection: close\r\n\r\n", [byteBuffer("Content-Length: 17"), - byteBuffer("Connection: close"), - byteBuffer("")]) + ( + "Content-Length: 17\r\nConnection: close\r\n\r\n", + [ + byteBuffer("Content-Length: 17"), + byteBuffer("Connection: close"), + byteBuffer(""), + ] + ), ]) { - return LineBasedFrameDecoder() + LineBasedFrameDecoder() } } catch { print(error) @@ -219,32 +239,41 @@ class LineBasedFrameDecoderTest: XCTestCase { let decoder = LineBasedFrameDecoder() let b2mp = NIOSingleStepByteToMessageProcessor(decoder) var callCount = 0 - XCTAssertNoThrow(try b2mp.process(buffer: ByteBuffer(string: "1\n\n2\n3\n")) { line in - callCount += 1 - switch callCount { - case 1: - XCTAssertEqual(ByteBuffer(string: "1"), line) - case 2: - XCTAssertEqual(ByteBuffer(string: ""), line) - case 3: - XCTAssertEqual(ByteBuffer(string: "2"), line) - case 4: - XCTAssertEqual(ByteBuffer(string: "3"), line) - default: - XCTFail("not expecting call no \(callCount)") + XCTAssertNoThrow( + try b2mp.process(buffer: ByteBuffer(string: "1\n\n2\n3\n")) { line in + callCount += 1 + switch callCount { + case 1: + XCTAssertEqual(ByteBuffer(string: "1"), line) + case 2: + XCTAssertEqual(ByteBuffer(string: ""), line) + case 3: + XCTAssertEqual(ByteBuffer(string: "2"), line) + case 4: + XCTAssertEqual(ByteBuffer(string: "3"), line) + default: + XCTFail("not expecting call no \(callCount)") + } } - }) + ) } func testBasicSingleStepNoNewlineComingButEOF() { let decoder = LineBasedFrameDecoder() let b2mp = NIOSingleStepByteToMessageProcessor(decoder) - XCTAssertNoThrow(try b2mp.process(buffer: ByteBuffer(string: "new newline eva\r")) { line in - XCTFail("not taking calls") - }) - XCTAssertThrowsError(try b2mp.finishProcessing(seenEOF: true, { line in - XCTFail("not taking calls") - })) { error in + XCTAssertNoThrow( + try b2mp.process(buffer: ByteBuffer(string: "new newline eva\r")) { line in + XCTFail("not taking calls") + } + ) + XCTAssertThrowsError( + try b2mp.finishProcessing( + seenEOF: true, + { line in + XCTFail("not taking calls") + } + ) + ) { error in if let error = error as? NIOExtrasErrors.LeftOverBytesError { XCTAssertEqual(ByteBuffer(string: "new newline eva\r"), error.leftOverBytes) } else { @@ -256,12 +285,19 @@ class LineBasedFrameDecoderTest: XCTestCase { func testBasicSingleStepNoNewlineOrEOFComing() { let decoder = LineBasedFrameDecoder() let b2mp = NIOSingleStepByteToMessageProcessor(decoder) - XCTAssertNoThrow(try b2mp.process(buffer: ByteBuffer(string: "new newline eva\r")) { line in - XCTFail("not taking calls") - }) - XCTAssertThrowsError(try b2mp.finishProcessing(seenEOF: false, { line in - XCTFail("not taking calls") - })) { error in + XCTAssertNoThrow( + try b2mp.process(buffer: ByteBuffer(string: "new newline eva\r")) { line in + XCTFail("not taking calls") + } + ) + XCTAssertThrowsError( + try b2mp.finishProcessing( + seenEOF: false, + { line in + XCTFail("not taking calls") + } + ) + ) { error in if let error = error as? NIOExtrasErrors.LeftOverBytesError { XCTAssertEqual(ByteBuffer(string: "new newline eva\r"), error.leftOverBytes) } else { @@ -274,21 +310,23 @@ class LineBasedFrameDecoderTest: XCTestCase { let decoder = LineBasedFrameDecoder() let b2mp = NIOSingleStepByteToMessageProcessor(decoder) var callCount = 0 - XCTAssertNoThrow(try b2mp.process(buffer: ByteBuffer(string: "1\n\n2\n3\n")) { line in - callCount += 1 - switch callCount { - case 1: - XCTAssertEqual(ByteBuffer(string: "1"), line) - XCTAssertNoThrow(try b2mp.finishProcessing(seenEOF: true) { _ in } ) - case 2: - XCTAssertEqual(ByteBuffer(string: ""), line) - case 3: - XCTAssertEqual(ByteBuffer(string: "2"), line) - case 4: - XCTAssertEqual(ByteBuffer(string: "3"), line) - default: - XCTFail("not expecting call no \(callCount)") + XCTAssertNoThrow( + try b2mp.process(buffer: ByteBuffer(string: "1\n\n2\n3\n")) { line in + callCount += 1 + switch callCount { + case 1: + XCTAssertEqual(ByteBuffer(string: "1"), line) + XCTAssertNoThrow(try b2mp.finishProcessing(seenEOF: true) { _ in }) + case 2: + XCTAssertEqual(ByteBuffer(string: ""), line) + case 3: + XCTAssertEqual(ByteBuffer(string: "2"), line) + case 4: + XCTAssertEqual(ByteBuffer(string: "3"), line) + default: + XCTFail("not expecting call no \(callCount)") + } } - }) + ) } } diff --git a/Tests/NIOExtrasTests/PCAPRingBufferTest.swift b/Tests/NIOExtrasTests/PCAPRingBufferTest.swift index 5f3fed3a..7534fadc 100644 --- a/Tests/NIOExtrasTests/PCAPRingBufferTest.swift +++ b/Tests/NIOExtrasTests/PCAPRingBufferTest.swift @@ -12,15 +12,15 @@ // //===----------------------------------------------------------------------===// -import XCTest - import NIOCore import NIOEmbedded +import XCTest + @testable import NIOExtras class PCAPRingBufferTest: XCTestCase { private func dataForTests() -> [ByteBuffer] { - return [ + [ ByteBuffer(repeating: 100, count: 100), ByteBuffer(repeating: 50, count: 50), ByteBuffer(repeating: 150, count: 150), @@ -45,7 +45,7 @@ class PCAPRingBufferTest: XCTestCase { } func testNotLimited() { - let ringBuffer = NIOPCAPRingBuffer(maximumFragments: 1000, maximumBytes: 1000000) + let ringBuffer = NIOPCAPRingBuffer(maximumFragments: 1000, maximumBytes: 1_000_000) var totalBytes = 0 for fragment in dataForTests() { ringBuffer.addFragment(fragment) @@ -54,16 +54,16 @@ class PCAPRingBufferTest: XCTestCase { let emitted = PCAPRingBufferTest.captureBytes(ringBuffer: ringBuffer) XCTAssertEqual(emitted.readableBytes, totalBytes) } - + func testFragmentLimit() { - let ringBuffer = NIOPCAPRingBuffer(maximumFragments: 3, maximumBytes: 1000000) + let ringBuffer = NIOPCAPRingBuffer(maximumFragments: 3, maximumBytes: 1_000_000) for fragment in dataForTests() { ringBuffer.addFragment(fragment) } let emitted = PCAPRingBufferTest.captureBytes(ringBuffer: ringBuffer) XCTAssertEqual(emitted.readableBytes, 25 + 75 + 120) } - + func testByteLimit() { let expectedData = 150 + 25 + 75 + 120 let ringBuffer = NIOPCAPRingBuffer(maximumBytes: expectedData + 10) @@ -83,7 +83,7 @@ class PCAPRingBufferTest: XCTestCase { let emitted = PCAPRingBufferTest.captureBytes(ringBuffer: ringBuffer) XCTAssertEqual(emitted.readableBytes, expectedData) } - + func testExtremeByteLimit() { let ringBuffer = NIOPCAPRingBuffer(maximumFragments: 1000, maximumBytes: 10) for fragment in dataForTests() { @@ -92,15 +92,15 @@ class PCAPRingBufferTest: XCTestCase { let emitted = PCAPRingBufferTest.captureBytes(ringBuffer: ringBuffer) XCTAssertEqual(emitted.readableBytes, 0) } - + func testUnusedBuffer() { let ringBuffer = NIOPCAPRingBuffer(maximumFragments: 1000, maximumBytes: 1000) let emitted = PCAPRingBufferTest.captureBytes(ringBuffer: ringBuffer) XCTAssertEqual(emitted.readableBytes, 0) } - + func testDoubleEmitZero() { - let ringBuffer = NIOPCAPRingBuffer(maximumFragments: 1000, maximumBytes: 1000000) + let ringBuffer = NIOPCAPRingBuffer(maximumFragments: 1000, maximumBytes: 1_000_000) for fragment in dataForTests() { ringBuffer.addFragment(fragment) } @@ -108,55 +108,62 @@ class PCAPRingBufferTest: XCTestCase { let emitted2 = PCAPRingBufferTest.captureBytes(ringBuffer: ringBuffer) XCTAssertEqual(emitted2.readableBytes, 0) } - + func testDoubleEmitSome() { - let ringBuffer = NIOPCAPRingBuffer(maximumFragments: 1000, maximumBytes: 1000000) + let ringBuffer = NIOPCAPRingBuffer(maximumFragments: 1000, maximumBytes: 1_000_000) for fragment in dataForTests() { ringBuffer.addFragment(fragment) } _ = PCAPRingBufferTest.captureBytes(ringBuffer: ringBuffer) - + ringBuffer.addFragment(ByteBuffer(repeating: 75, count: 75)) let emitted2 = PCAPRingBufferTest.captureBytes(ringBuffer: ringBuffer) XCTAssertEqual(emitted2.readableBytes, 75) } - + func testAsHandlerSink() { let fragmentsToRecord = 4 let channel = EmbeddedChannel() let ringBuffer = NIOPCAPRingBuffer(maximumFragments: .init(fragmentsToRecord), maximumBytes: 1_000_000) - XCTAssertNoThrow(try channel.pipeline.addHandler( - NIOWritePCAPHandler(mode: .client, - fakeLocalAddress: nil, - fakeRemoteAddress: nil, - fileSink: { ringBuffer.addFragment($0) })).wait()) + XCTAssertNoThrow( + try channel.pipeline.addHandler( + NIOWritePCAPHandler( + mode: .client, + fakeLocalAddress: nil, + fakeRemoteAddress: nil, + fileSink: { ringBuffer.addFragment($0) } + ) + ).wait() + ) channel.localAddress = try! SocketAddress(ipAddress: "255.255.255.254", port: Int(UInt16.max) - 1) XCTAssertNoThrow(try channel.connect(to: .init(ipAddress: "1.2.3.4", port: 5678)).wait()) for data in dataForTests() { XCTAssertNoThrow(try channel.writeAndFlush(data).wait()) } XCTAssertNoThrow(try channel.throwIfErrorCaught()) - - XCTAssertNoThrow(try { - // See what we've got - hopefully 5 data packets. - var capturedData = PCAPRingBufferTest.captureBytes(ringBuffer: ringBuffer) - let data = dataForTests() - for expectedData in data[(data.count - fragmentsToRecord)...] { - var packet = capturedData.readPCAPRecord() - let tcpPayloadBytes = try packet?.payload.readTCPIPv4()?.tcpPayload.readableBytes - XCTAssertEqual(tcpPayloadBytes, expectedData.readableBytes) - } - }()) + + XCTAssertNoThrow( + try { + // See what we've got - hopefully 5 data packets. + var capturedData = PCAPRingBufferTest.captureBytes(ringBuffer: ringBuffer) + let data = dataForTests() + for expectedData in data[(data.count - fragmentsToRecord)...] { + var packet = capturedData.readPCAPRecord() + let tcpPayloadBytes = try packet?.payload.readTCPIPv4()?.tcpPayload.readableBytes + XCTAssertEqual(tcpPayloadBytes, expectedData.readableBytes) + } + }() + ) } - class TriggerOnCumulativeSizeHandler : ChannelOutboundHandler { + class TriggerOnCumulativeSizeHandler: ChannelOutboundHandler { typealias OutboundIn = ByteBuffer private var bytesUntilTrigger: Int private var pcapRingBuffer: NIOPCAPRingBuffer - private var sink: (ByteBuffer) -> () + private var sink: (ByteBuffer) -> Void - init(triggerBytes: Int, pcapRingBuffer: NIOPCAPRingBuffer, sink: @escaping (ByteBuffer) -> ()) { + init(triggerBytes: Int, pcapRingBuffer: NIOPCAPRingBuffer, sink: @escaping (ByteBuffer) -> Void) { self.bytesUntilTrigger = triggerBytes self.pcapRingBuffer = pcapRingBuffer self.sink = sink @@ -183,33 +190,40 @@ class PCAPRingBufferTest: XCTestCase { func testRecordedBytes(buffer: ByteBuffer) { var capturedData = buffer - XCTAssertNoThrow(try { - testTriggered = true - // See what we've got. - let data = dataForTests() - for expectedData in data[(triggerEndIndex - maximumFragments).. EventLoopFuture in - let triggerHandler = TriggerOnCumulativeSizeHandler(triggerBytes: trigger, - pcapRingBuffer: pcapRingBuffer, - sink: testRecordedBytes) + let triggerHandler = TriggerOnCumulativeSizeHandler( + triggerBytes: trigger, + pcapRingBuffer: pcapRingBuffer, + sink: testRecordedBytes + ) return channel.pipeline.addHandler(triggerHandler, name: "trigger") } XCTAssertNoThrow(try addHandlers.wait()) @@ -220,6 +234,6 @@ class PCAPRingBufferTest: XCTestCase { XCTAssertNoThrow(try channel.writeAndFlush(data).wait()) } XCTAssertNoThrow(try channel.throwIfErrorCaught()) - XCTAssert(testTriggered) // Just to make sure something actually happened. + XCTAssert(testTriggered) // Just to make sure something actually happened. } } diff --git a/Tests/NIOExtrasTests/QuiescingHelperTest.swift b/Tests/NIOExtrasTests/QuiescingHelperTest.swift index ac62e78b..68c9cf4d 100644 --- a/Tests/NIOExtrasTests/QuiescingHelperTest.swift +++ b/Tests/NIOExtrasTests/QuiescingHelperTest.swift @@ -14,11 +14,12 @@ import NIOCore import NIOEmbedded -@testable import NIOExtras import NIOPosix import NIOTestUtils import XCTest +@testable import NIOExtras + private final class WaitForQuiesceUserEvent: ChannelInboundHandler { typealias InboundIn = Never private let promise: EventLoopPromise @@ -87,7 +88,7 @@ public class QuiescingHelperTest: XCTestCase { XCTAssertTrue(childChannels.allSatisfy { $0.isActive }) // now close all the child channels - childChannels.forEach { $0.close(promise: nil) } + for childChannel in childChannels { childChannel.close(promise: nil) } el.run() XCTAssertTrue(childChannels.allSatisfy { !$0.isActive }) diff --git a/Tests/NIOExtrasTests/RequestResponseHandlerTest.swift b/Tests/NIOExtrasTests/RequestResponseHandlerTest.swift index caf0f43f..20ed1bae 100644 --- a/Tests/NIOExtrasTests/RequestResponseHandlerTest.swift +++ b/Tests/NIOExtrasTests/RequestResponseHandlerTest.swift @@ -12,10 +12,10 @@ // //===----------------------------------------------------------------------===// -import XCTest import NIOCore import NIOEmbedded import NIOExtras +import XCTest class RequestResponseHandlerTest: XCTestCase { private var eventLoop: EmbeddedEventLoop! @@ -143,7 +143,6 @@ class RequestResponseHandlerTest: XCTestCase { // we'll also fire a second error through the pipeline that shouldn't do anything self.channel.pipeline.fireErrorCaught(DummyError2()) - // and just after the error, the response arrives too (but too late) XCTAssertNoThrow(try self.channel.writeInbound(())) diff --git a/Tests/NIOExtrasTests/RequestResponseWithIDHandlerTest.swift b/Tests/NIOExtrasTests/RequestResponseWithIDHandlerTest.swift index 51e6a359..0bec0506 100644 --- a/Tests/NIOExtrasTests/RequestResponseWithIDHandlerTest.swift +++ b/Tests/NIOExtrasTests/RequestResponseWithIDHandlerTest.swift @@ -12,10 +12,10 @@ // //===----------------------------------------------------------------------===// -import XCTest import NIOCore import NIOEmbedded import NIOExtras +import XCTest class RequestResponseWithIDHandlerTest: XCTestCase { private var eventLoop: EmbeddedEventLoop! @@ -41,7 +41,11 @@ class RequestResponseWithIDHandlerTest: XCTestCase { } func testSimpleRequestWorks() { - XCTAssertNoThrow(try self.channel.pipeline.addHandler(NIORequestResponseWithIDHandler, ValueWithRequestID>()).wait()) + XCTAssertNoThrow( + try self.channel.pipeline.addHandler( + NIORequestResponseWithIDHandler, ValueWithRequestID>() + ).wait() + ) self.buffer.writeString("hello") // pretend to connect to the EmbeddedChannel knows it's supposed to be active @@ -49,13 +53,21 @@ class RequestResponseWithIDHandlerTest: XCTestCase { let p: EventLoopPromise> = self.channel.eventLoop.makePromise() // write request - XCTAssertNoThrow(try self.channel.writeOutbound((ValueWithRequestID(requestID: 1, value: IOData.byteBuffer(self.buffer)), - p))) + XCTAssertNoThrow( + try self.channel.writeOutbound( + ( + ValueWithRequestID(requestID: 1, value: IOData.byteBuffer(self.buffer)), + p + ) + ) + ) // write response XCTAssertNoThrow(try self.channel.writeInbound(ValueWithRequestID(requestID: 1, value: "okay"))) // verify request was forwarded - XCTAssertEqual(ValueWithRequestID(requestID: 1, value: IOData.byteBuffer(self.buffer)), - try self.channel.readOutbound()) + XCTAssertEqual( + ValueWithRequestID(requestID: 1, value: IOData.byteBuffer(self.buffer)), + try self.channel.readOutbound() + ) // verify response was not forwarded XCTAssertEqual(nil, try self.channel.readInbound(as: ValueWithRequestID.self)) // verify the promise got succeeded with the response @@ -64,7 +76,11 @@ class RequestResponseWithIDHandlerTest: XCTestCase { func testEnqueingMultipleRequestsWorks() throws { struct DummyError: Error {} - XCTAssertNoThrow(try self.channel.pipeline.addHandler(NIORequestResponseWithIDHandler, ValueWithRequestID>()).wait()) + XCTAssertNoThrow( + try self.channel.pipeline.addHandler( + NIORequestResponseWithIDHandler, ValueWithRequestID>() + ).wait() + ) var futures: [EventLoopFuture>] = [] // pretend to connect to the EmbeddedChannel knows it's supposed to be active @@ -78,8 +94,16 @@ class RequestResponseWithIDHandlerTest: XCTestCase { futures.append(p.futureResult) // write request - XCTAssertNoThrow(try self.channel.writeOutbound((ValueWithRequestID(requestID: reqId, - value: IOData.byteBuffer(self.buffer)), p))) + XCTAssertNoThrow( + try self.channel.writeOutbound( + ( + ValueWithRequestID( + requestID: reqId, + value: IOData.byteBuffer(self.buffer) + ), p + ) + ) + ) } // let's have 3 successful responses @@ -99,8 +123,12 @@ class RequestResponseWithIDHandlerTest: XCTestCase { default: XCTFail("could not find request") } - XCTAssertNoThrow(XCTAssertEqual(ValueWithRequestID(requestID: reqIdExpected, value: reqIdExpected), - try futures[reqIdExpected].wait())) + XCTAssertNoThrow( + XCTAssertEqual( + ValueWithRequestID(requestID: reqIdExpected, value: reqIdExpected), + try futures[reqIdExpected].wait() + ) + ) } // validate the Channel is active @@ -122,13 +150,25 @@ class RequestResponseWithIDHandlerTest: XCTestCase { func testRequestsEnqueuedAfterErrorAreFailed() { struct DummyError: Error {} - XCTAssertNoThrow(try self.channel.pipeline.addHandler(NIORequestResponseWithIDHandler, ValueWithRequestID>()).wait()) + XCTAssertNoThrow( + try self.channel.pipeline.addHandler( + NIORequestResponseWithIDHandler, ValueWithRequestID>() + ).wait() + ) self.channel.pipeline.fireErrorCaught(DummyError()) let p: EventLoopPromise> = self.eventLoop.makePromise() - XCTAssertThrowsError(try self.channel.writeOutbound((ValueWithRequestID(requestID: 1, - value: IOData.byteBuffer(self.buffer)), p))) { error in + XCTAssertThrowsError( + try self.channel.writeOutbound( + ( + ValueWithRequestID( + requestID: 1, + value: IOData.byteBuffer(self.buffer) + ), p + ) + ) + ) { error in XCTAssertNotNil(error as? DummyError) } XCTAssertThrowsError(try p.futureResult.wait()) { error in @@ -140,12 +180,22 @@ class RequestResponseWithIDHandlerTest: XCTestCase { struct DummyError1: Error {} struct DummyError2: Error {} - XCTAssertNoThrow(try self.channel.pipeline.addHandler(NIORequestResponseWithIDHandler, ValueWithRequestID>()).wait()) + XCTAssertNoThrow( + try self.channel.pipeline.addHandler( + NIORequestResponseWithIDHandler, ValueWithRequestID>() + ).wait() + ) let p: EventLoopPromise> = self.eventLoop.makePromise() // right now, everything's still okay so the enqueued request won't immediately be failed - XCTAssertNoThrow(try self.channel.writeOutbound((ValueWithRequestID(requestID: 1, value: IOData.byteBuffer(self.buffer)), - p))) + XCTAssertNoThrow( + try self.channel.writeOutbound( + ( + ValueWithRequestID(requestID: 1, value: IOData.byteBuffer(self.buffer)), + p + ) + ) + ) // but whilst we're waiting for the response, an error turns up self.channel.pipeline.fireErrorCaught(DummyError1()) @@ -153,7 +203,6 @@ class RequestResponseWithIDHandlerTest: XCTestCase { // we'll also fire a second error through the pipeline that shouldn't do anything self.channel.pipeline.fireErrorCaught(DummyError2()) - // and just after the error, the response arrives too (but too late) XCTAssertNoThrow(try self.channel.writeInbound(())) @@ -163,7 +212,11 @@ class RequestResponseWithIDHandlerTest: XCTestCase { } func testClosedConnectionFailsOutstandingPromises() { - XCTAssertNoThrow(try self.channel.pipeline.addHandler(NIORequestResponseWithIDHandler, ValueWithRequestID>()).wait()) + XCTAssertNoThrow( + try self.channel.pipeline.addHandler( + NIORequestResponseWithIDHandler, ValueWithRequestID>() + ).wait() + ) let promise = self.eventLoop.makePromise(of: ValueWithRequestID.self) XCTAssertNoThrow(try self.channel.writeOutbound((ValueWithRequestID(requestID: 1, value: "Hello!"), promise))) @@ -175,7 +228,11 @@ class RequestResponseWithIDHandlerTest: XCTestCase { } func testOutOfOrderResponsesWork() { - XCTAssertNoThrow(try self.channel.pipeline.addHandler(NIORequestResponseWithIDHandler, ValueWithRequestID>()).wait()) + XCTAssertNoThrow( + try self.channel.pipeline.addHandler( + NIORequestResponseWithIDHandler, ValueWithRequestID>() + ).wait() + ) self.buffer.writeString("hello") // pretend to connect to the EmbeddedChannel knows it's supposed to be active @@ -201,7 +258,11 @@ class RequestResponseWithIDHandlerTest: XCTestCase { } func testErrorOnResponseForNonExistantRequest() { - XCTAssertNoThrow(try self.channel.pipeline.addHandler(NIORequestResponseWithIDHandler, ValueWithRequestID>()).wait()) + XCTAssertNoThrow( + try self.channel.pipeline.addHandler( + NIORequestResponseWithIDHandler, ValueWithRequestID>() + ).wait() + ) self.buffer.writeString("hello") // pretend to connect to the EmbeddedChannel knows it's supposed to be active @@ -212,7 +273,8 @@ class RequestResponseWithIDHandlerTest: XCTestCase { // write request XCTAssertNoThrow(try self.channel.writeOutbound((ValueWithRequestID(requestID: 1, value: "1"), p1))) // write wrong response - XCTAssertThrowsError(try self.channel.writeInbound(ValueWithRequestID(requestID: 2, value: "okay 2"))) { error in + XCTAssertThrowsError(try self.channel.writeInbound(ValueWithRequestID(requestID: 2, value: "okay 2"))) { + error in guard let error = error as? NIOExtrasErrors.ResponseForInvalidRequest> else { XCTFail("wrong error") return @@ -236,10 +298,15 @@ class RequestResponseWithIDHandlerTest: XCTestCase { func channelInactive(context: ChannelHandlerContext) { let responsePromise = context.eventLoop.makePromise(of: ValueWithRequestID.self) let writePromise = context.eventLoop.makePromise(of: Void.self) - context.writeAndFlush(self.wrapOutboundOut( - (ValueWithRequestID(requestID: 1, value: IOData.byteBuffer(ByteBuffer(string: "hi"))), responsePromise) + context.writeAndFlush( + self.wrapOutboundOut( + ( + ValueWithRequestID(requestID: 1, value: IOData.byteBuffer(ByteBuffer(string: "hi"))), + responsePromise + ) ), - promise: writePromise) + promise: writePromise + ) var writePromiseCompleted = false defer { XCTAssertTrue(writePromiseCompleted) @@ -269,10 +336,12 @@ class RequestResponseWithIDHandlerTest: XCTestCase { } } - XCTAssertNoThrow(try self.channel.pipeline.addHandlers( - NIORequestResponseWithIDHandler, ValueWithRequestID>(), - EmitRequestOnInactiveHandler() - ).wait()) + XCTAssertNoThrow( + try self.channel.pipeline.addHandlers( + NIORequestResponseWithIDHandler, ValueWithRequestID>(), + EmitRequestOnInactiveHandler() + ).wait() + ) self.buffer.writeString("hello") // pretend to connect to the EmbeddedChannel knows it's supposed to be active @@ -280,14 +349,22 @@ class RequestResponseWithIDHandlerTest: XCTestCase { let p: EventLoopPromise> = self.channel.eventLoop.makePromise() // write request - XCTAssertNoThrow(try self.channel.writeOutbound((ValueWithRequestID(requestID: 1, value: IOData.byteBuffer(self.buffer)), - p))) + XCTAssertNoThrow( + try self.channel.writeOutbound( + ( + ValueWithRequestID(requestID: 1, value: IOData.byteBuffer(self.buffer)), + p + ) + ) + ) // write response XCTAssertNoThrow(try self.channel.writeInbound(ValueWithRequestID(requestID: 1, value: "okay"))) // verify request was forwarded - XCTAssertEqual(ValueWithRequestID(requestID: 1, value: IOData.byteBuffer(self.buffer)), - try self.channel.readOutbound()) + XCTAssertEqual( + ValueWithRequestID(requestID: 1, value: IOData.byteBuffer(self.buffer)), + try self.channel.readOutbound() + ) // verify the promise got succeeded with the response XCTAssertEqual(ValueWithRequestID(requestID: 1, value: "okay"), try p.futureResult.wait()) diff --git a/Tests/NIOExtrasTests/SynchronizedFileSinkTests.swift b/Tests/NIOExtrasTests/SynchronizedFileSinkTests.swift index d42395ae..9a9ce277 100644 --- a/Tests/NIOExtrasTests/SynchronizedFileSinkTests.swift +++ b/Tests/NIOExtrasTests/SynchronizedFileSinkTests.swift @@ -13,23 +13,29 @@ //===----------------------------------------------------------------------===// import Foundation -import XCTest - import NIOCore import NIOEmbedded +import XCTest + @testable import NIOExtras final class SynchronizedFileSinkTests: XCTestCase { func testSimpleFileSink() throws { try withTemporaryFile { file, path in - let sink = try NIOWritePCAPHandler.SynchronizedFileSink.fileSinkWritingToFile(path: path, errorHandler: { XCTFail("Caught error \($0)") }) + let sink = try NIOWritePCAPHandler.SynchronizedFileSink.fileSinkWritingToFile( + path: path, + errorHandler: { XCTFail("Caught error \($0)") } + ) sink.write(buffer: ByteBuffer(string: "Hello, ")) sink.write(buffer: ByteBuffer(string: "world!")) try sink.syncClose() let data = try Data(contentsOf: URL(fileURLWithPath: path)) - XCTAssertEqual(data, Data(NIOWritePCAPHandler.pcapFileHeader.readableBytesView) + Data("Hello, world!".utf8)) + XCTAssertEqual( + data, + Data(NIOWritePCAPHandler.pcapFileHeader.readableBytesView) + Data("Hello, world!".utf8) + ) } } @@ -37,20 +43,29 @@ final class SynchronizedFileSinkTests: XCTestCase { guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { try await withTemporaryFile { file, path in - let sink = try NIOWritePCAPHandler.SynchronizedFileSink.fileSinkWritingToFile(path: path, errorHandler: { XCTFail("Caught error \($0)") }) + let sink = try NIOWritePCAPHandler.SynchronizedFileSink.fileSinkWritingToFile( + path: path, + errorHandler: { XCTFail("Caught error \($0)") } + ) sink.write(buffer: ByteBuffer(string: "Hello, ")) sink.write(buffer: ByteBuffer(string: "world!")) try await sink.close() let data = try Data(contentsOf: URL(fileURLWithPath: path)) - XCTAssertEqual(data, Data(NIOWritePCAPHandler.pcapFileHeader.readableBytesView) + Data("Hello, world!".utf8)) + XCTAssertEqual( + data, + Data(NIOWritePCAPHandler.pcapFileHeader.readableBytesView) + Data("Hello, world!".utf8) + ) } } } } -fileprivate func withTemporaryFile(content: String? = nil, _ body: (NIOCore.NIOFileHandle, String) throws -> T) throws -> T { +private func withTemporaryFile( + content: String? = nil, + _ body: (NIOCore.NIOFileHandle, String) throws -> T +) throws -> T { let temporaryFilePath = "\(temporaryDirectory)/nio_extras_\(UUID())" FileManager.default.createFile(atPath: temporaryFilePath, contents: content?.data(using: .utf8)) defer { @@ -66,7 +81,10 @@ fileprivate func withTemporaryFile(content: String? = nil, _ body: (NIOCore.N } @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -fileprivate func withTemporaryFile(content: String? = nil, _ body: (NIOCore.NIOFileHandle, String) async throws -> T) async throws -> T { +private func withTemporaryFile( + content: String? = nil, + _ body: (NIOCore.NIOFileHandle, String) async throws -> T +) async throws -> T { let temporaryFilePath = "\(temporaryDirectory)/nio_extras_\(UUID())" FileManager.default.createFile(atPath: temporaryFilePath, contents: content?.data(using: .utf8)) defer { @@ -81,16 +99,16 @@ fileprivate func withTemporaryFile(content: String? = nil, _ body: (NIOCore.N return try await body(fileHandle, temporaryFilePath) } -fileprivate var temporaryDirectory: String { -#if os(Linux) +private var temporaryDirectory: String { + #if os(Linux) return "/tmp" -#else + #else if #available(macOS 10.12, iOS 10, tvOS 10, watchOS 3, *) { return FileManager.default.temporaryDirectory.path } else { return "/tmp" } -#endif // os + #endif // os } extension XCTestCase { @@ -117,7 +135,7 @@ extension XCTestCase { try await operation() } catch { XCTFail("Error thrown while executing \(function): \(error)", file: file, line: line) - Thread.callStackSymbols.forEach { print($0) } + for callStack in Thread.callStackSymbols { print(callStack) } } expectation.fulfill() } diff --git a/Tests/NIOExtrasTests/WritePCAPHandlerTest.swift b/Tests/NIOExtrasTests/WritePCAPHandlerTest.swift index ab873841..7d2173c1 100644 --- a/Tests/NIOExtrasTests/WritePCAPHandlerTest.swift +++ b/Tests/NIOExtrasTests/WritePCAPHandlerTest.swift @@ -13,10 +13,10 @@ //===----------------------------------------------------------------------===// import Foundation -import XCTest - import NIOCore import NIOEmbedded +import XCTest + @testable import NIOExtras class WritePCAPHandlerTest: XCTestCase { @@ -24,67 +24,85 @@ class WritePCAPHandlerTest: XCTestCase { private var channel: EmbeddedChannel! private var scratchBuffer: ByteBuffer! private var testAddressA: SocketAddress.IPv6Address! - + private var _mode: NIOWritePCAPHandler.Mode = .client var mode: NIOWritePCAPHandler.Mode { get { - return self._mode + self._mode } set { - self.channel = EmbeddedChannel(handler: NIOWritePCAPHandler(mode: newValue, - fakeLocalAddress: nil, - fakeRemoteAddress: nil, - fileSink: { - self.accumulatedPackets.append($0) - })) + self.channel = EmbeddedChannel( + handler: NIOWritePCAPHandler( + mode: newValue, + fakeLocalAddress: nil, + fakeRemoteAddress: nil, + fileSink: { + self.accumulatedPackets.append($0) + } + ) + ) self._mode = newValue } } - + override func setUp() { self.accumulatedPackets = [] self.channel = EmbeddedChannel() - XCTAssertNoThrow(try self.channel.pipeline.addHandler(NIOWritePCAPHandler(mode: .client, - fakeLocalAddress: nil, - fakeRemoteAddress: nil, - fileSink: { - self.accumulatedPackets.append($0) - }), name: "NIOWritePCAPHandler").wait()) + XCTAssertNoThrow( + try self.channel.pipeline.addHandler( + NIOWritePCAPHandler( + mode: .client, + fakeLocalAddress: nil, + fakeRemoteAddress: nil, + fileSink: { + self.accumulatedPackets.append($0) + } + ), + name: "NIOWritePCAPHandler" + ).wait() + ) self.scratchBuffer = self.channel.allocator.buffer(capacity: 128) } - + override func tearDown() { self.accumulatedPackets = nil self.channel = nil self.scratchBuffer = nil } - - func assertEqual(expectedAddress: SocketAddress?, - actualIPv4Address: in_addr, - actualPort: UInt16, - file: StaticString = #filePath, - line: UInt = #line) { + + func assertEqual( + expectedAddress: SocketAddress?, + actualIPv4Address: in_addr, + actualPort: UInt16, + file: StaticString = #filePath, + line: UInt = #line + ) { guard let port = expectedAddress?.port else { XCTFail("expected address nil or has no port", file: (file), line: line) return } switch expectedAddress { case .some(.v4(let expectedAddress)): - XCTAssertEqual(expectedAddress.address.sin_addr.s_addr, - actualIPv4Address.s_addr, - "IP addresses don't match", - file: (file), line: line) + XCTAssertEqual( + expectedAddress.address.sin_addr.s_addr, + actualIPv4Address.s_addr, + "IP addresses don't match", + file: (file), + line: line + ) XCTAssertEqual(port, Int(actualPort), "ports don't match", file: (file), line: line) default: XCTFail("expected address not an IPv4 address", file: (file), line: line) } } - func assertEqual(expectedAddress: SocketAddress?, - actualIPv6Address: in6_addr, - actualPort: UInt16, - file: StaticString = #filePath, - line: UInt = #line) { + func assertEqual( + expectedAddress: SocketAddress?, + actualIPv6Address: in6_addr, + actualPort: UInt16, + file: StaticString = #filePath, + line: UInt = #line + ) { guard let port = expectedAddress?.port else { XCTFail("expected address nil or has no port", file: (file), line: line) return @@ -109,52 +127,66 @@ class WritePCAPHandlerTest: XCTestCase { self.channel.localAddress = try! SocketAddress(ipAddress: "255.255.255.254", port: Int(UInt16.max) - 1) XCTAssertNoThrow(try self.channel.connect(to: .init(ipAddress: "1.2.3.4", port: 5678)).wait()) XCTAssertNoThrow(try self.channel.throwIfErrorCaught()) - XCTAssertEqual(1, self.accumulatedPackets.count) // the WritePCAPHandler will batch all into one write + XCTAssertEqual(1, self.accumulatedPackets.count) // the WritePCAPHandler will batch all into one write var buffer = self.accumulatedPackets.first - let records = [buffer?.readPCAPRecord() /* SYN */, - buffer?.readPCAPRecord() /* SYN+ACK */, - buffer?.readPCAPRecord() /* ACK */] + let records = [ + buffer?.readPCAPRecord(), // SYN + buffer?.readPCAPRecord(), // SYN+ACK + buffer?.readPCAPRecord(), // ACK + ] var ipPackets: [TCPIPv4Packet] = [] for var record in records { - XCTAssertNotNil(record) // we must have been able to parse a record - XCTAssertGreaterThan(record?.payload.readableBytes ?? -1, 0) // there must be some TCP/IP packet in there - XCTAssertEqual(2, record?.pcapProtocolID) // 2 is IPv4 + XCTAssertNotNil(record) // we must have been able to parse a record + XCTAssertGreaterThan(record?.payload.readableBytes ?? -1, 0) // there must be some TCP/IP packet in there + XCTAssertEqual(2, record?.pcapProtocolID) // 2 is IPv4 if let ipPacket = try record?.payload.readTCPIPv4() { ipPackets.append(ipPacket) XCTAssertEqual(0, ipPacket.tcpPayload.readableBytes) - XCTAssertEqual(40, ipPacket.wholeIPPacketLength) // in IPv4 it's payload + IP + TCP header + XCTAssertEqual(40, ipPacket.wholeIPPacketLength) // in IPv4 it's payload + IP + TCP header } } XCTAssertEqual(3, ipPackets.count) - + // SYN, local should be source, remote is destination - self.assertEqual(expectedAddress: self.channel?.localAddress, - actualIPv4Address: ipPackets[0].src, - actualPort: ipPackets[0].tcpHeader.srcPort) - self.assertEqual(expectedAddress: self.channel?.remoteAddress, - actualIPv4Address: ipPackets[0].dst, - actualPort: ipPackets[0].tcpHeader.dstPort) + self.assertEqual( + expectedAddress: self.channel?.localAddress, + actualIPv4Address: ipPackets[0].src, + actualPort: ipPackets[0].tcpHeader.srcPort + ) + self.assertEqual( + expectedAddress: self.channel?.remoteAddress, + actualIPv4Address: ipPackets[0].dst, + actualPort: ipPackets[0].tcpHeader.dstPort + ) XCTAssertEqual([.syn], ipPackets[0].tcpHeader.flags) - + // SYN+ACK, local should be destination, remote should be source - self.assertEqual(expectedAddress: self.channel?.remoteAddress, - actualIPv4Address: ipPackets[1].src, - actualPort: ipPackets[1].tcpHeader.srcPort) - self.assertEqual(expectedAddress: self.channel?.localAddress, - actualIPv4Address: ipPackets[1].dst, - actualPort: ipPackets[1].tcpHeader.dstPort) + self.assertEqual( + expectedAddress: self.channel?.remoteAddress, + actualIPv4Address: ipPackets[1].src, + actualPort: ipPackets[1].tcpHeader.srcPort + ) + self.assertEqual( + expectedAddress: self.channel?.localAddress, + actualIPv4Address: ipPackets[1].dst, + actualPort: ipPackets[1].tcpHeader.dstPort + ) XCTAssertEqual([.syn, .ack], ipPackets[1].tcpHeader.flags) - + // ACK - self.assertEqual(expectedAddress: self.channel?.localAddress, - actualIPv4Address: ipPackets[0].src, - actualPort: ipPackets[0].tcpHeader.srcPort) - self.assertEqual(expectedAddress: self.channel?.remoteAddress, - actualIPv4Address: ipPackets[0].dst, - actualPort: ipPackets[0].tcpHeader.dstPort) + self.assertEqual( + expectedAddress: self.channel?.localAddress, + actualIPv4Address: ipPackets[0].src, + actualPort: ipPackets[0].tcpHeader.srcPort + ) + self.assertEqual( + expectedAddress: self.channel?.remoteAddress, + actualIPv4Address: ipPackets[0].dst, + actualPort: ipPackets[0].tcpHeader.dstPort + ) XCTAssertEqual([.ack], ipPackets[2].tcpHeader.flags) - XCTAssertEqual(0, buffer?.readableBytes) // there shouldn't be anything else left + XCTAssertEqual(0, buffer?.readableBytes) // there shouldn't be anything else left } func testConnectIssuesThreePacketsForIPv6() throws { @@ -162,216 +194,272 @@ class WritePCAPHandlerTest: XCTestCase { self.channel.localAddress = try! SocketAddress(ipAddress: "1:2:3:4:5:6:7:8", port: Int(UInt16.max) - 1) XCTAssertNoThrow(try self.channel.connect(to: .init(ipAddress: "::1", port: 5678)).wait()) XCTAssertNoThrow(try self.channel.throwIfErrorCaught()) - XCTAssertEqual(1, self.accumulatedPackets.count) // the WritePCAPHandler will batch all into one write + XCTAssertEqual(1, self.accumulatedPackets.count) // the WritePCAPHandler will batch all into one write var buffer = self.accumulatedPackets.first - let records = [buffer?.readPCAPRecord() /* SYN */, - buffer?.readPCAPRecord() /* SYN+ACK */, - buffer?.readPCAPRecord() /* ACK */] + let records = [ + buffer?.readPCAPRecord(), // SYN + buffer?.readPCAPRecord(), // SYN+ACK + buffer?.readPCAPRecord(), // ACK + ] var ipPackets: [TCPIPv6Packet] = [] for var record in records { - XCTAssertNotNil(record) // we must have been able to parse a record - XCTAssertGreaterThan(record?.payload.readableBytes ?? -1, 0) // there must be some TCP/IP packet in there - XCTAssertEqual(24, record?.pcapProtocolID) // 24 is IPv6 + XCTAssertNotNil(record) // we must have been able to parse a record + XCTAssertGreaterThan(record?.payload.readableBytes ?? -1, 0) // there must be some TCP/IP packet in there + XCTAssertEqual(24, record?.pcapProtocolID) // 24 is IPv6 if let ipPacket = try record?.payload.readTCPIPv6() { ipPackets.append(ipPacket) XCTAssertEqual(0, ipPacket.tcpPayload.readableBytes) - XCTAssertEqual(20, ipPacket.payloadLength) // in IPv6 it's just the payload, ie. payload + TCP header + XCTAssertEqual(20, ipPacket.payloadLength) // in IPv6 it's just the payload, ie. payload + TCP header } } XCTAssertEqual(3, ipPackets.count) - + // SYN, local should be source, remote is destination - self.assertEqual(expectedAddress: self.channel?.localAddress, - actualIPv6Address: ipPackets[0].src, - actualPort: ipPackets[0].tcpHeader.srcPort) - self.assertEqual(expectedAddress: self.channel?.remoteAddress, - actualIPv6Address: ipPackets[0].dst, - actualPort: ipPackets[0].tcpHeader.dstPort) + self.assertEqual( + expectedAddress: self.channel?.localAddress, + actualIPv6Address: ipPackets[0].src, + actualPort: ipPackets[0].tcpHeader.srcPort + ) + self.assertEqual( + expectedAddress: self.channel?.remoteAddress, + actualIPv6Address: ipPackets[0].dst, + actualPort: ipPackets[0].tcpHeader.dstPort + ) XCTAssertEqual([.syn], ipPackets[0].tcpHeader.flags) - + // SYN+ACK, local should be destination, remote should be source - self.assertEqual(expectedAddress: self.channel?.remoteAddress, - actualIPv6Address: ipPackets[1].src, - actualPort: ipPackets[1].tcpHeader.srcPort) - self.assertEqual(expectedAddress: self.channel?.localAddress, - actualIPv6Address: ipPackets[1].dst, - actualPort: ipPackets[1].tcpHeader.dstPort) + self.assertEqual( + expectedAddress: self.channel?.remoteAddress, + actualIPv6Address: ipPackets[1].src, + actualPort: ipPackets[1].tcpHeader.srcPort + ) + self.assertEqual( + expectedAddress: self.channel?.localAddress, + actualIPv6Address: ipPackets[1].dst, + actualPort: ipPackets[1].tcpHeader.dstPort + ) XCTAssertEqual([.syn, .ack], ipPackets[1].tcpHeader.flags) - + // ACK - self.assertEqual(expectedAddress: self.channel?.localAddress, - actualIPv6Address: ipPackets[0].src, - actualPort: ipPackets[0].tcpHeader.srcPort) - self.assertEqual(expectedAddress: self.channel?.remoteAddress, - actualIPv6Address: ipPackets[0].dst, - actualPort: ipPackets[0].tcpHeader.dstPort) + self.assertEqual( + expectedAddress: self.channel?.localAddress, + actualIPv6Address: ipPackets[0].src, + actualPort: ipPackets[0].tcpHeader.srcPort + ) + self.assertEqual( + expectedAddress: self.channel?.remoteAddress, + actualIPv6Address: ipPackets[0].dst, + actualPort: ipPackets[0].tcpHeader.dstPort + ) XCTAssertEqual([.ack], ipPackets[2].tcpHeader.flags) - - XCTAssertEqual(0, buffer?.readableBytes) // there shouldn't be anything else left + + XCTAssertEqual(0, buffer?.readableBytes) // there shouldn't be anything else left } - + func testAcceptConnectionFromRemote() throws { self.mode = .server - + XCTAssertEqual([], self.accumulatedPackets) self.channel.remoteAddress = try! SocketAddress(ipAddress: "1.2.3.4", port: 5678) self.channel.localAddress = try! SocketAddress(ipAddress: "255.255.255.254", port: Int(UInt16.max) - 1) channel.pipeline.fireChannelActive() XCTAssertNoThrow(try self.channel.throwIfErrorCaught()) - XCTAssertEqual(1, self.accumulatedPackets.count) // the WritePCAPHandler will batch all into one write + XCTAssertEqual(1, self.accumulatedPackets.count) // the WritePCAPHandler will batch all into one write var buffer = self.accumulatedPackets.first - let records = [buffer?.readPCAPRecord() /* SYN */, - buffer?.readPCAPRecord() /* SYN+ACK */, - buffer?.readPCAPRecord() /* ACK */] + let records = [ + buffer?.readPCAPRecord(), // SYN + buffer?.readPCAPRecord(), // SYN+ACK + buffer?.readPCAPRecord(), // ACK + ] var ipPackets: [TCPIPv4Packet] = [] for var record in records { - XCTAssertNotNil(record) // we must have been able to parse a record - XCTAssertGreaterThan(record?.payload.readableBytes ?? -1, 0) // there must be some TCP/IP packet in there - XCTAssertEqual(2, record?.pcapProtocolID) // 2 is IPv4 + XCTAssertNotNil(record) // we must have been able to parse a record + XCTAssertGreaterThan(record?.payload.readableBytes ?? -1, 0) // there must be some TCP/IP packet in there + XCTAssertEqual(2, record?.pcapProtocolID) // 2 is IPv4 if let ipPacket = try record?.payload.readTCPIPv4() { ipPackets.append(ipPacket) XCTAssertEqual(0, ipPacket.tcpPayload.readableBytes) } } XCTAssertEqual(3, ipPackets.count) - + // SYN, local should be dst, remote is src - self.assertEqual(expectedAddress: self.channel?.remoteAddress, - actualIPv4Address: ipPackets[0].src, - actualPort: ipPackets[0].tcpHeader.srcPort) - self.assertEqual(expectedAddress: self.channel?.localAddress, - actualIPv4Address: ipPackets[0].dst, - actualPort: ipPackets[0].tcpHeader.dstPort) + self.assertEqual( + expectedAddress: self.channel?.remoteAddress, + actualIPv4Address: ipPackets[0].src, + actualPort: ipPackets[0].tcpHeader.srcPort + ) + self.assertEqual( + expectedAddress: self.channel?.localAddress, + actualIPv4Address: ipPackets[0].dst, + actualPort: ipPackets[0].tcpHeader.dstPort + ) XCTAssertEqual([.syn], ipPackets[0].tcpHeader.flags) - + // SYN+ACK, local should be src, remote should be dst - self.assertEqual(expectedAddress: self.channel?.localAddress, - actualIPv4Address: ipPackets[1].src, - actualPort: ipPackets[1].tcpHeader.srcPort) - self.assertEqual(expectedAddress: self.channel?.remoteAddress, - actualIPv4Address: ipPackets[1].dst, - actualPort: ipPackets[1].tcpHeader.dstPort) + self.assertEqual( + expectedAddress: self.channel?.localAddress, + actualIPv4Address: ipPackets[1].src, + actualPort: ipPackets[1].tcpHeader.srcPort + ) + self.assertEqual( + expectedAddress: self.channel?.remoteAddress, + actualIPv4Address: ipPackets[1].dst, + actualPort: ipPackets[1].tcpHeader.dstPort + ) XCTAssertEqual([.syn, .ack], ipPackets[1].tcpHeader.flags) - + // ACK - self.assertEqual(expectedAddress: self.channel?.remoteAddress, - actualIPv4Address: ipPackets[0].src, - actualPort: ipPackets[0].tcpHeader.srcPort) - self.assertEqual(expectedAddress: self.channel?.localAddress, - actualIPv4Address: ipPackets[0].dst, - actualPort: ipPackets[0].tcpHeader.dstPort) + self.assertEqual( + expectedAddress: self.channel?.remoteAddress, + actualIPv4Address: ipPackets[0].src, + actualPort: ipPackets[0].tcpHeader.srcPort + ) + self.assertEqual( + expectedAddress: self.channel?.localAddress, + actualIPv4Address: ipPackets[0].dst, + actualPort: ipPackets[0].tcpHeader.dstPort + ) XCTAssertEqual([.ack], ipPackets[2].tcpHeader.flags) - - XCTAssertEqual(0, buffer?.readableBytes) // there shouldn't be anything else left + + XCTAssertEqual(0, buffer?.readableBytes) // there shouldn't be anything else left } - + func testCloseOriginatingFromLocal() throws { self.channel.localAddress = try! SocketAddress(ipAddress: "1.1.1.1", port: 1) self.channel.remoteAddress = try! SocketAddress(ipAddress: "2.2.2.2", port: 2) XCTAssertNoThrow(try self.channel.close().wait()) - - XCTAssertEqual(1, self.accumulatedPackets.count) // we're batching again. - + + XCTAssertEqual(1, self.accumulatedPackets.count) // we're batching again. + var buffer = self.accumulatedPackets.first - let records = [buffer?.readPCAPRecord() /* FIN */, - buffer?.readPCAPRecord() /* FIN+ACK */, - buffer?.readPCAPRecord() /* ACK */] - XCTAssertEqual(0, buffer?.readableBytes) // nothing left + let records = [ + buffer?.readPCAPRecord(), // FIN + buffer?.readPCAPRecord(), // FIN+ACK + buffer?.readPCAPRecord(), // ACK + ] + XCTAssertEqual(0, buffer?.readableBytes) // nothing left var ipPackets: [TCPIPv4Packet] = [] for var record in records { - XCTAssertNotNil(record) // we must have been able to parse a record - XCTAssertGreaterThan(record?.payload.readableBytes ?? -1, 0) // there must be some TCP/IP packet in there + XCTAssertNotNil(record) // we must have been able to parse a record + XCTAssertGreaterThan(record?.payload.readableBytes ?? -1, 0) // there must be some TCP/IP packet in there if let ipPacket = try record?.payload.readTCPIPv4() { ipPackets.append(ipPacket) XCTAssertEqual(0, ipPacket.tcpPayload.readableBytes) } } - + // FIN, local should be source, remote is destination - self.assertEqual(expectedAddress: self.channel?.localAddress, - actualIPv4Address: ipPackets[0].src, - actualPort: ipPackets[0].tcpHeader.srcPort) - self.assertEqual(expectedAddress: self.channel?.remoteAddress, - actualIPv4Address: ipPackets[0].dst, - actualPort: ipPackets[0].tcpHeader.dstPort) + self.assertEqual( + expectedAddress: self.channel?.localAddress, + actualIPv4Address: ipPackets[0].src, + actualPort: ipPackets[0].tcpHeader.srcPort + ) + self.assertEqual( + expectedAddress: self.channel?.remoteAddress, + actualIPv4Address: ipPackets[0].dst, + actualPort: ipPackets[0].tcpHeader.dstPort + ) XCTAssertEqual([.fin], ipPackets[0].tcpHeader.flags) - + // FIN+ACK, local should be destination, remote should be source - self.assertEqual(expectedAddress: self.channel?.remoteAddress, - actualIPv4Address: ipPackets[1].src, - actualPort: ipPackets[1].tcpHeader.srcPort) - self.assertEqual(expectedAddress: self.channel?.localAddress, - actualIPv4Address: ipPackets[1].dst, - actualPort: ipPackets[1].tcpHeader.dstPort) + self.assertEqual( + expectedAddress: self.channel?.remoteAddress, + actualIPv4Address: ipPackets[1].src, + actualPort: ipPackets[1].tcpHeader.srcPort + ) + self.assertEqual( + expectedAddress: self.channel?.localAddress, + actualIPv4Address: ipPackets[1].dst, + actualPort: ipPackets[1].tcpHeader.dstPort + ) XCTAssertEqual([.fin, .ack], ipPackets[1].tcpHeader.flags) - + // ACK - self.assertEqual(expectedAddress: self.channel?.localAddress, - actualIPv4Address: ipPackets[0].src, - actualPort: ipPackets[0].tcpHeader.srcPort) - self.assertEqual(expectedAddress: self.channel?.remoteAddress, - actualIPv4Address: ipPackets[0].dst, - actualPort: ipPackets[0].tcpHeader.dstPort) + self.assertEqual( + expectedAddress: self.channel?.localAddress, + actualIPv4Address: ipPackets[0].src, + actualPort: ipPackets[0].tcpHeader.srcPort + ) + self.assertEqual( + expectedAddress: self.channel?.remoteAddress, + actualIPv4Address: ipPackets[0].dst, + actualPort: ipPackets[0].tcpHeader.dstPort + ) XCTAssertEqual([.ack], ipPackets[2].tcpHeader.flags) } - + func testCloseOriginatingFromRemote() throws { self.channel.localAddress = try! SocketAddress(ipAddress: "1.1.1.1", port: 1) self.channel.remoteAddress = try! SocketAddress(ipAddress: "2.2.2.2", port: 2) self.channel.pipeline.fireChannelInactive() - - XCTAssertEqual(1, self.accumulatedPackets.count) // we're batching again. - + + XCTAssertEqual(1, self.accumulatedPackets.count) // we're batching again. + var buffer = self.accumulatedPackets.first - let records = [buffer?.readPCAPRecord() /* FIN */, - buffer?.readPCAPRecord() /* FIN+ACK */, - buffer?.readPCAPRecord() /* ACK */] - XCTAssertEqual(0, buffer?.readableBytes) // nothing left + let records = [ + buffer?.readPCAPRecord(), // FIN + buffer?.readPCAPRecord(), // FIN+ACK + buffer?.readPCAPRecord(), // ACK + ] + XCTAssertEqual(0, buffer?.readableBytes) // nothing left var ipPackets: [TCPIPv4Packet] = [] for var record in records { - XCTAssertNotNil(record) // we must have been able to parse a record - XCTAssertGreaterThan(record?.payload.readableBytes ?? -1, 0) // there must be some TCP/IP packet in there + XCTAssertNotNil(record) // we must have been able to parse a record + XCTAssertGreaterThan(record?.payload.readableBytes ?? -1, 0) // there must be some TCP/IP packet in there if let ipPacket = try record?.payload.readTCPIPv4() { ipPackets.append(ipPacket) XCTAssertEqual(0, ipPacket.tcpPayload.readableBytes) } } - + // FIN, local should be dst, remote is src - self.assertEqual(expectedAddress: self.channel?.remoteAddress, - actualIPv4Address: ipPackets[0].src, - actualPort: ipPackets[0].tcpHeader.srcPort) - self.assertEqual(expectedAddress: self.channel?.localAddress, - actualIPv4Address: ipPackets[0].dst, - actualPort: ipPackets[0].tcpHeader.dstPort) + self.assertEqual( + expectedAddress: self.channel?.remoteAddress, + actualIPv4Address: ipPackets[0].src, + actualPort: ipPackets[0].tcpHeader.srcPort + ) + self.assertEqual( + expectedAddress: self.channel?.localAddress, + actualIPv4Address: ipPackets[0].dst, + actualPort: ipPackets[0].tcpHeader.dstPort + ) XCTAssertEqual([.fin], ipPackets[0].tcpHeader.flags) - + // FIN+ACK, local should be src, remote should be dst - self.assertEqual(expectedAddress: self.channel?.localAddress, - actualIPv4Address: ipPackets[1].src, - actualPort: ipPackets[1].tcpHeader.srcPort) - self.assertEqual(expectedAddress: self.channel?.remoteAddress, - actualIPv4Address: ipPackets[1].dst, - actualPort: ipPackets[1].tcpHeader.dstPort) + self.assertEqual( + expectedAddress: self.channel?.localAddress, + actualIPv4Address: ipPackets[1].src, + actualPort: ipPackets[1].tcpHeader.srcPort + ) + self.assertEqual( + expectedAddress: self.channel?.remoteAddress, + actualIPv4Address: ipPackets[1].dst, + actualPort: ipPackets[1].tcpHeader.dstPort + ) XCTAssertEqual([.fin, .ack], ipPackets[1].tcpHeader.flags) - + // ACK - self.assertEqual(expectedAddress: self.channel?.remoteAddress, - actualIPv4Address: ipPackets[0].src, - actualPort: ipPackets[0].tcpHeader.srcPort) - self.assertEqual(expectedAddress: self.channel?.localAddress, - actualIPv4Address: ipPackets[0].dst, - actualPort: ipPackets[0].tcpHeader.dstPort) + self.assertEqual( + expectedAddress: self.channel?.remoteAddress, + actualIPv4Address: ipPackets[0].src, + actualPort: ipPackets[0].tcpHeader.srcPort + ) + self.assertEqual( + expectedAddress: self.channel?.localAddress, + actualIPv4Address: ipPackets[0].dst, + actualPort: ipPackets[0].tcpHeader.dstPort + ) XCTAssertEqual([.ack], ipPackets[2].tcpHeader.flags) } - + func testInboundData() throws { self.channel.localAddress = try! SocketAddress(ipAddress: "1.2.3.4", port: 1111) self.channel.remoteAddress = try! SocketAddress(ipAddress: "9.8.7.6", port: 2222) self.scratchBuffer.writeStaticString("hello") XCTAssertNoThrow(try self.channel.writeInbound(self.scratchBuffer)) XCTAssertEqual(1, self.accumulatedPackets.count) - + guard var packetBytes = self.accumulatedPackets.first else { XCTFail("couldn't read bytes of first packet") return @@ -380,7 +468,7 @@ class WritePCAPHandlerTest: XCTestCase { XCTFail("couldn't read payload from PCAP record") return } - XCTAssertEqual(0, packetBytes.readableBytes) // check nothing is left over + XCTAssertEqual(0, packetBytes.readableBytes) // check nothing is left over guard let tcpIPPacket = try payload.payload.readTCPIPv4() else { XCTFail("couldn't read TCP/IPv4 packet") return @@ -389,14 +477,14 @@ class WritePCAPHandlerTest: XCTestCase { XCTAssertEqual(2222, tcpIPPacket.tcpHeader.srcPort) XCTAssertEqual("hello", String(decoding: tcpIPPacket.tcpPayload.readableBytesView, as: Unicode.UTF8.self)) } - + func testOutboundData() throws { self.channel.localAddress = try! SocketAddress(ipAddress: "1.2.3.4", port: 1111) self.channel.remoteAddress = try! SocketAddress(ipAddress: "9.8.7.6", port: 2222) self.scratchBuffer.writeStaticString("hello") XCTAssertNoThrow(try self.channel.writeOutbound(self.scratchBuffer)) XCTAssertEqual(1, self.accumulatedPackets.count) - + guard var packetBytes = self.accumulatedPackets.first else { XCTFail("couldn't read bytes of first packet") return @@ -405,7 +493,7 @@ class WritePCAPHandlerTest: XCTestCase { XCTFail("couldn't read payload from PCAP record") return } - XCTAssertEqual(0, packetBytes.readableBytes) // check nothing is left over + XCTAssertEqual(0, packetBytes.readableBytes) // check nothing is left over guard let tcpIPPacket = try payload.payload.readTCPIPv4() else { XCTFail("couldn't read TCP/IPv4 packet") return @@ -414,7 +502,7 @@ class WritePCAPHandlerTest: XCTestCase { XCTAssertEqual(1111, tcpIPPacket.tcpHeader.srcPort) XCTAssertEqual("hello", String(decoding: tcpIPPacket.tcpPayload.readableBytesView, as: Unicode.UTF8.self)) } - + func testOversizedInboundDataComesAsTwoPacketsIPv4() throws { self.channel.localAddress = try! SocketAddress(ipAddress: "1.2.3.4", port: 1111) self.channel.remoteAddress = try! SocketAddress(ipAddress: "9.8.7.6", port: 2222) @@ -422,7 +510,7 @@ class WritePCAPHandlerTest: XCTestCase { self.scratchBuffer.writeString(expectedData) XCTAssertNoThrow(try self.channel.writeInbound(self.scratchBuffer)) XCTAssertEqual(1, self.accumulatedPackets.count) - + guard var packetBytes = self.accumulatedPackets.first else { XCTFail("couldn't read bytes of first packet") return @@ -431,9 +519,10 @@ class WritePCAPHandlerTest: XCTestCase { XCTFail("couldn't read payloads from PCAP record") return } - XCTAssertEqual(0, packetBytes.readableBytes) // check nothing is left over + XCTAssertEqual(0, packetBytes.readableBytes) // check nothing is left over guard let tcpIPPacket1 = try payload1.payload.readTCPIPv4(), - let tcpIPPacket2 = try payload2.payload.readTCPIPv4() else { + let tcpIPPacket2 = try payload2.payload.readTCPIPv4() + else { XCTFail("couldn't read TCP/IPv4 packets") return } @@ -441,11 +530,12 @@ class WritePCAPHandlerTest: XCTestCase { XCTAssertEqual(2222, tcpIPPacket1.tcpHeader.srcPort) XCTAssertEqual(1111, tcpIPPacket2.tcpHeader.dstPort) XCTAssertEqual(2222, tcpIPPacket2.tcpHeader.srcPort) - let actualData = String(decoding: tcpIPPacket1.tcpPayload.readableBytesView, as: Unicode.UTF8.self) + - String(decoding: tcpIPPacket2.tcpPayload.readableBytesView, as: Unicode.UTF8.self) + let actualData = + String(decoding: tcpIPPacket1.tcpPayload.readableBytesView, as: Unicode.UTF8.self) + + String(decoding: tcpIPPacket2.tcpPayload.readableBytesView, as: Unicode.UTF8.self) XCTAssertEqual(expectedData, actualData) } - + func testOversizedInboundDataComesAsTwoPacketsIPv6() throws { self.channel.localAddress = try! SocketAddress(ipAddress: "::1", port: 1111) self.channel.remoteAddress = try! SocketAddress(ipAddress: "::2", port: 2222) @@ -453,7 +543,7 @@ class WritePCAPHandlerTest: XCTestCase { self.scratchBuffer.writeString(expectedData) XCTAssertNoThrow(try self.channel.writeInbound(self.scratchBuffer)) XCTAssertEqual(1, self.accumulatedPackets.count) - + guard var packetBytes = self.accumulatedPackets.first else { XCTFail("couldn't read bytes of first packet") return @@ -462,9 +552,10 @@ class WritePCAPHandlerTest: XCTestCase { XCTFail("couldn't read payloads from PCAP record") return } - XCTAssertEqual(0, packetBytes.readableBytes) // check nothing is left over + XCTAssertEqual(0, packetBytes.readableBytes) // check nothing is left over guard let tcpIPPacket1 = try payload1.payload.readTCPIPv6(), - let tcpIPPacket2 = try payload2.payload.readTCPIPv6() else { + let tcpIPPacket2 = try payload2.payload.readTCPIPv6() + else { XCTFail("couldn't read TCP/IPv6 packets") return } @@ -472,11 +563,12 @@ class WritePCAPHandlerTest: XCTestCase { XCTAssertEqual(2222, tcpIPPacket1.tcpHeader.srcPort) XCTAssertEqual(1111, tcpIPPacket2.tcpHeader.dstPort) XCTAssertEqual(2222, tcpIPPacket2.tcpHeader.srcPort) - let actualData = String(decoding: tcpIPPacket1.tcpPayload.readableBytesView, as: Unicode.UTF8.self) + - String(decoding: tcpIPPacket2.tcpPayload.readableBytesView, as: Unicode.UTF8.self) + let actualData = + String(decoding: tcpIPPacket1.tcpPayload.readableBytesView, as: Unicode.UTF8.self) + + String(decoding: tcpIPPacket2.tcpPayload.readableBytesView, as: Unicode.UTF8.self) XCTAssertEqual(expectedData, actualData) } - + func testOversizedOutboundDataComesAsTwoPacketsIPv4() throws { self.channel.localAddress = try! SocketAddress(ipAddress: "1.2.3.4", port: 1111) self.channel.remoteAddress = try! SocketAddress(ipAddress: "9.8.7.6", port: 2222) @@ -484,7 +576,7 @@ class WritePCAPHandlerTest: XCTestCase { self.scratchBuffer.writeString(expectedData) XCTAssertNoThrow(try self.channel.writeOutbound(self.scratchBuffer)) XCTAssertEqual(1, self.accumulatedPackets.count) - + guard var packetBytes = self.accumulatedPackets.first else { XCTFail("couldn't read bytes of first packet") return @@ -493,9 +585,10 @@ class WritePCAPHandlerTest: XCTestCase { XCTFail("couldn't read payloads from PCAP record") return } - XCTAssertEqual(0, packetBytes.readableBytes) // check nothing is left over + XCTAssertEqual(0, packetBytes.readableBytes) // check nothing is left over guard let tcpIPPacket1 = try payload1.payload.readTCPIPv4(), - let tcpIPPacket2 = try payload2.payload.readTCPIPv4() else { + let tcpIPPacket2 = try payload2.payload.readTCPIPv4() + else { XCTFail("couldn't read TCP/IPv4 packets") return } @@ -503,11 +596,12 @@ class WritePCAPHandlerTest: XCTestCase { XCTAssertEqual(1111, tcpIPPacket1.tcpHeader.srcPort) XCTAssertEqual(2222, tcpIPPacket2.tcpHeader.dstPort) XCTAssertEqual(1111, tcpIPPacket2.tcpHeader.srcPort) - let actualData = String(decoding: tcpIPPacket1.tcpPayload.readableBytesView, as: Unicode.UTF8.self) + - String(decoding: tcpIPPacket2.tcpPayload.readableBytesView, as: Unicode.UTF8.self) + let actualData = + String(decoding: tcpIPPacket1.tcpPayload.readableBytesView, as: Unicode.UTF8.self) + + String(decoding: tcpIPPacket2.tcpPayload.readableBytesView, as: Unicode.UTF8.self) XCTAssertEqual(expectedData, actualData) } - + func testOversizedOutboundDataComesAsTwoPacketsIPv6() throws { self.channel.localAddress = try! SocketAddress(ipAddress: "::1", port: 1111) self.channel.remoteAddress = try! SocketAddress(ipAddress: "::2", port: 2222) @@ -515,7 +609,7 @@ class WritePCAPHandlerTest: XCTestCase { self.scratchBuffer.writeString(expectedData) XCTAssertNoThrow(try self.channel.writeOutbound(self.scratchBuffer)) XCTAssertEqual(1, self.accumulatedPackets.count) - + guard var packetBytes = self.accumulatedPackets.first else { XCTFail("couldn't read bytes of first packet") return @@ -524,9 +618,10 @@ class WritePCAPHandlerTest: XCTestCase { XCTFail("couldn't read payloads from PCAP record") return } - XCTAssertEqual(0, packetBytes.readableBytes) // check nothing is left over + XCTAssertEqual(0, packetBytes.readableBytes) // check nothing is left over guard let tcpIPPacket1 = try payload1.payload.readTCPIPv6(), - let tcpIPPacket2 = try payload2.payload.readTCPIPv6() else { + let tcpIPPacket2 = try payload2.payload.readTCPIPv6() + else { XCTFail("couldn't read TCP/IPv6 packets") return } @@ -534,8 +629,9 @@ class WritePCAPHandlerTest: XCTestCase { XCTAssertEqual(1111, tcpIPPacket1.tcpHeader.srcPort) XCTAssertEqual(2222, tcpIPPacket2.tcpHeader.dstPort) XCTAssertEqual(1111, tcpIPPacket2.tcpHeader.srcPort) - let actualData = String(decoding: tcpIPPacket1.tcpPayload.readableBytesView, as: Unicode.UTF8.self) + - String(decoding: tcpIPPacket2.tcpPayload.readableBytesView, as: Unicode.UTF8.self) + let actualData = + String(decoding: tcpIPPacket1.tcpPayload.readableBytesView, as: Unicode.UTF8.self) + + String(decoding: tcpIPPacket2.tcpPayload.readableBytesView, as: Unicode.UTF8.self) XCTAssertEqual(expectedData, actualData) } @@ -566,20 +662,22 @@ class WritePCAPHandlerTest: XCTestCase { XCTAssertEqual(1, self.accumulatedPackets.count) var write1Bytes = self.accumulatedPackets.first XCTAssertNotNil(write1Bytes?.readPCAPRecord()) - XCTAssertEqual(0, write1Bytes?.readableBytes) // nothing left + XCTAssertEqual(0, write1Bytes?.readableBytes) // nothing left XCTAssertNoThrow(try self.channel.finish()) - XCTAssertEqual(2 /* the TCP connection FIN dance */, self.accumulatedPackets.count) + XCTAssertEqual(2, self.accumulatedPackets.count) // the TCP connection FIN dance var write2Bytes = self.accumulatedPackets.dropFirst().first - let records = [write2Bytes?.readPCAPRecord() /* FIN */, - write2Bytes?.readPCAPRecord() /* FIN+ACK */, - write2Bytes?.readPCAPRecord() /* ACK */] - XCTAssertEqual(0, write2Bytes?.readableBytes) // nothing left + let records = [ + write2Bytes?.readPCAPRecord(), // FIN + write2Bytes?.readPCAPRecord(), // FIN+ACK + write2Bytes?.readPCAPRecord(), // ACK + ] + XCTAssertEqual(0, write2Bytes?.readableBytes) // nothing left var ipPackets: [TCPIPv6Packet] = [] for var record in records { - XCTAssertNotNil(record) // we must have been able to parse a record - XCTAssertGreaterThan(record?.payload.readableBytes ?? -1, 0) // there must be some TCP/IP packet in there + XCTAssertNotNil(record) // we must have been able to parse a record + XCTAssertGreaterThan(record?.payload.readableBytes ?? -1, 0) // there must be some TCP/IP packet in there if let ipPacket = try record?.payload.readTCPIPv6() { ipPackets.append(ipPacket) XCTAssertEqual(0, ipPacket.tcpPayload.readableBytes) @@ -587,46 +685,64 @@ class WritePCAPHandlerTest: XCTestCase { } // FIN, local should be source, remote is destination - self.assertEqual(expectedAddress: self.channel?.localAddress, - actualIPv6Address: ipPackets[0].src, - actualPort: ipPackets[0].tcpHeader.srcPort) - self.assertEqual(expectedAddress: self.channel?.remoteAddress, - actualIPv6Address: ipPackets[0].dst, - actualPort: ipPackets[0].tcpHeader.dstPort) + self.assertEqual( + expectedAddress: self.channel?.localAddress, + actualIPv6Address: ipPackets[0].src, + actualPort: ipPackets[0].tcpHeader.srcPort + ) + self.assertEqual( + expectedAddress: self.channel?.remoteAddress, + actualIPv6Address: ipPackets[0].dst, + actualPort: ipPackets[0].tcpHeader.dstPort + ) XCTAssertEqual([.fin], ipPackets[0].tcpHeader.flags) - XCTAssertEqual(20 /* just the TCP header */, ipPackets[0].payloadLength) + XCTAssertEqual(20, ipPackets[0].payloadLength) // 20 -> just the TCP header // FIN+ACK, local should be destination, remote should be source - self.assertEqual(expectedAddress: self.channel?.remoteAddress, - actualIPv6Address: ipPackets[1].src, - actualPort: ipPackets[1].tcpHeader.srcPort) - self.assertEqual(expectedAddress: self.channel?.localAddress, - actualIPv6Address: ipPackets[1].dst, - actualPort: ipPackets[1].tcpHeader.dstPort) + self.assertEqual( + expectedAddress: self.channel?.remoteAddress, + actualIPv6Address: ipPackets[1].src, + actualPort: ipPackets[1].tcpHeader.srcPort + ) + self.assertEqual( + expectedAddress: self.channel?.localAddress, + actualIPv6Address: ipPackets[1].dst, + actualPort: ipPackets[1].tcpHeader.dstPort + ) XCTAssertEqual([.fin, .ack], ipPackets[1].tcpHeader.flags) - XCTAssertEqual(20 /* just the TCP header */, ipPackets[1].payloadLength) + XCTAssertEqual(20, ipPackets[1].payloadLength) // 20 -> just the TCP header // ACK - self.assertEqual(expectedAddress: self.channel?.localAddress, - actualIPv6Address: ipPackets[0].src, - actualPort: ipPackets[0].tcpHeader.srcPort) - self.assertEqual(expectedAddress: self.channel?.remoteAddress, - actualIPv6Address: ipPackets[0].dst, - actualPort: ipPackets[0].tcpHeader.dstPort) + self.assertEqual( + expectedAddress: self.channel?.localAddress, + actualIPv6Address: ipPackets[0].src, + actualPort: ipPackets[0].tcpHeader.srcPort + ) + self.assertEqual( + expectedAddress: self.channel?.remoteAddress, + actualIPv6Address: ipPackets[0].dst, + actualPort: ipPackets[0].tcpHeader.dstPort + ) XCTAssertEqual([.ack], ipPackets[2].tcpHeader.flags) - XCTAssertEqual(20 /* just the TCP header */, ipPackets[2].payloadLength) + XCTAssertEqual(20, ipPackets[2].payloadLength) // 20 -> just the TCP header } func testUnflushedOutboundDataIsWrittenWhenEmittingWritesOnIssue() throws { XCTAssertNoThrow(try self.channel.pipeline.removeHandler(name: "NIOWritePCAPHandler").wait()) let settings = NIOWritePCAPHandler.Settings(emitPCAPWrites: .whenIssued) - XCTAssertNoThrow(try self.channel.pipeline.addHandler(NIOWritePCAPHandler(mode: .client, - fakeLocalAddress: nil, - fakeRemoteAddress: nil, - settings: settings, - fileSink: { - self.accumulatedPackets.append($0) - })).wait()) + XCTAssertNoThrow( + try self.channel.pipeline.addHandler( + NIOWritePCAPHandler( + mode: .client, + fakeLocalAddress: nil, + fakeRemoteAddress: nil, + settings: settings, + fileSink: { + self.accumulatedPackets.append($0) + } + ) + ).wait() + ) self.channel.localAddress = try! SocketAddress(ipAddress: "1.2.3.4", port: 1111) self.channel.remoteAddress = try! SocketAddress(ipAddress: "9.8.7.6", port: 2222) self.scratchBuffer.writeStaticString("hello") @@ -662,16 +778,22 @@ class WritePCAPHandlerTest: XCTestCase { // Let's drop all writes/flushes so EmbeddedChannel won't accumulate them. XCTAssertNoThrow(try channel.pipeline.addHandler(DropAllWritesAndFlushes()).wait()) - XCTAssertNoThrow(try channel.pipeline.addHandler(NIOWritePCAPHandler(mode: .client, - fakeLocalAddress: .init(ipAddress: "::1", port: 1), - fakeRemoteAddress: .init(ipAddress: "::2", port: 2), - fileSink: { - numberOfBytesLogged += Int64($0.readableBytes) - })).wait()) + XCTAssertNoThrow( + try channel.pipeline.addHandler( + NIOWritePCAPHandler( + mode: .client, + fakeLocalAddress: .init(ipAddress: "::1", port: 1), + fakeRemoteAddress: .init(ipAddress: "::2", port: 2), + fileSink: { + numberOfBytesLogged += Int64($0.readableBytes) + } + ) + ).wait() + ) // Let's also drop all channelReads to prevent accumulation of all the data. XCTAssertNoThrow(try channel.pipeline.addHandler(DropAllChannelReads()).wait()) - let chunkSize = Int(UInt16.max - 40 /* needs to fit into the IPv4 header which adds 40 */) + let chunkSize = Int(UInt16.max - 40) // needs to fit into the IPv4 header which adds 40 self.scratchBuffer = channel.allocator.buffer(capacity: chunkSize) self.scratchBuffer.writeString(String(repeating: "X", count: chunkSize)) @@ -722,77 +844,87 @@ extension ByteBuffer { let saveSelf = self guard let srcPort = self.readInteger(as: UInt16.self), let dstPort = self.readInteger(as: UInt16.self), - let seqNo = self.readInteger(as: UInt32.self), // seq no - let ackNo = self.readInteger(as: UInt32.self), // ack no - let flagsAndFriends = self.readInteger(as: UInt16.self), // data offset + reserved bits + fancy stuff - let _ = self.readInteger(as: UInt16.self), // window size - let _ = self.readInteger(as: UInt16.self), // checksum - let _ = self.readInteger(as: UInt16.self) /* urgent pointer */ else { - self = saveSelf - return nil - } - guard (flagsAndFriends & (0xf << 12)) == (0x5 << 12) /* check that the data offset is right */ else { + let seqNo = self.readInteger(as: UInt32.self), // seq no + let ackNo = self.readInteger(as: UInt32.self), // ack no + let flagsAndFriends = self.readInteger(as: UInt16.self), // data offset + reserved bits + fancy stuff + let _ = self.readInteger(as: UInt16.self), // window size + let _ = self.readInteger(as: UInt16.self), // checksum + let _ = self.readInteger(as: UInt16.self) // urgent pointer + else { + self = saveSelf + return nil + } + + // check that the data offset is right + guard (flagsAndFriends & (0xf << 12)) == (0x5 << 12) else { throw ParsingError() } - return TCPHeader(flags: .init(rawValue: UInt8(flagsAndFriends & 0xfff)), - ackNumber: ackNo == 0 ? nil : ackNo, - sequenceNumber: seqNo, - srcPort: srcPort, - dstPort: dstPort) + return TCPHeader( + flags: .init(rawValue: UInt8(flagsAndFriends & 0xfff)), + ackNumber: ackNo == 0 ? nil : ackNo, + sequenceNumber: seqNo, + srcPort: srcPort, + dstPort: dstPort + ) } - + // read & parse a TCP/IPv4 packet, containing everything belonging to it (including payload) mutating func readTCPIPv4() throws -> TCPIPv4Packet? { struct ParsingError: Error {} let saveSelf = self guard let version = self.readInteger(as: UInt8.self), - let _ = self.readInteger(as: UInt8.self), // DSCP + let _ = self.readInteger(as: UInt8.self), // DSCP let ipv4WholeLength = self.readInteger(as: UInt16.self), - let _ = self.readInteger(as: UInt16.self), // identification - let _ = self.readInteger(as: UInt16.self), // flags & fragment offset - let _ = self.readInteger(as: UInt8.self), // TTL - let innerProtocolID = self.readInteger(as: UInt8.self), // TCP - let _ = self.readInteger(as: UInt16.self), // checksum + let _ = self.readInteger(as: UInt16.self), // identification + let _ = self.readInteger(as: UInt16.self), // flags & fragment offset + let _ = self.readInteger(as: UInt8.self), // TTL + let innerProtocolID = self.readInteger(as: UInt8.self), // TCP + let _ = self.readInteger(as: UInt16.self), // checksum let srcRaw = self.readInteger(endianness: .host, as: UInt32.self), let dstRaw = self.readInteger(endianness: .host, as: UInt32.self), var payload = self.readSlice(length: Int(ipv4WholeLength - 20)), - let tcp = try payload.readTCPHeader() else { - self = saveSelf - return nil + let tcp = try payload.readTCPHeader() + else { + self = saveSelf + return nil } - guard version == 0x45, innerProtocolID == 6 /* TCP is 6 */ else { + guard version == 0x45, innerProtocolID == 6 else { // innerProtocolID -> TCP is 6 throw ParsingError() } let src = in_addr(s_addr: srcRaw) let dst = in_addr(s_addr: dstRaw) - return TCPIPv4Packet(src: src, - dst: dst, - wholeIPPacketLength: .init(ipv4WholeLength), - tcpHeader: tcp, - tcpPayload: payload) + return TCPIPv4Packet( + src: src, + dst: dst, + wholeIPPacketLength: .init(ipv4WholeLength), + tcpHeader: tcp, + tcpPayload: payload + ) } - + // read & parse a TCP/IPv6 packet, containing everything belonging to it (including payload) mutating func readTCPIPv6() throws -> TCPIPv6Packet? { let saveSelf = self - guard let versionAndFancyStuff = self.readInteger(as: UInt32.self), // IP version (6) & fancy stuff + guard let versionAndFancyStuff = self.readInteger(as: UInt32.self), // IP version (6) & fancy stuff let payloadLength = self.readInteger(as: UInt16.self), - let innerProtocolID = self.readInteger(as: UInt8.self), // TCP - let _ = self.readInteger(as: UInt8.self), // hop limit (like TTL) + let innerProtocolID = self.readInteger(as: UInt8.self), // TCP + let _ = self.readInteger(as: UInt8.self), // hop limit (like TTL) var srcAddrBuffer = self.readSlice(length: MemoryLayout.size), var dstAddrBuffer = self.readSlice(length: MemoryLayout.size), var payload = self.readSlice(length: Int(payloadLength)), - let tcp = try payload.readTCPHeader() else { - self = saveSelf - return nil + let tcp = try payload.readTCPHeader() + else { + self = saveSelf + return nil } - guard versionAndFancyStuff >> 28 == 6 /* IPv_6_ */, innerProtocolID == 6 /* TCP is 6 */ else { + // IPv_6_ TCP is 6 + guard versionAndFancyStuff >> 28 == 6, innerProtocolID == 6 else { return nil } - + var srcAddress = in6_addr() var dstAddress = in6_addr() withUnsafeMutableBytes(of: &srcAddress) { copyDestPtr in @@ -810,37 +942,44 @@ extension ByteBuffer { } } - return TCPIPv6Packet(src: srcAddress, - dst: dstAddress, - payloadLength: .init(payloadLength), - tcpHeader: tcp, - tcpPayload: payload) + return TCPIPv6Packet( + src: srcAddress, + dst: dstAddress, + payloadLength: .init(payloadLength), + tcpHeader: tcp, + tcpPayload: payload + ) } // read a PCAP record, including all its payload mutating func readPCAPRecord() -> PCAPRecord? { - let saveSelf = self // save the buffer in case we don't have enough to parse - + let saveSelf = self // save the buffer in case we don't have enough to parse + guard let timeSecs = self.readInteger(endianness: .host, as: UInt32.self), let timeUSecs = self.readInteger(endianness: .host, as: UInt32.self), let lenPacket = self.readInteger(endianness: .host, as: UInt32.self), let lenDisk = self.readInteger(endianness: .host, as: UInt32.self), let pcapProtocolID = self.readInteger(endianness: .host, as: UInt32.self), - let payload = self.readSlice(length: Int(lenDisk - 4)) else { - self = saveSelf - return nil + let payload = self.readSlice(length: Int(lenDisk - 4)) + else { + self = saveSelf + return nil } - + assert(lenPacket == lenDisk, "\(lenPacket) != \(lenDisk)") - + let notImplementedAddress = try! SocketAddress(ipAddress: "9.9.9.9", port: 0xbad) let tcp = TCPHeader(flags: [], ackNumber: nil, sequenceNumber: 0xbad, srcPort: 0xbad, dstPort: 0xbad) - return .init(time: timeval(tv_sec: .init(timeSecs), tv_usec: .init(timeUSecs)), - header: try! PCAPRecordHeader(payloadLength: .init(lenPacket), - src: notImplementedAddress, - dst: notImplementedAddress, - tcp: tcp), - pcapProtocolID: pcapProtocolID, - payload: payload) + return .init( + time: timeval(tv_sec: .init(timeSecs), tv_usec: .init(timeUSecs)), + header: try! PCAPRecordHeader( + payloadLength: .init(lenPacket), + src: notImplementedAddress, + dst: notImplementedAddress, + tcp: tcp + ), + pcapProtocolID: pcapProtocolID, + payload: payload + ) } } diff --git a/Tests/NIOHTTPCompressionTests/HTTPRequestCompressorTest.swift b/Tests/NIOHTTPCompressionTests/HTTPRequestCompressorTest.swift index c607d146..6954d9a9 100644 --- a/Tests/NIOHTTPCompressionTests/HTTPRequestCompressorTest.swift +++ b/Tests/NIOHTTPCompressionTests/HTTPRequestCompressorTest.swift @@ -12,45 +12,51 @@ // //===----------------------------------------------------------------------===// -import XCTest import CNIOExtrasZlib import NIOCore import NIOEmbedded import NIOHTTP1 +import XCTest + @testable import NIOHTTPCompression class HTTPRequestCompressorTest: XCTestCase { - + func compressionChannel(_ compression: NIOCompression.Algorithm = .gzip) throws -> EmbeddedChannel { let channel = EmbeddedChannel() //XCTAssertNoThrow(try channel.pipeline.addHandler(HTTPRequestEncoder(), name: "encoder").wait()) - XCTAssertNoThrow(try channel.pipeline.addHandler(NIOHTTPRequestCompressor(encoding: compression), name: "compressor").wait()) + XCTAssertNoThrow( + try channel.pipeline.addHandler(NIOHTTPRequestCompressor(encoding: compression), name: "compressor").wait() + ) return channel } - + func write(body: [ByteBuffer], to channel: EmbeddedChannel) throws { let requestHead = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .GET, uri: "/") try write(head: requestHead, body: body, to: channel) } - + func write(head: HTTPRequestHead, body: [ByteBuffer], to channel: EmbeddedChannel) throws { var promiseArray = PromiseArray(on: channel.eventLoop) channel.pipeline.write(NIOAny(HTTPClientRequestPart.head(head)), promise: promiseArray.makePromise()) for bodyChunk in body { - channel.pipeline.write(NIOAny(HTTPClientRequestPart.body(.byteBuffer(bodyChunk))), promise: promiseArray.makePromise()) + channel.pipeline.write( + NIOAny(HTTPClientRequestPart.body(.byteBuffer(bodyChunk))), + promise: promiseArray.makePromise() + ) } channel.pipeline.write(NIOAny(HTTPClientRequestPart.end(nil)), promise: promiseArray.makePromise()) channel.pipeline.flush() - + try promiseArray.waitUntilComplete() } - + func writeWithIntermittantFlush(body: [ByteBuffer], to channel: EmbeddedChannel) throws { let requestHead = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .GET, uri: "/") try writeWithIntermittantFlush(head: requestHead, body: body, to: channel) } - + func writeWithIntermittantFlush(head: HTTPRequestHead, body: [ByteBuffer], to channel: EmbeddedChannel) throws { var promiseArray = PromiseArray(on: channel.eventLoop) var count = 3 @@ -69,7 +75,7 @@ class HTTPRequestCompressorTest: XCTestCase { } channel.pipeline.write(NIOAny(HTTPClientRequestPart.end(nil)), promise: promiseArray.makePromise()) channel.pipeline.flush() - + try promiseArray.waitUntilComplete() } @@ -92,19 +98,19 @@ class HTTPRequestCompressorTest: XCTestCase { } return (head: requestHead, body: byteBuffer) } - - func readVerifyPart(from channel: EmbeddedChannel, verify: (HTTPClientRequestPart)->()) throws { + + func readVerifyPart(from channel: EmbeddedChannel, verify: (HTTPClientRequestPart) -> Void) throws { channel.pipeline.read() loop: while let requestPart: HTTPClientRequestPart = try channel.readOutbound() { verify(requestPart) } } - + func testGzipContentEncoding() throws { let channel = try compressionChannel() var buffer = ByteBufferAllocator().buffer(capacity: 0) buffer.writeString("Test") - + _ = try write(body: [buffer], to: channel) try readVerifyPart(from: channel) { part in if case .head(let head) = part { @@ -112,12 +118,12 @@ class HTTPRequestCompressorTest: XCTestCase { } } } - + func testDeflateContentEncoding() throws { let channel = try compressionChannel(.deflate) var buffer = ByteBufferAllocator().buffer(capacity: 0) buffer.writeString("Test") - + _ = try write(body: [buffer], to: channel) try readVerifyPart(from: channel) { part in if case .head(let head) = part { @@ -125,19 +131,19 @@ class HTTPRequestCompressorTest: XCTestCase { } } } - + func testOneBuffer() throws { let channel = try compressionChannel() var buffer = ByteBufferAllocator().buffer(capacity: 1024 * Int.bitWidth / 8) for _ in 0..<1024 { buffer.writeInteger(Int.random(in: Int.min...Int.max)) } - + _ = try write(body: [buffer], to: channel) var result = try read(from: channel) var uncompressedBuffer = ByteBufferAllocator().buffer(capacity: buffer.readableBytes) z_stream.decompressGzip(compressedBytes: &result.body, outputBuffer: &uncompressedBuffer) - + XCTAssertEqual(buffer, uncompressedBuffer) XCTAssertEqual(result.head.headers["content-encoding"].first, "gzip") } @@ -159,11 +165,11 @@ class HTTPRequestCompressorTest: XCTestCase { var result = try read(from: channel) var uncompressedBuffer = ByteBufferAllocator().buffer(capacity: buffersConcat.readableBytes) z_stream.decompressGzip(compressedBytes: &result.body, outputBuffer: &uncompressedBuffer) - + XCTAssertEqual(buffersConcat, uncompressedBuffer) XCTAssertEqual(result.head.headers["content-encoding"].first, "gzip") } - + func testMultipleBuffersDeflate() throws { let channel = try compressionChannel(.deflate) var buffers: [ByteBuffer] = [] @@ -181,11 +187,11 @@ class HTTPRequestCompressorTest: XCTestCase { var result = try read(from: channel) var uncompressedBuffer = ByteBufferAllocator().buffer(capacity: buffersConcat.readableBytes) z_stream.decompressDeflate(compressedBytes: &result.body, outputBuffer: &uncompressedBuffer) - + XCTAssertEqual(buffersConcat, uncompressedBuffer) XCTAssertEqual(result.head.headers["content-encoding"].first, "deflate") } - + func testMultipleBuffersWithFlushes() throws { let channel = try compressionChannel() var buffers: [ByteBuffer] = [] @@ -203,7 +209,7 @@ class HTTPRequestCompressorTest: XCTestCase { var result = try read(from: channel) var uncompressedBuffer = ByteBufferAllocator().buffer(capacity: buffersConcat.readableBytes) z_stream.decompressGzip(compressedBytes: &result.body, outputBuffer: &uncompressedBuffer) - + XCTAssertEqual(buffersConcat, uncompressedBuffer) XCTAssertEqual(result.head.headers["content-encoding"].first, "gzip") XCTAssertEqual(result.head.headers["transfer-encoding"].first, "chunked") @@ -216,12 +222,15 @@ class HTTPRequestCompressorTest: XCTestCase { for _ in 0..<1024 { buffer.writeInteger(Int.random(in: Int.min...Int.max)) } - + let requestHead = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .GET, uri: "/") var promiseArray = PromiseArray(on: channel.eventLoop) channel.pipeline.write(NIOAny(HTTPClientRequestPart.head(requestHead)), promise: promiseArray.makePromise()) channel.pipeline.flush() - channel.pipeline.write(NIOAny(HTTPClientRequestPart.body(.byteBuffer(buffer))), promise: promiseArray.makePromise()) + channel.pipeline.write( + NIOAny(HTTPClientRequestPart.body(.byteBuffer(buffer))), + promise: promiseArray.makePromise() + ) channel.pipeline.write(NIOAny(HTTPClientRequestPart.end(nil)), promise: promiseArray.makePromise()) channel.pipeline.flush() try promiseArray.waitUntilComplete() @@ -229,22 +238,25 @@ class HTTPRequestCompressorTest: XCTestCase { var result = try read(from: channel) var uncompressedBuffer = ByteBufferAllocator().buffer(capacity: buffer.readableBytes) z_stream.decompressGzip(compressedBytes: &result.body, outputBuffer: &uncompressedBuffer) - + XCTAssertEqual(buffer, uncompressedBuffer) XCTAssertEqual(result.head.headers["content-encoding"].first, "gzip") } - + func testFlushBeforeEnd() throws { let channel = try compressionChannel() var buffer = ByteBufferAllocator().buffer(capacity: 1024 * Int.bitWidth / 8) for _ in 0..<1024 { buffer.writeInteger(Int.random(in: Int.min...Int.max)) } - + let requestHead = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .GET, uri: "/") var promiseArray = PromiseArray(on: channel.eventLoop) channel.pipeline.write(NIOAny(HTTPClientRequestPart.head(requestHead)), promise: promiseArray.makePromise()) - channel.pipeline.write(NIOAny(HTTPClientRequestPart.body(.byteBuffer(buffer))), promise: promiseArray.makePromise()) + channel.pipeline.write( + NIOAny(HTTPClientRequestPart.body(.byteBuffer(buffer))), + promise: promiseArray.makePromise() + ) channel.pipeline.flush() channel.pipeline.write(NIOAny(HTTPClientRequestPart.end(nil)), promise: promiseArray.makePromise()) channel.pipeline.flush() @@ -253,18 +265,18 @@ class HTTPRequestCompressorTest: XCTestCase { var result = try read(from: channel) var uncompressedBuffer = ByteBufferAllocator().buffer(capacity: buffer.readableBytes) z_stream.decompressGzip(compressedBytes: &result.body, outputBuffer: &uncompressedBuffer) - + XCTAssertEqual(buffer, uncompressedBuffer) XCTAssertEqual(result.head.headers["content-encoding"].first, "gzip") } - + func testDoubleFlush() throws { let channel = try compressionChannel() var buffer = ByteBufferAllocator().buffer(capacity: 1024 * Int.bitWidth / 8) for _ in 0..<1024 { buffer.writeInteger(Int.random(in: Int.min...Int.max)) } - + let algo = NIOCompression.Algorithm.gzip if algo == NIOCompression.Algorithm.deflate { print("Hello") @@ -272,7 +284,10 @@ class HTTPRequestCompressorTest: XCTestCase { let requestHead = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .GET, uri: "/") var promiseArray = PromiseArray(on: channel.eventLoop) channel.pipeline.write(NIOAny(HTTPClientRequestPart.head(requestHead)), promise: promiseArray.makePromise()) - channel.pipeline.write(NIOAny(HTTPClientRequestPart.body(.byteBuffer(buffer))), promise: promiseArray.makePromise()) + channel.pipeline.write( + NIOAny(HTTPClientRequestPart.body(.byteBuffer(buffer))), + promise: promiseArray.makePromise() + ) channel.pipeline.flush() channel.pipeline.flush() channel.pipeline.write(NIOAny(HTTPClientRequestPart.end(nil)), promise: promiseArray.makePromise()) @@ -282,14 +297,14 @@ class HTTPRequestCompressorTest: XCTestCase { var result = try read(from: channel) var uncompressedBuffer = ByteBufferAllocator().buffer(capacity: buffer.readableBytes) z_stream.decompressGzip(compressedBytes: &result.body, outputBuffer: &uncompressedBuffer) - + XCTAssertEqual(buffer, uncompressedBuffer) XCTAssertEqual(result.head.headers["content-encoding"].first, "gzip") } - + func testNoBody() throws { let channel = try compressionChannel() - + let requestHead = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .GET, uri: "/") var promiseArray = PromiseArray(on: channel.eventLoop) channel.pipeline.write(NIOAny(HTTPClientRequestPart.head(requestHead)), promise: promiseArray.makePromise()) @@ -301,7 +316,7 @@ class HTTPRequestCompressorTest: XCTestCase { switch part { case .head(let head): XCTAssertNil(head.headers["Content-Encoding"].first) - case.body: + case .body: XCTFail("Shouldn't return a body") case .end: break @@ -313,28 +328,30 @@ class HTTPRequestCompressorTest: XCTestCase { struct PromiseArray { var promises: [EventLoopPromise] let eventLoop: EventLoop - + init(on eventLoop: EventLoop) { self.promises = [] self.eventLoop = eventLoop } - + mutating func makePromise() -> EventLoopPromise { let promise: EventLoopPromise = eventLoop.makePromise() self.promises.append(promise) return promise } - + func waitUntilComplete() throws { let resultFutures = promises.map { $0.futureResult } _ = try EventLoopFuture.whenAllComplete(resultFutures, on: eventLoop).wait() } } -private extension ByteBuffer { +extension ByteBuffer { @discardableResult - mutating func withUnsafeMutableReadableUInt8Bytes(_ body: (UnsafeMutableBufferPointer) throws -> T) rethrows -> T { - return try self.withUnsafeMutableReadableBytes { (ptr: UnsafeMutableRawBufferPointer) -> T in + fileprivate mutating func withUnsafeMutableReadableUInt8Bytes( + _ body: (UnsafeMutableBufferPointer) throws -> T + ) rethrows -> T { + try self.withUnsafeMutableReadableBytes { (ptr: UnsafeMutableRawBufferPointer) -> T in let baseInputPointer = ptr.baseAddress?.assumingMemoryBound(to: UInt8.self) let inputBufferPointer = UnsafeMutableBufferPointer(start: baseInputPointer, count: ptr.count) return try body(inputBufferPointer) @@ -342,8 +359,10 @@ private extension ByteBuffer { } @discardableResult - mutating func writeWithUnsafeMutableUInt8Bytes(_ body: (UnsafeMutableBufferPointer) throws -> Int) rethrows -> Int { - return try self.writeWithUnsafeMutableBytes(minimumWritableBytes: 0) { (ptr: UnsafeMutableRawBufferPointer) -> Int in + fileprivate mutating func writeWithUnsafeMutableUInt8Bytes( + _ body: (UnsafeMutableBufferPointer) throws -> Int + ) rethrows -> Int { + try self.writeWithUnsafeMutableBytes(minimumWritableBytes: 0) { (ptr: UnsafeMutableRawBufferPointer) -> Int in let baseInputPointer = ptr.baseAddress?.assumingMemoryBound(to: UInt8.self) let inputBufferPointer = UnsafeMutableBufferPointer(start: baseInputPointer, count: ptr.count) return try body(inputBufferPointer) @@ -351,16 +370,17 @@ private extension ByteBuffer { } } -private extension z_stream { - static func decompressDeflate(compressedBytes: inout ByteBuffer, outputBuffer: inout ByteBuffer) { +extension z_stream { + fileprivate static func decompressDeflate(compressedBytes: inout ByteBuffer, outputBuffer: inout ByteBuffer) { decompress(compressedBytes: &compressedBytes, outputBuffer: &outputBuffer, windowSize: 15) } - static func decompressGzip(compressedBytes: inout ByteBuffer, outputBuffer: inout ByteBuffer) { + fileprivate static func decompressGzip(compressedBytes: inout ByteBuffer, outputBuffer: inout ByteBuffer) { decompress(compressedBytes: &compressedBytes, outputBuffer: &outputBuffer, windowSize: 16 + 15) } - private static func decompress(compressedBytes: inout ByteBuffer, outputBuffer: inout ByteBuffer, windowSize: Int32) { + private static func decompress(compressedBytes: inout ByteBuffer, outputBuffer: inout ByteBuffer, windowSize: Int32) + { compressedBytes.withUnsafeMutableReadableUInt8Bytes { inputPointer in outputBuffer.writeWithUnsafeMutableUInt8Bytes { outputPointer -> Int in var stream = z_stream() @@ -389,4 +409,3 @@ private extension z_stream { } } } - diff --git a/Tests/NIOHTTPCompressionTests/HTTPRequestDecompressorTest.swift b/Tests/NIOHTTPCompressionTests/HTTPRequestDecompressorTest.swift index 2da83713..94190e94 100644 --- a/Tests/NIOHTTPCompressionTests/HTTPRequestDecompressorTest.swift +++ b/Tests/NIOHTTPCompressionTests/HTTPRequestDecompressorTest.swift @@ -12,14 +12,16 @@ // //===----------------------------------------------------------------------===// -import XCTest import CNIOExtrasZlib import NIOCore import NIOEmbedded +import XCTest + @testable import NIOHTTP1 @testable import NIOHTTPCompression -private let testString = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." +private let testString = + "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." private final class DecompressedAssert: ChannelInboundHandler { typealias InboundIn = HTTPServerRequestPart @@ -49,7 +51,16 @@ class HTTPRequestDecompressorTest: XCTestCase { let compressed = compress(buffer, "gzip") let headers = HTTPHeaders([("Content-Encoding", "gzip"), ("Content-Length", "\(compressed.readableBytes)")]) - try channel.writeInbound(HTTPServerRequestPart.head(.init(version: .init(major: 1, minor: 1), method: .POST, uri: "https://nio.swift.org/test", headers: headers))) + try channel.writeInbound( + HTTPServerRequestPart.head( + .init( + version: .init(major: 1, minor: 1), + method: .POST, + uri: "https://nio.swift.org/test", + headers: headers + ) + ) + ) XCTAssertNoThrow(try channel.writeInbound(HTTPServerRequestPart.body(compressed))) } @@ -60,8 +71,17 @@ class HTTPRequestDecompressorTest: XCTestCase { let decompressed = ByteBuffer.of(bytes: Array(repeating: 0, count: 500)) let compressed = compress(decompressed, "gzip") let headers = HTTPHeaders([("Content-Encoding", "gzip"), ("Content-Length", "\(compressed.readableBytes)")]) - try channel.writeInbound(HTTPServerRequestPart.head(.init(version: .init(major: 1, minor: 1), method: .POST, uri: "https://nio.swift.org/test", headers: headers))) - + try channel.writeInbound( + HTTPServerRequestPart.head( + .init( + version: .init(major: 1, minor: 1), + method: .POST, + uri: "https://nio.swift.org/test", + headers: headers + ) + ) + ) + do { try channel.writeInbound(HTTPServerRequestPart.body(compressed)) XCTFail("writeShouldFail") @@ -80,10 +100,19 @@ class HTTPRequestDecompressorTest: XCTestCase { let channel = EmbeddedChannel() let decompressed = ByteBuffer.of(bytes: Array(repeating: 0, count: 200)) let compressed = compress(decompressed, "gzip") - try channel.pipeline.addHandler(NIOHTTPRequestDecompressor(limit: .size(decompressed.readableBytes-1))).wait() + try channel.pipeline.addHandler(NIOHTTPRequestDecompressor(limit: .size(decompressed.readableBytes - 1))).wait() let headers = HTTPHeaders([("Content-Encoding", "gzip"), ("Content-Length", "\(compressed.readableBytes)")]) - try channel.writeInbound(HTTPServerRequestPart.head(.init(version: .init(major: 1, minor: 1), method: .POST, uri: "https://nio.swift.org/test", headers: headers))) - + try channel.writeInbound( + HTTPServerRequestPart.head( + .init( + version: .init(major: 1, minor: 1), + method: .POST, + uri: "https://nio.swift.org/test", + headers: headers + ) + ) + ) + do { try channel.writeInbound(HTTPServerRequestPart.body(compressed)) XCTFail("writeInbound should fail") @@ -110,7 +139,7 @@ class HTTPRequestDecompressorTest: XCTestCase { (actual: "gzip", announced: "deflate"), (actual: "deflate", announced: "gzip"), ] - + for algorithm in algorithms { let compressed: ByteBuffer var headers = HTTPHeaders() @@ -123,7 +152,16 @@ class HTTPRequestDecompressorTest: XCTestCase { headers.add(name: "Content-Length", value: "\(compressed.readableBytes)") XCTAssertNoThrow( - try channel.writeInbound(HTTPServerRequestPart.head(.init(version: .init(major: 1, minor: 1), method: .POST, uri: "https://nio.swift.org/test", headers: headers))) + try channel.writeInbound( + HTTPServerRequestPart.head( + .init( + version: .init(major: 1, minor: 1), + method: .POST, + uri: "https://nio.swift.org/test", + headers: headers + ) + ) + ) ) XCTAssertNoThrow(try channel.writeInbound(HTTPServerRequestPart.body(compressed))) @@ -138,7 +176,16 @@ class HTTPRequestDecompressorTest: XCTestCase { let channel = EmbeddedChannel() try channel.pipeline.addHandler(NIOHTTPRequestDecompressor(limit: .none)).wait() let headers = HTTPHeaders([("Content-Encoding", "deflate"), ("Content-Length", "\(compressed.readableBytes)")]) - try channel.writeInbound(HTTPServerRequestPart.head(.init(version: .init(major: 1, minor: 1), method: .POST, uri: "https://nio.swift.org/test", headers: headers))) + try channel.writeInbound( + HTTPServerRequestPart.head( + .init( + version: .init(major: 1, minor: 1), + method: .POST, + uri: "https://nio.swift.org/test", + headers: headers + ) + ) + ) XCTAssertThrowsError(try channel.writeInbound(HTTPServerRequestPart.body(compressed))) } @@ -150,7 +197,16 @@ class HTTPRequestDecompressorTest: XCTestCase { let channel = EmbeddedChannel() try channel.pipeline.addHandler(NIOHTTPRequestDecompressor(limit: .none)).wait() let headers = HTTPHeaders([("Content-Encoding", "deflate"), ("Content-Length", "\(compressed.readableBytes)")]) - try channel.writeInbound(HTTPServerRequestPart.head(.init(version: .init(major: 1, minor: 1), method: .POST, uri: "https://nio.swift.org/test", headers: headers))) + try channel.writeInbound( + HTTPServerRequestPart.head( + .init( + version: .init(major: 1, minor: 1), + method: .POST, + uri: "https://nio.swift.org/test", + headers: headers + ) + ) + ) XCTAssertNoThrow(try channel.writeInbound(HTTPServerRequestPart.body(compressed))) XCTAssertThrowsError(try channel.writeInbound(HTTPServerRequestPart.end(nil))) @@ -176,7 +232,14 @@ class HTTPRequestDecompressorTest: XCTestCase { return buffer } - let rc = CNIOExtrasZlib_deflateInit2(&stream, Z_DEFAULT_COMPRESSION, Z_DEFLATED, windowBits, 8, Z_DEFAULT_STRATEGY) + let rc = CNIOExtrasZlib_deflateInit2( + &stream, + Z_DEFAULT_COMPRESSION, + Z_DEFLATED, + windowBits, + 8, + Z_DEFAULT_STRATEGY + ) XCTAssertEqual(Z_OK, rc) defer { @@ -190,15 +253,19 @@ class HTTPRequestDecompressorTest: XCTestCase { body.readWithUnsafeMutableReadableBytes { dataPtr in let typedPtr = dataPtr.baseAddress!.assumingMemoryBound(to: UInt8.self) - let typedDataPtr = UnsafeMutableBufferPointer(start: typedPtr, - count: dataPtr.count) + let typedDataPtr = UnsafeMutableBufferPointer( + start: typedPtr, + count: dataPtr.count + ) stream.avail_in = UInt32(typedDataPtr.count) stream.next_in = typedDataPtr.baseAddress! buffer.writeWithUnsafeMutableBytes(minimumWritableBytes: 0) { outputPtr in - let typedOutputPtr = UnsafeMutableBufferPointer(start: outputPtr.baseAddress!.assumingMemoryBound(to: UInt8.self), - count: outputPtr.count) + let typedOutputPtr = UnsafeMutableBufferPointer( + start: outputPtr.baseAddress!.assumingMemoryBound(to: UInt8.self), + count: outputPtr.count + ) stream.avail_out = UInt32(typedOutputPtr.count) stream.next_out = typedOutputPtr.baseAddress! let rc = deflate(&stream, Z_FINISH) diff --git a/Tests/NIOHTTPCompressionTests/HTTPResponseCompressorTest.swift b/Tests/NIOHTTPCompressionTests/HTTPResponseCompressorTest.swift index 99a56730..cc32ecdc 100644 --- a/Tests/NIOHTTPCompressionTests/HTTPResponseCompressorTest.swift +++ b/Tests/NIOHTTPCompressionTests/HTTPResponseCompressorTest.swift @@ -12,16 +12,17 @@ // //===----------------------------------------------------------------------===// -import XCTest import CNIOExtrasZlib +import NIOConcurrencyHelpers import NIOCore import NIOEmbedded import NIOHTTP1 -import NIOConcurrencyHelpers +import XCTest + @testable import NIOHTTPCompression private class PromiseOrderer { - private var promiseArray: Array> + private var promiseArray: [EventLoopPromise] private let eventLoop: EventLoop internal init(eventLoop: EventLoop) { @@ -56,10 +57,12 @@ private class PromiseOrderer { } } -private extension ByteBuffer { +extension ByteBuffer { @discardableResult - mutating func withUnsafeMutableReadableUInt8Bytes(_ body: (UnsafeMutableBufferPointer) throws -> T) rethrows -> T { - return try self.withUnsafeMutableReadableBytes { (ptr: UnsafeMutableRawBufferPointer) -> T in + fileprivate mutating func withUnsafeMutableReadableUInt8Bytes( + _ body: (UnsafeMutableBufferPointer) throws -> T + ) rethrows -> T { + try self.withUnsafeMutableReadableBytes { (ptr: UnsafeMutableRawBufferPointer) -> T in let baseInputPointer = ptr.baseAddress?.assumingMemoryBound(to: UInt8.self) let inputBufferPointer = UnsafeMutableBufferPointer(start: baseInputPointer, count: ptr.count) return try body(inputBufferPointer) @@ -67,15 +70,17 @@ private extension ByteBuffer { } @discardableResult - mutating func writeWithUnsafeMutableUInt8Bytes(_ body: (UnsafeMutableBufferPointer) throws -> Int) rethrows -> Int { - return try self.writeWithUnsafeMutableBytes(minimumWritableBytes: 0) { (ptr: UnsafeMutableRawBufferPointer) -> Int in + fileprivate mutating func writeWithUnsafeMutableUInt8Bytes( + _ body: (UnsafeMutableBufferPointer) throws -> Int + ) rethrows -> Int { + try self.writeWithUnsafeMutableBytes(minimumWritableBytes: 0) { (ptr: UnsafeMutableRawBufferPointer) -> Int in let baseInputPointer = ptr.baseAddress?.assumingMemoryBound(to: UInt8.self) let inputBufferPointer = UnsafeMutableBufferPointer(start: baseInputPointer, count: ptr.count) return try body(inputBufferPointer) } } - mutating func merge(_ others: S) -> ByteBuffer where S.Element == ByteBuffer { + fileprivate mutating func merge(_ others: S) -> ByteBuffer where S.Element == ByteBuffer { for var buffer in others { self.writeBuffer(&buffer) } @@ -83,16 +88,17 @@ private extension ByteBuffer { } } -private extension z_stream { - static func decompressDeflate(compressedBytes: inout ByteBuffer, outputBuffer: inout ByteBuffer) { +extension z_stream { + fileprivate static func decompressDeflate(compressedBytes: inout ByteBuffer, outputBuffer: inout ByteBuffer) { decompress(compressedBytes: &compressedBytes, outputBuffer: &outputBuffer, windowSize: 15) } - static func decompressGzip(compressedBytes: inout ByteBuffer, outputBuffer: inout ByteBuffer) { + fileprivate static func decompressGzip(compressedBytes: inout ByteBuffer, outputBuffer: inout ByteBuffer) { decompress(compressedBytes: &compressedBytes, outputBuffer: &outputBuffer, windowSize: 16 + 15) } - private static func decompress(compressedBytes: inout ByteBuffer, outputBuffer: inout ByteBuffer, windowSize: Int32) { + private static func decompress(compressedBytes: inout ByteBuffer, outputBuffer: inout ByteBuffer, windowSize: Int32) + { compressedBytes.withUnsafeMutableReadableUInt8Bytes { inputPointer in outputBuffer.writeWithUnsafeMutableUInt8Bytes { outputPointer -> Int in var stream = z_stream() @@ -148,12 +154,16 @@ class HTTPResponseCompressorTest: XCTestCase { channel.pipeline.write(NIOAny(HTTPServerResponsePart.head(head)), promise: promiseOrderer.makePromise()) for bodyChunk in body { - channel.pipeline.write(NIOAny(HTTPServerResponsePart.body(.byteBuffer(bodyChunk))), - promise: promiseOrderer.makePromise()) + channel.pipeline.write( + NIOAny(HTTPServerResponsePart.body(.byteBuffer(bodyChunk))), + promise: promiseOrderer.makePromise() + ) } - channel.pipeline.write(NIOAny(HTTPServerResponsePart.end(nil)), - promise: promiseOrderer.makePromise()) + channel.pipeline.write( + NIOAny(HTTPServerResponsePart.end(nil)), + promise: promiseOrderer.makePromise() + ) channel.pipeline.flush() // Get all the promises to fire. @@ -165,25 +175,31 @@ class HTTPResponseCompressorTest: XCTestCase { var writeCount = 0 channel.pipeline.write(NIOAny(HTTPServerResponsePart.head(head)), promise: promiseOrderer.makePromise()) for bodyChunk in body { - channel.pipeline.write(NIOAny(HTTPServerResponsePart.body(.byteBuffer(bodyChunk))), - promise: promiseOrderer.makePromise()) + channel.pipeline.write( + NIOAny(HTTPServerResponsePart.body(.byteBuffer(bodyChunk))), + promise: promiseOrderer.makePromise() + ) writeCount += 1 if writeCount % 3 == 0 { channel.pipeline.flush() } } - channel.pipeline.write(NIOAny(HTTPServerResponsePart.end(nil)), - promise: promiseOrderer.makePromise()) + channel.pipeline.write( + NIOAny(HTTPServerResponsePart.end(nil)), + promise: promiseOrderer.makePromise() + ) channel.pipeline.flush() // Get all the promises to fire. try promiseOrderer.waitUntilComplete() } - private func compressResponse(head: HTTPResponseHead, - body: [ByteBuffer], - channel: EmbeddedChannel, - writeStrategy: WriteStrategy = .once) throws -> (HTTPResponseHead, [ByteBuffer]) { + private func compressResponse( + head: HTTPResponseHead, + body: [ByteBuffer], + channel: EmbeddedChannel, + writeStrategy: WriteStrategy = .once + ) throws -> (HTTPResponseHead, [ByteBuffer]) { switch writeStrategy { case .once: try writeOneChunk(head: head, body: body, channel: channel) @@ -225,10 +241,12 @@ class HTTPResponseCompressorTest: XCTestCase { return (head!, dataChunks) } - private func assertDecompressedResponseMatches(responseData: inout ByteBuffer, - expectedResponse: ByteBuffer, - allocator: ByteBufferAllocator, - decompressor: (inout ByteBuffer, inout ByteBuffer) -> Void) { + private func assertDecompressedResponseMatches( + responseData: inout ByteBuffer, + expectedResponse: ByteBuffer, + allocator: ByteBufferAllocator, + decompressor: (inout ByteBuffer, inout ByteBuffer) -> Void + ) { var outputBuffer = allocator.buffer(capacity: expectedResponse.readableBytes) decompressor(&responseData, &outputBuffer) XCTAssertEqual(expectedResponse, outputBuffer) @@ -255,10 +273,12 @@ class HTTPResponseCompressorTest: XCTestCase { bodyChunks.append(bodyBuffer.getSlice(at: index, length: 2)!) } - let data = try compressResponse(head: response, - body: bodyChunks, - channel: channel, - writeStrategy: writeStrategy) + let data = try compressResponse( + head: response, + body: bodyChunks, + channel: channel, + writeStrategy: writeStrategy + ) let compressedResponse = data.0 var compressedChunks = data.1 var compressedBody = compressedChunks[0].merge(compressedChunks[1...]) @@ -266,7 +286,10 @@ class HTTPResponseCompressorTest: XCTestCase { switch writeStrategy { case .once: - XCTAssertEqual(compressedResponse.headers[canonicalForm: "content-length"], ["\(compressedBody.readableBytes)"]) + XCTAssertEqual( + compressedResponse.headers[canonicalForm: "content-length"], + ["\(compressedBody.readableBytes)"] + ) XCTAssertEqual(compressedResponse.headers[canonicalForm: "transfer-encoding"], []) case .intermittentFlushes: XCTAssertEqual(compressedResponse.headers[canonicalForm: "content-length"], []) @@ -277,10 +300,12 @@ class HTTPResponseCompressorTest: XCTestCase { XCTAssertEqual(compressedResponse.headers, assertHeaders) } - assertDecompressedResponseMatches(responseData: &compressedBody, - expectedResponse: bodyBuffer, - allocator: channel.allocator, - decompressor: z_stream.decompressDeflate) + assertDecompressedResponseMatches( + responseData: &compressedBody, + expectedResponse: bodyBuffer, + allocator: channel.allocator, + decompressor: z_stream.decompressDeflate + ) } private func assertGzippedResponse( @@ -304,10 +329,12 @@ class HTTPResponseCompressorTest: XCTestCase { bodyChunks.append(bodyBuffer.getSlice(at: index, length: 2)!) } - let data = try compressResponse(head: response, - body: bodyChunks, - channel: channel, - writeStrategy: writeStrategy) + let data = try compressResponse( + head: response, + body: bodyChunks, + channel: channel, + writeStrategy: writeStrategy + ) let compressedResponse = data.0 var compressedChunks = data.1 var compressedBody = compressedChunks[0].merge(compressedChunks[1...]) @@ -315,7 +342,10 @@ class HTTPResponseCompressorTest: XCTestCase { switch writeStrategy { case .once: - XCTAssertEqual(compressedResponse.headers[canonicalForm: "content-length"], ["\(compressedBody.readableBytes)"]) + XCTAssertEqual( + compressedResponse.headers[canonicalForm: "content-length"], + ["\(compressedBody.readableBytes)"] + ) XCTAssertEqual(compressedResponse.headers[canonicalForm: "transfer-encoding"], []) case .intermittentFlushes: XCTAssertEqual(compressedResponse.headers[canonicalForm: "content-length"], []) @@ -326,10 +356,12 @@ class HTTPResponseCompressorTest: XCTestCase { XCTAssertEqual(compressedResponse.headers, assertHeaders) } - assertDecompressedResponseMatches(responseData: &compressedBody, - expectedResponse: bodyBuffer, - allocator: channel.allocator, - decompressor: z_stream.decompressGzip) + assertDecompressedResponseMatches( + responseData: &compressedBody, + expectedResponse: bodyBuffer, + allocator: channel.allocator, + decompressor: z_stream.decompressGzip + ) } private func assertUncompressedResponse( @@ -353,10 +385,12 @@ class HTTPResponseCompressorTest: XCTestCase { bodyChunks.append(bodyBuffer.getSlice(at: index, length: 2)!) } - let data = try compressResponse(head: response, - body: bodyChunks, - channel: channel, - writeStrategy: writeStrategy) + let data = try compressResponse( + head: response, + body: bodyChunks, + channel: channel, + writeStrategy: writeStrategy + ) let compressedResponse = data.0 var compressedChunks = data.1 let uncompressedBody = compressedChunks[0].merge(compressedChunks[1...]) @@ -386,7 +420,7 @@ class HTTPResponseCompressorTest: XCTestCase { try sendRequest(acceptEncoding: "deflate", channel: channel) try assertDeflatedResponse(channel: channel) } - + func testExplicitInitialByteBufferCapacity() throws { /// This test it to make sure there is no ambiguity choosing an initializer. let channel = try compressionChannel(compressor: HTTPResponseCompressor(initialByteBufferCapacity: 2048)) @@ -566,7 +600,7 @@ class HTTPResponseCompressorTest: XCTestCase { switch err { case HTTPResponseCompressor.CompressionError.uncompressedWritesPending: () - // ok + // ok default: XCTFail("\(err)") } @@ -674,9 +708,11 @@ class HTTPResponseCompressorTest: XCTestCase { try sendRequest(acceptEncoding: "deflate", channel: channel) try assertDeflatedResponse(channel: channel) - XCTAssertNoThrow(try channel.pipeline.context(handlerType: HTTPResponseCompressor.self).flatMap { context in - channel.pipeline.removeHandler(context: context) - }.wait()) + XCTAssertNoThrow( + try channel.pipeline.context(handlerType: HTTPResponseCompressor.self).flatMap { context in + channel.pipeline.removeHandler(context: context) + }.wait() + ) try sendRequest(acceptEncoding: "deflate", channel: channel) try assertUncompressedResponse(channel: channel) @@ -691,9 +727,11 @@ class HTTPResponseCompressorTest: XCTestCase { try sendRequest(acceptEncoding: "deflate;q=2.2, gzip;q=0.3", channel: channel) - let head = HTTPResponseHead(version: .init(major: 1, minor: 1), - status: .noContent, - headers: .init()) + let head = HTTPResponseHead( + version: .init(major: 1, minor: 1), + status: .noContent, + headers: .init() + ) try channel.writeOutbound(HTTPServerResponsePart.head(head)) try channel.writeOutbound(HTTPServerResponsePart.end(nil)) @@ -735,16 +773,16 @@ class HTTPResponseCompressorTest: XCTestCase { } } } - + func testConditionalCompressionEnabled() throws { let predicateWasCalled = expectation(description: "Predicate was called") let compressor = HTTPResponseCompressor { responseHeaders, isCompressionSupported in defer { predicateWasCalled.fulfill() } - XCTAssertEqual(responseHeaders.headers, ["Content-Type" : "json"]) + XCTAssertEqual(responseHeaders.headers, ["Content-Type": "json"]) XCTAssertEqual(isCompressionSupported, true) return .compressIfPossible } - + let channel = try compressionChannel(compressor: compressor) defer { XCTAssertNoThrow(try channel.finish()) @@ -753,66 +791,66 @@ class HTTPResponseCompressorTest: XCTestCase { try sendRequest(acceptEncoding: "deflate", channel: channel) try assertDeflatedResponse( channel: channel, - responseHeaders: ["Content-Type" : "json"], + responseHeaders: ["Content-Type": "json"], assertHeaders: [ - "Content-Type" : "json", - "Content-Encoding" : "deflate", - "Content-Length" : "23", + "Content-Type": "json", + "Content-Encoding": "deflate", + "Content-Length": "23", ] ) - + waitForExpectations(timeout: 0) } - + func testUnsupportedRequestConditionalCompressionEnabled() throws { let predicateWasCalled = expectation(description: "Predicate was called") let compressor = HTTPResponseCompressor { responseHeaders, isCompressionSupported in defer { predicateWasCalled.fulfill() } - XCTAssertEqual(responseHeaders.headers, ["Content-Type" : "json"]) + XCTAssertEqual(responseHeaders.headers, ["Content-Type": "json"]) XCTAssertEqual(isCompressionSupported, false) return .compressIfPossible } - + let channel = try compressionChannel(compressor: compressor) defer { XCTAssertNoThrow(try channel.finish()) } - + try sendRequest(acceptEncoding: nil, channel: channel) try assertUncompressedResponse( channel: channel, - responseHeaders: ["Content-Type" : "json"], + responseHeaders: ["Content-Type": "json"], assertHeaders: [ - "Content-Type" : "json", - "transfer-encoding" : "chunked", + "Content-Type": "json", + "transfer-encoding": "chunked", ] ) - + waitForExpectations(timeout: 0) } - + func testUnsupportedStatusConditionalCompressionEnabled() throws { let predicateWasCalled = expectation(description: "Predicate was called") let compressor = HTTPResponseCompressor { responseHeaders, isCompressionSupported in defer { predicateWasCalled.fulfill() } XCTAssertEqual(responseHeaders.status, .notModified) - XCTAssertEqual(responseHeaders.headers, ["Content-Type" : "json"]) + XCTAssertEqual(responseHeaders.headers, ["Content-Type": "json"]) XCTAssertEqual(isCompressionSupported, false) return .compressIfPossible } - + let channel = EmbeddedChannel() XCTAssertNoThrow(try channel.pipeline.addHandler(compressor).wait()) defer { XCTAssertNoThrow(try channel.finish()) } - + try sendRequest(acceptEncoding: "deflate", channel: channel) - + let head = HTTPResponseHead( version: .init(major: 1, minor: 1), status: .notModified, - headers: ["Content-Type" : "json"] + headers: ["Content-Type": "json"] ) try channel.writeOutbound(HTTPServerResponsePart.head(head)) try channel.writeOutbound(HTTPServerResponsePart.end(nil)) @@ -826,19 +864,19 @@ class HTTPResponseCompressorTest: XCTestCase { case .end: break } } - + waitForExpectations(timeout: 0) } - + func testConditionalCompressionDisabled() throws { let predicateWasCalled = expectation(description: "Predicate was called") let compressor = HTTPResponseCompressor { responseHeaders, isCompressionSupported in defer { predicateWasCalled.fulfill() } - XCTAssertEqual(responseHeaders.headers, ["Content-Type" : "json"]) + XCTAssertEqual(responseHeaders.headers, ["Content-Type": "json"]) XCTAssertEqual(isCompressionSupported, true) return .doNotCompress } - + let channel = try compressionChannel(compressor: compressor) defer { XCTAssertNoThrow(try channel.finish()) @@ -847,65 +885,65 @@ class HTTPResponseCompressorTest: XCTestCase { try sendRequest(acceptEncoding: "deflate", channel: channel) try assertUncompressedResponse( channel: channel, - responseHeaders: ["Content-Type" : "json"], + responseHeaders: ["Content-Type": "json"], assertHeaders: [ - "Content-Type" : "json", - "transfer-encoding" : "chunked", + "Content-Type": "json", + "transfer-encoding": "chunked", ] ) - + waitForExpectations(timeout: 0) } - + func testUnsupportedRequestConditionalCompressionDisabled() throws { let predicateWasCalled = expectation(description: "Predicate was called") let compressor = HTTPResponseCompressor { responseHeaders, isCompressionSupported in defer { predicateWasCalled.fulfill() } - XCTAssertEqual(responseHeaders.headers, ["Content-Type" : "json"]) + XCTAssertEqual(responseHeaders.headers, ["Content-Type": "json"]) XCTAssertEqual(isCompressionSupported, false) return .doNotCompress } - + let channel = try compressionChannel(compressor: compressor) defer { XCTAssertNoThrow(try channel.finish()) } - + try sendRequest(acceptEncoding: nil, channel: channel) try assertUncompressedResponse( channel: channel, - responseHeaders: ["Content-Type" : "json"], + responseHeaders: ["Content-Type": "json"], assertHeaders: [ - "Content-Type" : "json", - "transfer-encoding" : "chunked", + "Content-Type": "json", + "transfer-encoding": "chunked", ] ) - + waitForExpectations(timeout: 0) } - + func testUnsupportedStatusConditionalCompressionDisabled() throws { let predicateWasCalled = expectation(description: "Predicate was called") let compressor = HTTPResponseCompressor { responseHeaders, isCompressionSupported in defer { predicateWasCalled.fulfill() } XCTAssertEqual(responseHeaders.status, .notModified) - XCTAssertEqual(responseHeaders.headers, ["Content-Type" : "json"]) + XCTAssertEqual(responseHeaders.headers, ["Content-Type": "json"]) XCTAssertEqual(isCompressionSupported, false) return .doNotCompress } - + let channel = EmbeddedChannel() XCTAssertNoThrow(try channel.pipeline.addHandler(compressor).wait()) defer { XCTAssertNoThrow(try channel.finish()) } - + try sendRequest(acceptEncoding: "deflate", channel: channel) - + let head = HTTPResponseHead( version: .init(major: 1, minor: 1), status: .notModified, - headers: ["Content-Type" : "json"] + headers: ["Content-Type": "json"] ) try channel.writeOutbound(HTTPServerResponsePart.head(head)) try channel.writeOutbound(HTTPServerResponsePart.end(nil)) @@ -919,48 +957,51 @@ class HTTPResponseCompressorTest: XCTestCase { case .end: break } } - + waitForExpectations(timeout: 0) } - + func testConditionalCompressionModifiedHeaders() throws { let predicateWasCalled = expectation(description: "Predicate was called") predicateWasCalled.expectedFulfillmentCount = 2 let compressor = HTTPResponseCompressor { responseHeaders, isCompressionSupported in defer { predicateWasCalled.fulfill() } let isEnabled = responseHeaders.headers[canonicalForm: "x-compression"].first == "enable" - XCTAssertEqual(responseHeaders.headers, ["Content-Type" : "json", "X-Compression" : isEnabled ? "enable" : "disable"]) + XCTAssertEqual( + responseHeaders.headers, + ["Content-Type": "json", "X-Compression": isEnabled ? "enable" : "disable"] + ) responseHeaders.headers.remove(name: "X-Compression") XCTAssertEqual(isCompressionSupported, true) return isEnabled ? .compressIfPossible : .doNotCompress } - + let channel = try compressionChannel(compressor: compressor) defer { XCTAssertNoThrow(try channel.finish()) } - + try sendRequest(acceptEncoding: "deflate", channel: channel) try assertDeflatedResponse( channel: channel, - responseHeaders: ["Content-Type" : "json", "X-Compression" : "enable"], + responseHeaders: ["Content-Type": "json", "X-Compression": "enable"], assertHeaders: [ - "Content-Type" : "json", - "Content-Encoding" : "deflate", - "Content-Length" : "23", + "Content-Type": "json", + "Content-Encoding": "deflate", + "Content-Length": "23", ] ) - + try sendRequest(acceptEncoding: "deflate", channel: channel) try assertUncompressedResponse( channel: channel, - responseHeaders: ["Content-Type" : "json", "X-Compression" : "disable"], + responseHeaders: ["Content-Type": "json", "X-Compression": "disable"], assertHeaders: [ - "Content-Type" : "json", - "transfer-encoding" : "chunked", + "Content-Type": "json", + "transfer-encoding": "chunked", ] ) - + waitForExpectations(timeout: 0) } } @@ -978,17 +1019,17 @@ extension EventLoopFuture { } else { let lock = NIOLock() let group = DispatchGroup() - var fulfilled = false // protected by lock + var fulfilled = false // protected by lock group.enter() self.eventLoop.execute { - let isFulfilled = self.isFulfilled // This will now enter the above branch. + let isFulfilled = self.isFulfilled // This will now enter the above branch. lock.withLock { fulfilled = isFulfilled } group.leave() } - group.wait() // this is very nasty but this is for tests only, so... + group.wait() // this is very nasty but this is for tests only, so... return lock.withLock { fulfilled } } } diff --git a/Tests/NIOHTTPCompressionTests/HTTPResponseDecompressorTest.swift b/Tests/NIOHTTPCompressionTests/HTTPResponseDecompressorTest.swift index 6fc6c767..8a268502 100644 --- a/Tests/NIOHTTPCompressionTests/HTTPResponseDecompressorTest.swift +++ b/Tests/NIOHTTPCompressionTests/HTTPResponseDecompressorTest.swift @@ -12,10 +12,11 @@ // //===----------------------------------------------------------------------===// -import XCTest import CNIOExtrasZlib -@testable import NIOCore import NIOEmbedded +import XCTest + +@testable import NIOCore @testable import NIOHTTP1 @testable import NIOHTTPCompression @@ -25,32 +26,42 @@ class HTTPResponseDecompressorTest: XCTestCase { try channel.pipeline.addHandler(NIOHTTPResponseDecompressor(limit: .none)).wait() let headers = HTTPHeaders([("Content-Encoding", "deflate"), ("Content-Length", "13")]) - try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers))) + try channel.writeInbound( + HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers)) + ) let body = ByteBuffer.of(bytes: [120, 156, 75, 76, 28, 5, 200, 0, 0, 248, 66, 103, 17]) XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.body(body))) } - + func testDecompressionLimitSizeWithContentLenghtHeaderSucceeds() { let channel = EmbeddedChannel() XCTAssertNoThrow(try channel.pipeline.addHandler(NIOHTTPResponseDecompressor(limit: .size(272))).wait()) let headers = HTTPHeaders([("Content-Encoding", "deflate"), ("Content-Length", "13")]) - - XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers)))) + + XCTAssertNoThrow( + try channel.writeInbound( + HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers)) + ) + ) // this compressed payload is 272 bytes long uncompressed let body = ByteBuffer.of(bytes: [120, 156, 75, 76, 28, 5, 200, 0, 0, 248, 66, 103, 17]) XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.body(body))) } - + func testDecompressionLimitSizeWithContentLenghtHeaderFails() { let channel = EmbeddedChannel() XCTAssertNoThrow(try channel.pipeline.addHandler(NIOHTTPResponseDecompressor(limit: .size(271))).wait()) let headers = HTTPHeaders([("Content-Encoding", "deflate"), ("Content-Length", "13")]) - - XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers)))) + + XCTAssertNoThrow( + try channel.writeInbound( + HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers)) + ) + ) // this compressed payload is 272 bytes long uncompressed let body = ByteBuffer.of(bytes: [120, 156, 75, 76, 28, 5, 200, 0, 0, 248, 66, 103, 17]) @@ -58,14 +69,18 @@ class HTTPResponseDecompressorTest: XCTestCase { XCTAssertEqual(error as? NIOHTTPDecompression.DecompressionError, .limit) } } - + func testDecompressionLimitSizeWithoutContentLenghtHeaderSucceeds() { let channel = EmbeddedChannel() XCTAssertNoThrow(try channel.pipeline.addHandler(NIOHTTPResponseDecompressor(limit: .size(272))).wait()) - + let headers = HTTPHeaders([("Content-Encoding", "deflate")]) - - XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers)))) + + XCTAssertNoThrow( + try channel.writeInbound( + HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers)) + ) + ) // this compressed payload is 272 bytes long uncompressed let body = ByteBuffer.of(bytes: [120, 156, 75, 76, 28, 5, 200, 0, 0, 248, 66, 103, 17]) @@ -75,10 +90,14 @@ class HTTPResponseDecompressorTest: XCTestCase { func testDecompressionLimitSizeWithoutContentLenghtHeaderFails() { let channel = EmbeddedChannel() XCTAssertNoThrow(try channel.pipeline.addHandler(NIOHTTPResponseDecompressor(limit: .size(271))).wait()) - + let headers = HTTPHeaders([("Content-Encoding", "deflate")]) - - XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers)))) + + XCTAssertNoThrow( + try channel.writeInbound( + HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers)) + ) + ) // this compressed payload is 272 bytes long uncompressed let body = ByteBuffer.of(bytes: [120, 156, 75, 76, 28, 5, 200, 0, 0, 248, 66, 103, 17]) @@ -92,21 +111,29 @@ class HTTPResponseDecompressorTest: XCTestCase { XCTAssertNoThrow(try channel.pipeline.addHandler(NIOHTTPResponseDecompressor(limit: .ratio(21))).wait()) let headers = HTTPHeaders([("Content-Encoding", "deflate"), ("Content-Length", "13")]) - - XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers)))) + + XCTAssertNoThrow( + try channel.writeInbound( + HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers)) + ) + ) // this compressed payload is 272 bytes long uncompressed let body = ByteBuffer.of(bytes: [120, 156, 75, 76, 28, 5, 200, 0, 0, 248, 66, 103, 17]) XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.body(body))) } - + func testDecompressionLimitRatioWithContentLenghtHeaderFails() { let channel = EmbeddedChannel() XCTAssertNoThrow(try channel.pipeline.addHandler(NIOHTTPResponseDecompressor(limit: .ratio(20))).wait()) let headers = HTTPHeaders([("Content-Encoding", "deflate"), ("Content-Length", "13")]) - - XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers)))) + + XCTAssertNoThrow( + try channel.writeInbound( + HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers)) + ) + ) // this compressed payload is 272 bytes long uncompressed let body = ByteBuffer.of(bytes: [120, 156, 75, 76, 28, 5, 200, 0, 0, 248, 66, 103, 17]) @@ -114,14 +141,18 @@ class HTTPResponseDecompressorTest: XCTestCase { XCTAssertEqual(error as? NIOHTTPDecompression.DecompressionError, .limit) } } - + func testDecompressionLimitRatioWithoutContentLenghtHeaderSucceeds() { let channel = EmbeddedChannel() XCTAssertNoThrow(try channel.pipeline.addHandler(NIOHTTPResponseDecompressor(limit: .ratio(21))).wait()) - + let headers = HTTPHeaders([("Content-Encoding", "deflate")]) - - XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers)))) + + XCTAssertNoThrow( + try channel.writeInbound( + HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers)) + ) + ) // this compressed payload is 272 bytes long uncompressed let body = ByteBuffer.of(bytes: [120, 156, 75, 76, 28, 5, 200, 0, 0, 248, 66, 103, 17]) @@ -131,10 +162,14 @@ class HTTPResponseDecompressorTest: XCTestCase { func testDecompressionLimitRatioWithoutContentLenghtHeaderFails() { let channel = EmbeddedChannel() XCTAssertNoThrow(try channel.pipeline.addHandler(NIOHTTPResponseDecompressor(limit: .ratio(20))).wait()) - + let headers = HTTPHeaders([("Content-Encoding", "deflate")]) - - XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers)))) + + XCTAssertNoThrow( + try channel.writeInbound( + HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers)) + ) + ) // this compressed payload is 272 bytes long uncompressed let body = ByteBuffer.of(bytes: [120, 156, 75, 76, 28, 5, 200, 0, 0, 248, 66, 103, 17]) @@ -152,7 +187,14 @@ class HTTPResponseDecompressorTest: XCTestCase { let body = ByteBuffer.of(bytes: [120, 156, 75, 76, 28, 5, 200, 0, 0, 248, 66, 103, 17]) for i in 0..<3 { - XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers))), "\(i)") + XCTAssertNoThrow( + try channel.writeInbound( + HTTPClientResponsePart.head( + .init(version: .init(major: 1, minor: 1), status: .ok, headers: headers) + ) + ), + "\(i)" + ) XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.body(body)), "\(i)") XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.end(nil)), "\(i)") } @@ -164,7 +206,8 @@ class HTTPResponseDecompressorTest: XCTestCase { var body = "" for _ in 1...1000 { - body += "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." + body += + "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." } let algorithms: [(actual: String, announced: String)?] = [ nil, @@ -185,14 +228,23 @@ class HTTPResponseDecompressorTest: XCTestCase { } headers.add(name: "Content-Length", value: "\(compressed.readableBytes)") - XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers)))) + XCTAssertNoThrow( + try channel.writeInbound( + HTTPClientResponsePart.head( + .init(version: .init(major: 1, minor: 1), status: .ok, headers: headers) + ) + ) + ) XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.body(compressed))) XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.end(nil))) - + var head: HTTPClientResponsePart? XCTAssertNoThrow(head = try channel.readInbound(as: HTTPClientResponsePart.self)) - XCTAssertEqual(head, HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers))) - + XCTAssertEqual( + head, + HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers)) + ) + // the response is chunked var next: HTTPClientResponsePart? XCTAssertNoThrow(next = try channel.readInbound(as: HTTPClientResponsePart.self)) @@ -211,16 +263,17 @@ class HTTPResponseDecompressorTest: XCTestCase { XCTAssertEqual(buffer, ByteBuffer.of(string: body)) } } - + func testDecompressionWithoutContentLength() { let channel = EmbeddedChannel() XCTAssertNoThrow(try channel.pipeline.addHandler(NIOHTTPResponseDecompressor(limit: .none)).wait()) var body = "" for _ in 1...1000 { - body += "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." + body += + "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." } - + let algorithms: [(actual: String, announced: String)?] = [ nil, (actual: "gzip", announced: "gzip"), @@ -239,14 +292,23 @@ class HTTPResponseDecompressorTest: XCTestCase { compressed = ByteBuffer.of(string: body) } - XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers)))) + XCTAssertNoThrow( + try channel.writeInbound( + HTTPClientResponsePart.head( + .init(version: .init(major: 1, minor: 1), status: .ok, headers: headers) + ) + ) + ) XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.body(compressed))) XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.end(nil))) - + var head: HTTPClientResponsePart? XCTAssertNoThrow(head = try channel.readInbound(as: HTTPClientResponsePart.self)) - XCTAssertEqual(head, HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers))) - + XCTAssertEqual( + head, + HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers)) + ) + // the response is chunked var next: HTTPClientResponsePart? XCTAssertNoThrow(next = try channel.readInbound(as: HTTPClientResponsePart.self)) @@ -276,7 +338,9 @@ class HTTPResponseDecompressorTest: XCTestCase { let channel = EmbeddedChannel() try channel.pipeline.addHandler(NIOHTTPResponseDecompressor(limit: .none)).wait() let headers = HTTPHeaders([("Content-Encoding", "deflate"), ("Content-Length", "\(compressed.readableBytes)")]) - try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers))) + try channel.writeInbound( + HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers)) + ) XCTAssertThrowsError(try channel.writeInbound(HTTPClientResponsePart.body(compressed))) } @@ -288,7 +352,9 @@ class HTTPResponseDecompressorTest: XCTestCase { let channel = EmbeddedChannel() try channel.pipeline.addHandler(NIOHTTPResponseDecompressor(limit: .none)).wait() let headers = HTTPHeaders([("Content-Encoding", "deflate"), ("Content-Length", "\(compressed.readableBytes)")]) - try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers))) + try channel.writeInbound( + HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers)) + ) XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.body(compressed))) XCTAssertThrowsError(try channel.writeInbound(HTTPClientResponsePart.end(nil))) @@ -314,7 +380,14 @@ class HTTPResponseDecompressorTest: XCTestCase { return buffer } - let rc = CNIOExtrasZlib_deflateInit2(&stream, Z_DEFAULT_COMPRESSION, Z_DEFLATED, windowBits, 8, Z_DEFAULT_STRATEGY) + let rc = CNIOExtrasZlib_deflateInit2( + &stream, + Z_DEFAULT_COMPRESSION, + Z_DEFLATED, + windowBits, + 8, + Z_DEFAULT_STRATEGY + ) XCTAssertEqual(Z_OK, rc) defer { @@ -328,15 +401,19 @@ class HTTPResponseDecompressorTest: XCTestCase { body.readWithUnsafeMutableReadableBytes { dataPtr in let typedPtr = dataPtr.baseAddress!.assumingMemoryBound(to: UInt8.self) - let typedDataPtr = UnsafeMutableBufferPointer(start: typedPtr, - count: dataPtr.count) + let typedDataPtr = UnsafeMutableBufferPointer( + start: typedPtr, + count: dataPtr.count + ) stream.avail_in = UInt32(typedDataPtr.count) stream.next_in = typedDataPtr.baseAddress! buffer.writeWithUnsafeMutableBytes(minimumWritableBytes: 0) { outputPtr in - let typedOutputPtr = UnsafeMutableBufferPointer(start: outputPtr.baseAddress!.assumingMemoryBound(to: UInt8.self), - count: outputPtr.count) + let typedOutputPtr = UnsafeMutableBufferPointer( + start: outputPtr.baseAddress!.assumingMemoryBound(to: UInt8.self), + count: outputPtr.count + ) stream.avail_out = UInt32(typedOutputPtr.count) stream.next_out = typedOutputPtr.baseAddress! let rc = deflate(&stream, Z_FINISH) diff --git a/Tests/NIOHTTPTypesHTTP1Tests/NIOHTTPTypesHTTP1Tests.swift b/Tests/NIOHTTPTypesHTTP1Tests/NIOHTTPTypesHTTP1Tests.swift index 1baec754..e4d86531 100644 --- a/Tests/NIOHTTPTypesHTTP1Tests/NIOHTTPTypesHTTP1Tests.swift +++ b/Tests/NIOHTTPTypesHTTP1Tests/NIOHTTPTypesHTTP1Tests.swift @@ -48,41 +48,65 @@ final class NIOHTTPTypesHTTP1Tests: XCTestCase { super.tearDown() } - static let request = HTTPRequest(method: .get, scheme: "https", authority: "www.example.com", path: "/", headerFields: [ - .accept: "*/*", - .acceptEncoding: "gzip", - .acceptEncoding: "br", - .cookie: "a=b", - .cookie: "c=d", - .trailer: "X-Foo", - ]) - - static let requestNoSplitCookie = HTTPRequest(method: .get, scheme: "https", authority: "www.example.com", path: "/", headerFields: [ - .accept: "*/*", - .acceptEncoding: "gzip", - .acceptEncoding: "br", - .cookie: "a=b; c=d", - .trailer: "X-Foo", - ]) - - static let oldRequest = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/", headers: [ - "Host": "www.example.com", - "Accept": "*/*", - "Accept-Encoding": "gzip", - "Accept-Encoding": "br", - "Cookie": "a=b; c=d", - "Trailer": "X-Foo", - ]) - - static let response = HTTPResponse(status: .ok, headerFields: [ - .server: "HTTPServer/1.0", - .trailer: "X-Foo", - ]) - - static let oldResponse = HTTPResponseHead(version: .http1_1, status: .ok, headers: [ - "Server": "HTTPServer/1.0", - "Trailer": "X-Foo", - ]) + static let request = HTTPRequest( + method: .get, + scheme: "https", + authority: "www.example.com", + path: "/", + headerFields: [ + .accept: "*/*", + .acceptEncoding: "gzip", + .acceptEncoding: "br", + .cookie: "a=b", + .cookie: "c=d", + .trailer: "X-Foo", + ] + ) + + static let requestNoSplitCookie = HTTPRequest( + method: .get, + scheme: "https", + authority: "www.example.com", + path: "/", + headerFields: [ + .accept: "*/*", + .acceptEncoding: "gzip", + .acceptEncoding: "br", + .cookie: "a=b; c=d", + .trailer: "X-Foo", + ] + ) + + static let oldRequest = HTTPRequestHead( + version: .http1_1, + method: .GET, + uri: "/", + headers: [ + "Host": "www.example.com", + "Accept": "*/*", + "Accept-Encoding": "gzip", + "Accept-Encoding": "br", + "Cookie": "a=b; c=d", + "Trailer": "X-Foo", + ] + ) + + static let response = HTTPResponse( + status: .ok, + headerFields: [ + .server: "HTTPServer/1.0", + .trailer: "X-Foo", + ] + ) + + static let oldResponse = HTTPResponseHead( + version: .http1_1, + status: .ok, + headers: [ + "Server": "HTTPServer/1.0", + "Trailer": "X-Foo", + ] + ) static let trailers: HTTPFields = [.xFoo: "Bar"] diff --git a/Tests/NIOHTTPTypesHTTP2Tests/NIOHTTPTypesHTTP2Tests.swift b/Tests/NIOHTTPTypesHTTP2Tests/NIOHTTPTypesHTTP2Tests.swift index 830768d7..769b7255 100644 --- a/Tests/NIOHTTPTypesHTTP2Tests/NIOHTTPTypesHTTP2Tests.swift +++ b/Tests/NIOHTTPTypesHTTP2Tests/NIOHTTPTypesHTTP2Tests.swift @@ -63,14 +63,20 @@ final class NIOHTTPTypesHTTP2Tests: XCTestCase { super.tearDown() } - static let request = HTTPRequest(method: .get, scheme: "https", authority: "www.example.com", path: "/", headerFields: [ - .accept: "*/*", - .acceptEncoding: "gzip", - .acceptEncoding: "br", - .trailer: "X-Foo", - .cookie: "a=b", - .cookie: "c=d", - ]) + static let request = HTTPRequest( + method: .get, + scheme: "https", + authority: "www.example.com", + path: "/", + headerFields: [ + .accept: "*/*", + .acceptEncoding: "gzip", + .acceptEncoding: "br", + .trailer: "X-Foo", + .cookie: "a=b", + .cookie: "c=d", + ] + ) static let oldRequest: HPACKHeaders = [ ":method": "GET", @@ -85,10 +91,13 @@ final class NIOHTTPTypesHTTP2Tests: XCTestCase { "cookie": "c=d", ] - static let response = HTTPResponse(status: .ok, headerFields: [ - .server: "HTTPServer/1.0", - .trailer: "X-Foo", - ]) + static let response = HTTPResponse( + status: .ok, + headerFields: [ + .server: "HTTPServer/1.0", + .trailer: "X-Foo", + ] + ) static let oldResponse: HPACKHeaders = [ ":status": "200", diff --git a/Tests/NIONFS3Tests/NFS3FileSystemTests.swift b/Tests/NIONFS3Tests/NFS3FileSystemTests.swift index 2f95200a..55cf29ad 100644 --- a/Tests/NIONFS3Tests/NFS3FileSystemTests.swift +++ b/Tests/NIONFS3Tests/NFS3FileSystemTests.swift @@ -25,11 +25,23 @@ final class NFS3FileSystemTests: XCTestCase { } func readdirplus(_ call: NFS3CallReadDirPlus, promise: EventLoopPromise) { - promise.succeed(.init(result: .okay(.init(cookieVerifier: .init(rawValue: 11111), - entries: [.init(fileID: .init(rawValue: 22222), - fileName: "file", - cookie: .init(rawValue: 33333))], - eof: true)))) + promise.succeed( + .init( + result: .okay( + .init( + cookieVerifier: .init(rawValue: 11111), + entries: [ + .init( + fileID: .init(rawValue: 22222), + fileName: "file", + cookie: .init(rawValue: 33333) + ) + ], + eof: true + ) + ) + ) + ) } func mount(_ call: MountCallMount, promise: EventLoopPromise) { @@ -83,17 +95,31 @@ final class NFS3FileSystemTests: XCTestCase { } let fs = MyOnlyReadDirPlusFS() let promise = eventLoop.makePromise(of: NFS3ReplyReadDir.self) - fs.readdir(.init(fileHandle: .init(123), - cookie: .init(rawValue: 234), - cookieVerifier: .init(rawValue: 345), - maxResultByteCount: .init(rawValue: 456)), - promise: promise) + fs.readdir( + .init( + fileHandle: .init(123), + cookie: .init(rawValue: 234), + cookieVerifier: .init(rawValue: 345), + maxResultByteCount: .init(rawValue: 456) + ), + promise: promise + ) let actualResult = try promise.futureResult.wait() - let expectedResult = NFS3ReplyReadDir(result: .okay(.init(cookieVerifier: .init(rawValue: 11111), - entries: [.init(fileID: .init(rawValue: 22222), - fileName: "file", - cookie: .init(rawValue: 33333))], - eof: true))) + let expectedResult = NFS3ReplyReadDir( + result: .okay( + .init( + cookieVerifier: .init(rawValue: 11111), + entries: [ + .init( + fileID: .init(rawValue: 22222), + fileName: "file", + cookie: .init(rawValue: 33333) + ) + ], + eof: true + ) + ) + ) XCTAssertEqual(expectedResult, actualResult) } } diff --git a/Tests/NIONFS3Tests/NFS3ReplyEncoderTest.swift b/Tests/NIONFS3Tests/NFS3ReplyEncoderTest.swift index ed9c12c2..dff5784e 100644 --- a/Tests/NIONFS3Tests/NFS3ReplyEncoderTest.swift +++ b/Tests/NIONFS3Tests/NFS3ReplyEncoderTest.swift @@ -12,9 +12,9 @@ // //===----------------------------------------------------------------------===// -import XCTest -import NIONFS3 import NIOCore +import NIONFS3 +import XCTest final class NFS3ReplyEncoderTest: XCTestCase { func testPartialReadEncoding() { @@ -22,14 +22,32 @@ final class NFS3ReplyEncoderTest: XCTestCase { let expectedPayload = ByteBuffer(repeating: UInt8(ascii: "j"), count: payloadLength) let expectedFillBytes = (4 - (payloadLength % 4)) % 4 - let reply = RPCNFS3Reply(rpcReply: RPCReply(xid: 12345, - status: .messageAccepted(.init(verifier: .init(flavor: .noAuth, - opaque: nil), - status: .success))), - nfsReply: .read(.init(result: .okay(.init(attributes: nil, - count: .init(rawValue: 7), - eof: false, - data: expectedPayload))))) + let reply = RPCNFS3Reply( + rpcReply: RPCReply( + xid: 12345, + status: .messageAccepted( + .init( + verifier: .init( + flavor: .noAuth, + opaque: nil + ), + status: .success + ) + ) + ), + nfsReply: .read( + .init( + result: .okay( + .init( + attributes: nil, + count: .init(rawValue: 7), + eof: false, + data: expectedPayload + ) + ) + ) + ) + ) var partialSerialisation = ByteBuffer() let (bytesWritten, nextStep) = partialSerialisation.writeRPCNFS3ReplyPartially(reply) @@ -47,16 +65,24 @@ final class NFS3ReplyEncoderTest: XCTestCase { XCTAssertEqual(bytesWruttenFull, fullSerialisation.readableBytes) XCTAssert(fullSerialisation.readableBytesView.starts(with: partialSerialisation.readableBytesView)) - XCTAssert(fullSerialisation.readableBytesView - .dropFirst(partialSerialisation.readableBytes) - .prefix(expectedPayload.readableBytes) - .elementsEqual(expectedPayload.readableBytesView)) + XCTAssert( + fullSerialisation.readableBytesView + .dropFirst(partialSerialisation.readableBytes) + .prefix(expectedPayload.readableBytes) + .elementsEqual(expectedPayload.readableBytesView) + ) - XCTAssertEqual(partialSerialisation.readableBytes + payloadLength + expectedFillBytes, - fullSerialisation.readableBytes) - XCTAssertEqual(UInt32(payloadLength), - partialSerialisation.getInteger(at: partialSerialisation.writerIndex - 4, - as: UInt32.self)) + XCTAssertEqual( + partialSerialisation.readableBytes + payloadLength + expectedFillBytes, + fullSerialisation.readableBytes + ) + XCTAssertEqual( + UInt32(payloadLength), + partialSerialisation.getInteger( + at: partialSerialisation.writerIndex - 4, + as: UInt32.self + ) + ) } } @@ -64,14 +90,32 @@ final class NFS3ReplyEncoderTest: XCTestCase { for payloadLength in 0..<1 { let expectedPayload = ByteBuffer(repeating: UInt8(ascii: "j"), count: payloadLength) - let expectedReply = RPCNFS3Reply(rpcReply: RPCReply(xid: 12345, - status: .messageAccepted(.init(verifier: .init(flavor: .noAuth, - opaque: nil), - status: .success))), - nfsReply: .read(.init(result: .okay(.init(attributes: nil, - count: .init(rawValue: 7), - eof: false, - data: expectedPayload))))) + let expectedReply = RPCNFS3Reply( + rpcReply: RPCReply( + xid: 12345, + status: .messageAccepted( + .init( + verifier: .init( + flavor: .noAuth, + opaque: nil + ), + status: .success + ) + ) + ), + nfsReply: .read( + .init( + result: .okay( + .init( + attributes: nil, + count: .init(rawValue: 7), + eof: false, + data: expectedPayload + ) + ) + ) + ) + ) var fullSerialisation = ByteBuffer() let bytesWrittenFull = fullSerialisation.writeRPCNFS3Reply(expectedReply) @@ -84,9 +128,11 @@ final class NFS3ReplyEncoderTest: XCTestCase { var actualNFS3Reply: NFS3ReplyRead? = nil XCTAssertNoThrow(actualNFS3Reply = try actualReply.1.readNFS3ReplyRead()) XCTAssertEqual(0, actualReply.1.readableBytes) - XCTAssertEqual(expectedReply.nfsReply, - actualNFS3Reply.map { NFS3Reply.read($0) }, - "parsing failed for payload length \(payloadLength)") + XCTAssertEqual( + expectedReply.nfsReply, + actualNFS3Reply.map { NFS3Reply.read($0) }, + "parsing failed for payload length \(payloadLength)" + ) } } } diff --git a/Tests/NIONFS3Tests/NFS3RoundtripTests.swift b/Tests/NIONFS3Tests/NFS3RoundtripTests.swift index e78e9fa6..7ffc9d53 100644 --- a/Tests/NIONFS3Tests/NFS3RoundtripTests.swift +++ b/Tests/NIONFS3Tests/NFS3RoundtripTests.swift @@ -13,9 +13,9 @@ //===----------------------------------------------------------------------===// import NIOCore +import NIONFS3 import NIOTestUtils import XCTest -import NIONFS3 final class NFS3RoundtripTests: XCTestCase { func testRegularCallsRoundtrip() { @@ -31,17 +31,33 @@ final class NFS3RoundtripTests: XCTestCase { let nullCall1 = NFS3Call.null(.init()) let pathConfCall1 = NFS3Call.pathconf(.init(object: NFS3FileHandle(#line))) let readCall1 = NFS3Call.read(.init(fileHandle: NFS3FileHandle(#line), offset: 123, count: 456)) - let readDirPlusCall1 = NFS3Call.readdirplus(.init(fileHandle: NFS3FileHandle(#line), cookie: 345, - cookieVerifier: 879, dirCount: 23488, maxCount: 2342888)) - let readDirCall1 = NFS3Call.readdir(.init(fileHandle: NFS3FileHandle(#line), cookie: 345, cookieVerifier: 879, maxResultByteCount: 234797)) + let readDirPlusCall1 = NFS3Call.readdirplus( + .init( + fileHandle: NFS3FileHandle(#line), + cookie: 345, + cookieVerifier: 879, + dirCount: 23488, + maxCount: 2_342_888 + ) + ) + let readDirCall1 = NFS3Call.readdir( + .init(fileHandle: NFS3FileHandle(#line), cookie: 345, cookieVerifier: 879, maxResultByteCount: 234797) + ) let readlinkCall1 = NFS3Call.readlink(.init(symlink: NFS3FileHandle(#line))) - let setattrCall1 = NFS3Call.setattr(.init(object: NFS3FileHandle(#line), - newAttributes: .init(mode: 0o146, - uid: 1, gid: 2, - size: 3, - atime: .init(seconds: 4, nanoseconds: 5), - mtime: .init(seconds: 6, nanoseconds: 7)), - guard: .init(seconds: 8, nanoseconds: 0))) + let setattrCall1 = NFS3Call.setattr( + .init( + object: NFS3FileHandle(#line), + newAttributes: .init( + mode: 0o146, + uid: 1, + gid: 2, + size: 3, + atime: .init(seconds: 4, nanoseconds: 5), + mtime: .init(seconds: 6, nanoseconds: 7) + ), + guard: .init(seconds: 8, nanoseconds: 0) + ) + ) var xid: UInt32 = 0 func makeInputOutputPair(_ nfsCall: NFS3Call) -> (ByteBuffer, [RPCNFS3Call]) { @@ -55,55 +71,82 @@ final class NFS3RoundtripTests: XCTestCase { return (buffer, [rpcNFS3Call]) } - XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder(inputOutputPairs: [ - makeInputOutputPair(mountCall1), - makeInputOutputPair(mountCall2), - makeInputOutputPair(unmountCall1), - makeInputOutputPair(accessCall1), - makeInputOutputPair(fsInfoCall1), - makeInputOutputPair(fsStatCall1), - makeInputOutputPair(getattrCall1), - makeInputOutputPair(lookupCall1), - makeInputOutputPair(nullCall1), - makeInputOutputPair(mountCallNull), - makeInputOutputPair(pathConfCall1), - makeInputOutputPair(readCall1), - makeInputOutputPair(readDirCall1), - makeInputOutputPair(readDirPlusCall1), - makeInputOutputPair(readlinkCall1), - makeInputOutputPair(setattrCall1), - ], - decoderFactory: { NFS3CallDecoder() })) + XCTAssertNoThrow( + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [ + makeInputOutputPair(mountCall1), + makeInputOutputPair(mountCall2), + makeInputOutputPair(unmountCall1), + makeInputOutputPair(accessCall1), + makeInputOutputPair(fsInfoCall1), + makeInputOutputPair(fsStatCall1), + makeInputOutputPair(getattrCall1), + makeInputOutputPair(lookupCall1), + makeInputOutputPair(nullCall1), + makeInputOutputPair(mountCallNull), + makeInputOutputPair(pathConfCall1), + makeInputOutputPair(readCall1), + makeInputOutputPair(readDirCall1), + makeInputOutputPair(readDirPlusCall1), + makeInputOutputPair(readlinkCall1), + makeInputOutputPair(setattrCall1), + ], + decoderFactory: { NFS3CallDecoder() } + ) + ) } func testCallsWithMaxIntegersRoundtrip() { - let accessCall1 = NFS3Call.access(NFS3CallAccess(object: NFS3FileHandle(.max), - access: NFS3Access(rawValue: .max))) + let accessCall1 = NFS3Call.access( + NFS3CallAccess( + object: NFS3FileHandle(.max), + access: NFS3Access(rawValue: .max) + ) + ) let fsInfoCall1 = NFS3Call.fsinfo(.init(fsroot: NFS3FileHandle(.max))) let fsStatCall1 = NFS3Call.fsstat(.init(fsroot: NFS3FileHandle(.max))) let getattrCall1 = NFS3Call.getattr(.init(fileHandle: NFS3FileHandle(.max))) let lookupCall1 = NFS3Call.lookup(.init(dir: NFS3FileHandle(.max), name: "⚠️")) let pathConfCall1 = NFS3Call.pathconf(.init(object: NFS3FileHandle(.max))) - let readCall1 = NFS3Call.read(.init(fileHandle: NFS3FileHandle(.max), - offset: .init(rawValue: .max), count: .init(rawValue: .max))) - let readDirPlusCall1 = NFS3Call.readdirplus(.init(fileHandle: NFS3FileHandle(.max), - cookie: .init(rawValue: .max), - cookieVerifier: .init(rawValue: .max), - dirCount: .init(rawValue: .max), - maxCount: .init(rawValue: .max))) - let readDirCall1 = NFS3Call.readdir(.init(fileHandle: NFS3FileHandle(.max), - cookie: .init(rawValue: .max), - cookieVerifier: .init(rawValue: .max), - maxResultByteCount: .init(rawValue: .max))) + let readCall1 = NFS3Call.read( + .init( + fileHandle: NFS3FileHandle(.max), + offset: .init(rawValue: .max), + count: .init(rawValue: .max) + ) + ) + let readDirPlusCall1 = NFS3Call.readdirplus( + .init( + fileHandle: NFS3FileHandle(.max), + cookie: .init(rawValue: .max), + cookieVerifier: .init(rawValue: .max), + dirCount: .init(rawValue: .max), + maxCount: .init(rawValue: .max) + ) + ) + let readDirCall1 = NFS3Call.readdir( + .init( + fileHandle: NFS3FileHandle(.max), + cookie: .init(rawValue: .max), + cookieVerifier: .init(rawValue: .max), + maxResultByteCount: .init(rawValue: .max) + ) + ) let readlinkCall1 = NFS3Call.readlink(.init(symlink: NFS3FileHandle(.max))) - let setattrCall1 = NFS3Call.setattr(.init(object: NFS3FileHandle(.max), - newAttributes: .init(mode: .init(rawValue: .max), - uid: .init(rawValue: .max), - gid: .init(rawValue: .max), - size: .init(rawValue: .max), - atime: .init(seconds: .max, nanoseconds: .max), - mtime: .init(seconds: .max, nanoseconds: .max)), - guard: .init(seconds: .max, nanoseconds: .max))) + let setattrCall1 = NFS3Call.setattr( + .init( + object: NFS3FileHandle(.max), + newAttributes: .init( + mode: .init(rawValue: .max), + uid: .init(rawValue: .max), + gid: .init(rawValue: .max), + size: .init(rawValue: .max), + atime: .init(seconds: .max, nanoseconds: .max), + mtime: .init(seconds: .max, nanoseconds: .max) + ), + guard: .init(seconds: .max, nanoseconds: .max) + ) + ) var xid: UInt32 = 0 func makeInputOutputPair(_ nfsCall: NFS3Call) -> (ByteBuffer, [RPCNFS3Call]) { @@ -117,48 +160,66 @@ final class NFS3RoundtripTests: XCTestCase { return (buffer, [rpcNFS3Call]) } - XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder(inputOutputPairs: [ - makeInputOutputPair(accessCall1), - makeInputOutputPair(fsInfoCall1), - makeInputOutputPair(fsStatCall1), - makeInputOutputPair(getattrCall1), - makeInputOutputPair(lookupCall1), - makeInputOutputPair(pathConfCall1), - makeInputOutputPair(readCall1), - makeInputOutputPair(readDirPlusCall1), - makeInputOutputPair(readDirCall1), - makeInputOutputPair(readlinkCall1), - makeInputOutputPair(setattrCall1), - ], - decoderFactory: { NFS3CallDecoder() })) + XCTAssertNoThrow( + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [ + makeInputOutputPair(accessCall1), + makeInputOutputPair(fsInfoCall1), + makeInputOutputPair(fsStatCall1), + makeInputOutputPair(getattrCall1), + makeInputOutputPair(lookupCall1), + makeInputOutputPair(pathConfCall1), + makeInputOutputPair(readCall1), + makeInputOutputPair(readDirPlusCall1), + makeInputOutputPair(readDirCall1), + makeInputOutputPair(readlinkCall1), + makeInputOutputPair(setattrCall1), + ], + decoderFactory: { NFS3CallDecoder() } + ) + ) } func testRegularOkayRepliesRoundtrip() { func makeRandomFileAttr() -> NFS3FileAttr { - return .init(type: .init(rawValue: .random(in: 1 ... 7))!, - mode: .init(rawValue: .random(in: 0o000 ... 0o777)), - nlink: .random(in: .min ... .max), - uid: .init(rawValue: .random(in: .min ... .max)), - gid: .init(rawValue: .random(in: .min ... .max)), - size: .init(rawValue: .random(in: .min ... .max)), - used: .init(rawValue: .random(in: .min ... .max)), - rdev: .init(rawValue: .random(in: .min ... .max)), - fsid: .random(in: .min ... .max), - fileid: .init(rawValue: .random(in: .min ... .max)), - atime: .init(seconds: .random(in: .min ... .max), - nanoseconds: .random(in: .min ... .max)), - mtime: .init(seconds: .random(in: .min ... .max), - nanoseconds: .random(in: .min ... .max)), - ctime: .init(seconds: .random(in: .min ... .max), - nanoseconds: .random(in: .min ... .max))) + .init( + type: .init(rawValue: .random(in: 1...7))!, + mode: .init(rawValue: .random(in: 0o000...0o777)), + nlink: .random(in: .min ... .max), + uid: .init(rawValue: .random(in: .min ... .max)), + gid: .init(rawValue: .random(in: .min ... .max)), + size: .init(rawValue: .random(in: .min ... .max)), + used: .init(rawValue: .random(in: .min ... .max)), + rdev: .init(rawValue: .random(in: .min ... .max)), + fsid: .random(in: .min ... .max), + fileid: .init(rawValue: .random(in: .min ... .max)), + atime: .init( + seconds: .random(in: .min ... .max), + nanoseconds: .random(in: .min ... .max) + ), + mtime: .init( + seconds: .random(in: .min ... .max), + nanoseconds: .random(in: .min ... .max) + ), + ctime: .init( + seconds: .random(in: .min ... .max), + nanoseconds: .random(in: .min ... .max) + ) + ) } let mountNullReply1 = NFS3Reply.mountNull let mountReply1 = NFS3Reply.mount(MountReplyMount(result: .okay(.init(fileHandle: NFS3FileHandle(#line))))) let mountReply2 = NFS3Reply.mount(.init(result: .okay(.init(fileHandle: NFS3FileHandle(#line))))) let unmountReply1 = NFS3Reply.unmount(.init()) - let accessReply1 = NFS3Reply.access(.init(result: .okay(.init(dirAttributes: makeRandomFileAttr(), access: .allReadOnly)))) - let fsInfoReply1 = NFS3Reply.fsinfo(.init(result: - .okay(.init(attributes: makeRandomFileAttr(), + let accessReply1 = NFS3Reply.access( + .init(result: .okay(.init(dirAttributes: makeRandomFileAttr(), access: .allReadOnly))) + ) + let fsInfoReply1 = NFS3Reply.fsinfo( + .init( + result: + .okay( + .init( + attributes: makeRandomFileAttr(), rtmax: .random(in: .min ... .max), rtpref: .random(in: .min ... .max), rtmult: .random(in: .min ... .max), @@ -167,72 +228,167 @@ final class NFS3RoundtripTests: XCTestCase { wtmult: .random(in: .min ... .max), dtpref: .random(in: .min ... .max), maxFileSize: .init(rawValue: .random(in: .min ... .max)), - timeDelta: .init(seconds: .random(in: .min ... .max), - nanoseconds: .random(in: .min ... .max)), - properties: .init(rawValue: .random(in: .min ... .max)))))) - let fsStatReply1 = NFS3Reply.fsstat(.init(result: - .okay(.init(attributes: makeRandomFileAttr(), + timeDelta: .init( + seconds: .random(in: .min ... .max), + nanoseconds: .random(in: .min ... .max) + ), + properties: .init(rawValue: .random(in: .min ... .max)) + ) + ) + ) + ) + let fsStatReply1 = NFS3Reply.fsstat( + .init( + result: + .okay( + .init( + attributes: makeRandomFileAttr(), tbytes: .init(rawValue: .random(in: .min ... .max)), fbytes: .init(rawValue: .random(in: .min ... .max)), abytes: .init(rawValue: .random(in: .min ... .max)), tfiles: .init(rawValue: .random(in: .min ... .max)), ffiles: .init(rawValue: .random(in: .min ... .max)), afiles: .init(rawValue: .random(in: .min ... .max)), - invarsec: .random(in: .min ... .max))))) + invarsec: .random(in: .min ... .max) + ) + ) + ) + ) let getattrReply1 = NFS3Reply.getattr(.init(result: .okay(.init(attributes: makeRandomFileAttr())))) - let lookupReply1 = NFS3Reply.lookup(.init(result: - .okay(.init(fileHandle: NFS3FileHandle(.random(in: .min ... .max)), + let lookupReply1 = NFS3Reply.lookup( + .init( + result: + .okay( + .init( + fileHandle: NFS3FileHandle(.random(in: .min ... .max)), attributes: makeRandomFileAttr(), - dirAttributes: makeRandomFileAttr())))) + dirAttributes: makeRandomFileAttr() + ) + ) + ) + ) let nullReply1 = NFS3Reply.null - let pathConfReply1 = NFS3Reply.pathconf(.init(result: .okay(.init(attributes: makeRandomFileAttr(), - linkMax: .random(in: .min ... .max), - nameMax: .random(in: .min ... .max), - noTrunc: .random(), - chownRestricted: .random(), - caseInsensitive: .random(), - casePreserving: .random())))) - let readReply1 = NFS3Reply.read(.init(result: .okay(.init(attributes: makeRandomFileAttr(), - count: .init(rawValue: .random(in: .min ... .max)), - eof: .random(), - data: ByteBuffer(string: "abc"))))) - let readDirPlusReply1 = NFS3Reply.readdirplus(.init(result: - .okay(.init(dirAttributes: makeRandomFileAttr(), + let pathConfReply1 = NFS3Reply.pathconf( + .init( + result: .okay( + .init( + attributes: makeRandomFileAttr(), + linkMax: .random(in: .min ... .max), + nameMax: .random(in: .min ... .max), + noTrunc: .random(), + chownRestricted: .random(), + caseInsensitive: .random(), + casePreserving: .random() + ) + ) + ) + ) + let readReply1 = NFS3Reply.read( + .init( + result: .okay( + .init( + attributes: makeRandomFileAttr(), + count: .init(rawValue: .random(in: .min ... .max)), + eof: .random(), + data: ByteBuffer(string: "abc") + ) + ) + ) + ) + let readDirPlusReply1 = NFS3Reply.readdirplus( + .init( + result: + .okay( + .init( + dirAttributes: makeRandomFileAttr(), cookieVerifier: .init(rawValue: .random(in: .min ... .max)), - entries: [.init(fileID: .init(rawValue: .random(in: .min ... .max)), - fileName: "asd", - cookie: .init(rawValue: .random(in: .min ... .max)), - nameAttributes: makeRandomFileAttr(), - nameHandle: NFS3FileHandle(.random(in: .min ... .max)))], - eof: .random())))) - let readDirReply1 = NFS3Reply.readdir(.init(result: - .okay(.init(dirAttributes: makeRandomFileAttr(), + entries: [ + .init( + fileID: .init(rawValue: .random(in: .min ... .max)), + fileName: "asd", + cookie: .init(rawValue: .random(in: .min ... .max)), + nameAttributes: makeRandomFileAttr(), + nameHandle: NFS3FileHandle(.random(in: .min ... .max)) + ) + ], + eof: .random() + ) + ) + ) + ) + let readDirReply1 = NFS3Reply.readdir( + .init( + result: + .okay( + .init( + dirAttributes: makeRandomFileAttr(), cookieVerifier: .init(rawValue: .random(in: .min ... .max)), entries: [ - .init(fileID: .init(rawValue: .random(in: .min ... .max)), - fileName: "asd", - cookie: .init(rawValue: .random(in: .min ... .max)))], - eof: .random())))) - let readlinkReply1 = NFS3Reply.readlink(.init(result: .okay(.init(symlinkAttributes: makeRandomFileAttr(), - target: "he")))) - let setattrReply1 = NFS3Reply.setattr(.init(result: - .okay(.init(wcc: .init(before: .some(.init(size: .init(rawValue: .random(in: .min ... .max)), - mtime: .init(seconds: .random(in: .min ... .max), - nanoseconds: .random(in: .min ... .max)), - ctime: .init(seconds: .random(in: .min ... .max), - nanoseconds: .random(in: .min ... .max)))), - after: makeRandomFileAttr()))))) + .init( + fileID: .init(rawValue: .random(in: .min ... .max)), + fileName: "asd", + cookie: .init(rawValue: .random(in: .min ... .max)) + ) + ], + eof: .random() + ) + ) + ) + ) + let readlinkReply1 = NFS3Reply.readlink( + .init( + result: .okay( + .init( + symlinkAttributes: makeRandomFileAttr(), + target: "he" + ) + ) + ) + ) + let setattrReply1 = NFS3Reply.setattr( + .init( + result: + .okay( + .init( + wcc: .init( + before: .some( + .init( + size: .init(rawValue: .random(in: .min ... .max)), + mtime: .init( + seconds: .random(in: .min ... .max), + nanoseconds: .random(in: .min ... .max) + ), + ctime: .init( + seconds: .random(in: .min ... .max), + nanoseconds: .random(in: .min ... .max) + ) + ) + ), + after: makeRandomFileAttr() + ) + ) + ) + ) + ) var xid: UInt32 = 0 var prepopulatedProcs: [UInt32: RPCNFS3ProcedureID] = [:] func makeInputOutputPair(_ nfsReply: NFS3Reply) -> (ByteBuffer, [RPCNFS3Reply]) { var buffer = ByteBuffer() xid += 1 - let rpcNFS3Reply = RPCNFS3Reply(rpcReply: - .init(xid: xid, - status: .messageAccepted(.init(verifier: .init(flavor: .noAuth, opaque: nil), - status: .success))), - nfsReply: nfsReply) + let rpcNFS3Reply = RPCNFS3Reply( + rpcReply: + .init( + xid: xid, + status: .messageAccepted( + .init( + verifier: .init(flavor: .noAuth, opaque: nil), + status: .success + ) + ) + ), + nfsReply: nfsReply + ) prepopulatedProcs[xid] = .init(nfsReply) let oldReadableBytes = buffer.readableBytes let writtenBytes = buffer.writeRPCNFS3Reply(rpcNFS3Reply) @@ -241,26 +397,33 @@ final class NFS3RoundtripTests: XCTestCase { return (buffer, [rpcNFS3Reply]) } - XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( - inputOutputPairs: [ - makeInputOutputPair(mountNullReply1), - makeInputOutputPair(mountReply1), - makeInputOutputPair(mountReply2), - makeInputOutputPair(unmountReply1), - makeInputOutputPair(accessReply1), - makeInputOutputPair(fsInfoReply1), - makeInputOutputPair(fsStatReply1), - makeInputOutputPair(getattrReply1), - makeInputOutputPair(lookupReply1), - makeInputOutputPair(nullReply1), - makeInputOutputPair(pathConfReply1), - makeInputOutputPair(readReply1), - makeInputOutputPair(readDirPlusReply1), - makeInputOutputPair(readDirReply1), - makeInputOutputPair(readlinkReply1), - makeInputOutputPair(setattrReply1), - ], - decoderFactory: { NFS3ReplyDecoder(prepopulatedProcecedures: prepopulatedProcs, - allowDuplicateReplies: true) })) + XCTAssertNoThrow( + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [ + makeInputOutputPair(mountNullReply1), + makeInputOutputPair(mountReply1), + makeInputOutputPair(mountReply2), + makeInputOutputPair(unmountReply1), + makeInputOutputPair(accessReply1), + makeInputOutputPair(fsInfoReply1), + makeInputOutputPair(fsStatReply1), + makeInputOutputPair(getattrReply1), + makeInputOutputPair(lookupReply1), + makeInputOutputPair(nullReply1), + makeInputOutputPair(pathConfReply1), + makeInputOutputPair(readReply1), + makeInputOutputPair(readDirPlusReply1), + makeInputOutputPair(readDirReply1), + makeInputOutputPair(readlinkReply1), + makeInputOutputPair(setattrReply1), + ], + decoderFactory: { + NFS3ReplyDecoder( + prepopulatedProcecedures: prepopulatedProcs, + allowDuplicateReplies: true + ) + } + ) + ) } } diff --git a/Tests/NIOSOCKSTests/ClientGreeting+Tests.swift b/Tests/NIOSOCKSTests/ClientGreeting+Tests.swift index ddd89aa8..45d634fa 100644 --- a/Tests/NIOSOCKSTests/ClientGreeting+Tests.swift +++ b/Tests/NIOSOCKSTests/ClientGreeting+Tests.swift @@ -13,22 +13,28 @@ //===----------------------------------------------------------------------===// import NIOCore -@testable import NIOSOCKS import XCTest +@testable import NIOSOCKS + public class ClientGreetingTests: XCTestCase { - + func testInitFromBuffer() { var buffer = ByteBuffer() buffer.writeBytes([0x05, 0x01, 0x00]) XCTAssertNoThrow(XCTAssertEqual(try buffer.readClientGreeting(), .init(methods: [.noneRequired]))) XCTAssertEqual(buffer.readableBytes, 0) - + buffer.writeBytes([0x05, 0x03, 0x00, 0x01, 0x02]) - XCTAssertNoThrow(XCTAssertEqual(try buffer.readClientGreeting(), .init(methods: [.noneRequired, .gssapi, .usernamePassword]))) + XCTAssertNoThrow( + XCTAssertEqual( + try buffer.readClientGreeting(), + .init(methods: [.noneRequired, .gssapi, .usernamePassword]) + ) + ) XCTAssertEqual(buffer.readableBytes, 0) } - + func testWriting() { var buffer = ByteBuffer() let greeting = ClientGreeting(methods: [.noneRequired]) diff --git a/Tests/NIOSOCKSTests/ClientRequest+Tests.swift b/Tests/NIOSOCKSTests/ClientRequest+Tests.swift index f2980eaf..f50b418c 100644 --- a/Tests/NIOSOCKSTests/ClientRequest+Tests.swift +++ b/Tests/NIOSOCKSTests/ClientRequest+Tests.swift @@ -13,63 +13,81 @@ //===----------------------------------------------------------------------===// import NIOCore -@testable import NIOSOCKS import XCTest +@testable import NIOSOCKS + public class ClientRequestTests: XCTestCase { - + } // MARK: - SOCKSRequest extension ClientRequestTests { - + func testWriteClientRequest() { var buffer = ByteBuffer() let req = SOCKSRequest(command: .connect, addressType: .address(try! .init(ipAddress: "192.168.1.1", port: 80))) XCTAssertEqual(buffer.writeClientRequest(req), 10) XCTAssertEqual(buffer.readableBytes, 10) - XCTAssertEqual(buffer.readBytes(length: 10)!, - [0x05, 0x01, 0x00, 1, 192, 168, 1, 1, 0x00, 0x50]) + XCTAssertEqual( + buffer.readBytes(length: 10)!, + [0x05, 0x01, 0x00, 1, 192, 168, 1, 1, 0x00, 0x50] + ) } - + } // MARK: - AddressType extension ClientRequestTests { - + func testReadAddressType() { var ipv4 = ByteBuffer(bytes: [0x01, 0x0a, 0x0b, 0x0c, 0x0d, 0x00, 0x50]) XCTAssertEqual(ipv4.readableBytes, 7) - XCTAssertNoThrow(XCTAssertEqual(try ipv4.readAddressType(), .address(try! .init(ipAddress: "10.11.12.13", port: 80)))) + XCTAssertNoThrow( + XCTAssertEqual(try ipv4.readAddressType(), .address(try! .init(ipAddress: "10.11.12.13", port: 80))) + ) XCTAssertEqual(ipv4.readableBytes, 0) - - var domain = ByteBuffer(bytes: [0x03, 0x0a, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x63, 0x6f, 0x6d, 0x00, 0x50]) + + var domain = ByteBuffer(bytes: [ + 0x03, 0x0a, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x63, 0x6f, 0x6d, 0x00, 0x50, + ]) XCTAssertEqual(domain.readableBytes, 14) XCTAssertNoThrow(XCTAssertEqual(try domain.readAddressType(), .domain("google.com", port: 80))) XCTAssertEqual(domain.readableBytes, 0) - - var ipv6 = ByteBuffer(bytes: [0x04, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0xa0, 0x00, 0x50]) + + var ipv6 = ByteBuffer(bytes: [ + 0x04, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0xa0, 0x00, + 0x50, + ]) XCTAssertEqual(ipv6.readableBytes, 19) - XCTAssertNoThrow(XCTAssertEqual(try ipv6.readAddressType(), .address(try! .init(ipAddress: "0102:0304:0506:0708:090a:0b0c:0d0e:0fa0", port: 80)))) + XCTAssertNoThrow( + XCTAssertEqual( + try ipv6.readAddressType(), + .address(try! .init(ipAddress: "0102:0304:0506:0708:090a:0b0c:0d0e:0fa0", port: 80)) + ) + ) XCTAssertEqual(ipv6.readableBytes, 0) - + } - - func testWriteAddressType(){ + + func testWriteAddressType() { var ipv4 = ByteBuffer() XCTAssertEqual(ipv4.writeAddressType(.address(try! .init(ipAddress: "192.168.1.1", port: 80))), 7) XCTAssertEqual(ipv4.readBytes(length: 5)!, [1, 192, 168, 1, 1]) XCTAssertEqual(ipv4.readInteger(as: UInt16.self)!, 80) - + var ipv6 = ByteBuffer() - XCTAssertEqual(ipv6.writeAddressType(.address(try! .init(ipAddress: "0001:0002:0003:0004:0005:0006:0007:0008", port: 80))), 19) + XCTAssertEqual( + ipv6.writeAddressType(.address(try! .init(ipAddress: "0001:0002:0003:0004:0005:0006:0007:0008", port: 80))), + 19 + ) XCTAssertEqual(ipv6.readBytes(length: 17)!, [4, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8]) XCTAssertEqual(ipv6.readInteger(as: UInt16.self)!, 80) - + var domain = ByteBuffer() XCTAssertEqual(domain.writeAddressType(.domain("127.0.0.1", port: 80)), 13) XCTAssertEqual(domain.readBytes(length: 11)!, [3, 9, 49, 50, 55, 46, 48, 46, 48, 46, 49]) XCTAssertEqual(domain.readInteger(as: UInt16.self)!, 80) } - + } diff --git a/Tests/NIOSOCKSTests/ClientStateMachine+Tests.swift b/Tests/NIOSOCKSTests/ClientStateMachine+Tests.swift index d374479d..b2337dd5 100644 --- a/Tests/NIOSOCKSTests/ClientStateMachine+Tests.swift +++ b/Tests/NIOSOCKSTests/ClientStateMachine+Tests.swift @@ -13,65 +13,70 @@ //===----------------------------------------------------------------------===// import NIOCore -@testable import NIOSOCKS import XCTest +@testable import NIOSOCKS + public class ClientStateMachineTests: XCTestCase { - + func testUsualWorkflow() { - + // create state machine and immediately connect var stateMachine = ClientStateMachine() XCTAssertTrue(stateMachine.shouldBeginHandshake) XCTAssertNoThrow(XCTAssertEqual(try stateMachine.connectionEstablished(), .sendGreeting)) XCTAssertFalse(stateMachine.proxyEstablished) - + // send the client greeting XCTAssertNoThrow(try stateMachine.sendClientGreeting(.init(methods: [.noneRequired]))) XCTAssertFalse(stateMachine.shouldBeginHandshake) XCTAssertFalse(stateMachine.proxyEstablished) - + // provide the given server greeting, check what to do next var serverGreeting = ByteBuffer(bytes: [0x05, 0x00]) XCTAssertNoThrow(XCTAssertEqual(try stateMachine.receiveBuffer(&serverGreeting), .sendRequest)) XCTAssertFalse(stateMachine.shouldBeginHandshake) XCTAssertFalse(stateMachine.proxyEstablished) - + // finish authentication XCTAssertFalse(stateMachine.shouldBeginHandshake) XCTAssertFalse(stateMachine.proxyEstablished) - + // send the client request - XCTAssertNoThrow(try stateMachine.sendClientRequest(.init(command: .bind, addressType: .address(try! .init(ipAddress: "192.168.1.1", port: 80))))) + XCTAssertNoThrow( + try stateMachine.sendClientRequest( + .init(command: .bind, addressType: .address(try! .init(ipAddress: "192.168.1.1", port: 80))) + ) + ) XCTAssertFalse(stateMachine.shouldBeginHandshake) XCTAssertFalse(stateMachine.proxyEstablished) - + // recieve server response var serverResponse = ByteBuffer(bytes: [0x05, 0x00, 0x00, 0x01, 0x01, 0x02, 0x03, 0x04, 0x00, 0x50]) XCTAssertNoThrow(XCTAssertEqual(try stateMachine.receiveBuffer(&serverResponse), .proxyEstablished)) - + // proxy should be good to go XCTAssertFalse(stateMachine.shouldBeginHandshake) XCTAssertTrue(stateMachine.proxyEstablished) } - + // Once an error occurs the state machine // should refuse to progress further, as // the connection should instead be closed. func testErrorsAreHandled() { - + // prepare the state machine var stateMachine = ClientStateMachine() XCTAssertNoThrow(XCTAssertEqual(try stateMachine.connectionEstablished(), .sendGreeting)) XCTAssertNoThrow(try stateMachine.sendClientGreeting(.init(methods: [.noneRequired]))) - + // write some invalid bytes from the server // the state machine should throw var buffer = ByteBuffer(bytes: [0xFF, 0xFF]) XCTAssertThrowsError(try stateMachine.receiveBuffer(&buffer)) { e in XCTAssertTrue(e is SOCKSError.InvalidProtocolVersion) } - + // Now write some valid bytes. This time // the state machine should throw an // UnexpectedRead, as we should have closed diff --git a/Tests/NIOSOCKSTests/Helpers+Tests.swift b/Tests/NIOSOCKSTests/Helpers+Tests.swift index ad2eb189..30135d3f 100644 --- a/Tests/NIOSOCKSTests/Helpers+Tests.swift +++ b/Tests/NIOSOCKSTests/Helpers+Tests.swift @@ -13,35 +13,40 @@ //===----------------------------------------------------------------------===// import NIOCore -@testable import NIOSOCKS import XCTest +@testable import NIOSOCKS + public class HelperTests: XCTestCase { - + // Returning nil should unwind the changes func testUnwindingReturnNil() { var buffer = ByteBuffer(bytes: [1, 2, 3, 4, 5]) - XCTAssertNil(buffer.parseUnwindingIfNeeded { buffer -> Int? in - XCTAssertEqual(buffer.readBytes(length: 5), [1, 2, 3, 4, 5]) - return nil - }) + XCTAssertNil( + buffer.parseUnwindingIfNeeded { buffer -> Int? in + XCTAssertEqual(buffer.readBytes(length: 5), [1, 2, 3, 4, 5]) + return nil + } + ) XCTAssertEqual(buffer, ByteBuffer(bytes: [1, 2, 3, 4, 5])) } - + func testUnwindingThrowError() { - + struct TestError: Error, Hashable {} - + var buffer = ByteBuffer(bytes: [1, 2, 3, 4, 5]) - XCTAssertThrowsError(try buffer.parseUnwindingIfNeeded { buffer -> Int? in - XCTAssertEqual(buffer.readBytes(length: 5), [1, 2, 3, 4, 5]) - throw TestError() - }) { e in + XCTAssertThrowsError( + try buffer.parseUnwindingIfNeeded { buffer -> Int? in + XCTAssertEqual(buffer.readBytes(length: 5), [1, 2, 3, 4, 5]) + throw TestError() + } + ) { e in XCTAssertEqual(e as? TestError, TestError()) } XCTAssertEqual(buffer, ByteBuffer(bytes: [1, 2, 3, 4, 5])) } - + // If we don't return nil and don't throw an error then all should be good func testUnwindingNotRequired() { var buffer = ByteBuffer(bytes: [1, 2, 3, 4, 5]) @@ -50,5 +55,5 @@ public class HelperTests: XCTestCase { } XCTAssertEqual(buffer, ByteBuffer(bytes: [])) } - + } diff --git a/Tests/NIOSOCKSTests/MethodSelection+Tests.swift b/Tests/NIOSOCKSTests/MethodSelection+Tests.swift index 7a3c7c2a..60c59005 100644 --- a/Tests/NIOSOCKSTests/MethodSelection+Tests.swift +++ b/Tests/NIOSOCKSTests/MethodSelection+Tests.swift @@ -13,22 +13,23 @@ //===----------------------------------------------------------------------===// import NIOCore -@testable import NIOSOCKS import XCTest +@testable import NIOSOCKS + public class MethodSelectionTests: XCTestCase { - + func testReadFromByteBuffer() { var buffer = ByteBuffer(bytes: [0x05, 0x00]) XCTAssertEqual(buffer.readableBytes, 2) XCTAssertNoThrow(XCTAssertEqual(try buffer.readMethodSelection(), .init(method: .noneRequired))) XCTAssertEqual(buffer.readableBytes, 0) } - + func testWriteToByteBuffer() { var buffer = ByteBuffer() XCTAssertEqual(buffer.writeMethodSelection(.init(method: .noneRequired)), 2) XCTAssertEqual(buffer, ByteBuffer(bytes: [0x05, 0x00])) } - + } diff --git a/Tests/NIOSOCKSTests/SOCKSServerHandshakeHandler+Tests.swift b/Tests/NIOSOCKSTests/SOCKSServerHandshakeHandler+Tests.swift index 04cf0461..d75366ee 100644 --- a/Tests/NIOSOCKSTests/SOCKSServerHandshakeHandler+Tests.swift +++ b/Tests/NIOSOCKSTests/SOCKSServerHandshakeHandler+Tests.swift @@ -14,22 +14,23 @@ import NIOCore import NIOEmbedded -@testable import NIOSOCKS import XCTest +@testable import NIOSOCKS + class PromiseTestHandler: ChannelInboundHandler { typealias InboundIn = ClientMessage - + let expectedGreeting: ClientGreeting let expectedRequest: SOCKSRequest let expectedData: ByteBuffer - + var hadGreeting: Bool = false var hadRequest: Bool = false var hadData: Bool = false - + var hadSOCKSEstablishedProxyUserEvent: Bool = false - + public init( expectedGreeting: ClientGreeting, expectedRequest: SOCKSRequest, @@ -39,7 +40,7 @@ class PromiseTestHandler: ChannelInboundHandler { self.expectedRequest = expectedRequest self.expectedData = expectedData } - + func channelRead(context: ChannelHandlerContext, data: NIOAny) { let message = self.unwrapInboundIn(data) switch message { @@ -54,7 +55,7 @@ class PromiseTestHandler: ChannelInboundHandler { hadData = true } } - + func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { switch event { case is SOCKSProxyEstablishedEvent: @@ -67,10 +68,10 @@ class PromiseTestHandler: ChannelInboundHandler { } class SOCKSServerHandlerTests: XCTestCase { - + var channel: EmbeddedChannel! var handler: SOCKSServerHandshakeHandler! - + override func setUp() { XCTAssertNil(self.channel) self.handler = SOCKSServerHandshakeHandler() @@ -81,7 +82,7 @@ class SOCKSServerHandlerTests: XCTestCase { XCTAssertNotNil(self.channel) self.channel = nil } - + func assertOutputBuffer(_ bytes: [UInt8], line: UInt = #line) { do { if var buffer = try self.channel.readOutbound(as: ByteBuffer.self) { @@ -93,15 +94,15 @@ class SOCKSServerHandlerTests: XCTestCase { XCTFail("\(error)", line: line) } } - + func writeOutbound(_ message: ServerMessage, line: UInt = #line) { XCTAssertNoThrow(try self.channel.writeOutbound(message), line: line) } - + func writeInbound(_ bytes: [UInt8], line: UInt = #line) { XCTAssertNoThrow(try self.channel.writeInbound(ByteBuffer(bytes: bytes)), line: line) } - + func assertInbound(_ bytes: [UInt8], line: UInt = #line) { do { if var buffer = try self.channel.readInbound(as: ByteBuffer.self) { @@ -113,7 +114,7 @@ class SOCKSServerHandlerTests: XCTestCase { XCTFail("\(error)") } } - + func assertInbound(_ message: ClientMessage, line: UInt = #line) { do { if let actual = try self.channel.readInbound(as: ClientMessage.self) { @@ -125,10 +126,13 @@ class SOCKSServerHandlerTests: XCTestCase { XCTFail("\(error)", line: line) } } - + func testTypicalWorkflow() { let expectedGreeting = ClientGreeting(methods: [.init(value: 0xAA)]) - let expectedRequest = SOCKSRequest(command: .connect, addressType: .address(try! .init(ipAddress: "127.0.0.1", port: 80))) + let expectedRequest = SOCKSRequest( + command: .connect, + addressType: .address(try! .init(ipAddress: "127.0.0.1", port: 80)) + ) let expectedData = ByteBuffer(bytes: [0x01, 0x02, 0x03, 0x04]) let testHandler = PromiseTestHandler( expectedGreeting: expectedGreeting, @@ -136,39 +140,44 @@ class SOCKSServerHandlerTests: XCTestCase { expectedData: expectedData ) XCTAssertNoThrow(try self.channel.pipeline.addHandler(testHandler).wait()) - + // wait for the greeting XCTAssertFalse(testHandler.hadGreeting) self.writeInbound([0x05, 0x01, 0xAA]) XCTAssertTrue(testHandler.hadGreeting) - + // write the auth selection self.writeOutbound(.selectedAuthenticationMethod(.init(method: .init(value: 0xAA)))) self.assertOutputBuffer([0x05, 0xAA]) - + XCTAssertFalse(testHandler.hadData) self.writeInbound([0x01, 0x02, 0x03, 0x04]) XCTAssertTrue(testHandler.hadData) - + // finish authentication - nothing should be written // as this is informing the state machine only self.writeOutbound(.authenticationData(ByteBuffer(bytes: [0xFF, 0xFF]), complete: true)) self.assertOutputBuffer([0xFF, 0xFF]) - + // write the request XCTAssertFalse(testHandler.hadRequest) self.writeInbound([0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80]) XCTAssertTrue(testHandler.hadRequest) XCTAssertFalse(testHandler.hadSOCKSEstablishedProxyUserEvent) - self.writeOutbound(.response(.init(reply: .succeeded, boundAddress: .address(try! .init(ipAddress: "127.0.0.1", port: 80))))) + self.writeOutbound( + .response(.init(reply: .succeeded, boundAddress: .address(try! .init(ipAddress: "127.0.0.1", port: 80)))) + ) XCTAssertTrue(testHandler.hadSOCKSEstablishedProxyUserEvent) self.assertOutputBuffer([0x05, 0x00, 0x00, 0x01, 127, 0, 0, 1, 0, 80]) } - + // tests dripfeeding to ensure we buffer data correctly func testTypicalWorkflowDripfeed() { let expectedGreeting = ClientGreeting(methods: [.gssapi]) - let expectedRequest = SOCKSRequest(command: .connect, addressType: .address(try! .init(ipAddress: "127.0.0.1", port: 80))) + let expectedRequest = SOCKSRequest( + command: .connect, + addressType: .address(try! .init(ipAddress: "127.0.0.1", port: 80)) + ) let expectedData = ByteBuffer(string: "1234") let testHandler = PromiseTestHandler( expectedGreeting: expectedGreeting, @@ -176,7 +185,7 @@ class SOCKSServerHandlerTests: XCTestCase { expectedData: expectedData ) XCTAssertNoThrow(try self.channel.pipeline.addHandler(testHandler).wait()) - + // wait for the greeting XCTAssertFalse(testHandler.hadGreeting) self.writeInbound([0x05]) @@ -186,15 +195,21 @@ class SOCKSServerHandlerTests: XCTestCase { self.writeInbound([0x01]) self.assertOutputBuffer([]) XCTAssertTrue(testHandler.hadGreeting) - + // write the auth selection - XCTAssertNoThrow(try self.channel.writeOutbound(ServerMessage.selectedAuthenticationMethod(.init(method: .gssapi)))) + XCTAssertNoThrow( + try self.channel.writeOutbound(ServerMessage.selectedAuthenticationMethod(.init(method: .gssapi))) + ) self.assertOutputBuffer([0x05, 0x01]) - + // finish authentication with some bytes - XCTAssertNoThrow(try self.channel.writeOutbound(ServerMessage.authenticationData(ByteBuffer(bytes: [0xFF, 0xFF]), complete: true))) + XCTAssertNoThrow( + try self.channel.writeOutbound( + ServerMessage.authenticationData(ByteBuffer(bytes: [0xFF, 0xFF]), complete: true) + ) + ) self.assertOutputBuffer([0xFF, 0xFF]) - + // write the request XCTAssertFalse(testHandler.hadRequest) self.writeInbound([0x05, 0x01]) @@ -204,7 +219,7 @@ class SOCKSServerHandlerTests: XCTestCase { self.writeInbound([127, 0, 0, 1, 0, 80]) XCTAssertTrue(testHandler.hadRequest) } - + // write nonsense bytes that should be caught inbound func testInboundErrorsAreHandled() { let buffer = ByteBuffer(bytes: [0xFF, 0xFF, 0xFF]) @@ -212,24 +227,28 @@ class SOCKSServerHandlerTests: XCTestCase { XCTAssertTrue(e is SOCKSError.InvalidProtocolVersion) } } - + // write something that will be be invalid for the state machine's // current state, causing an error to be thrown func testOutboundErrorsAreHandled() { - XCTAssertThrowsError(try self.channel.writeAndFlush(ServerMessage.authenticationData(ByteBuffer(bytes: [0xFF, 0xFF]), complete: true)).wait()) { e in + XCTAssertThrowsError( + try self.channel.writeAndFlush( + ServerMessage.authenticationData(ByteBuffer(bytes: [0xFF, 0xFF]), complete: true) + ).wait() + ) { e in XCTAssertTrue(e is SOCKSError.InvalidServerState) } } - + func testFlushOnHandlerRemoved() { self.writeInbound([0x05, 0x01]) self.assertInbound([]) XCTAssertNoThrow(try self.channel.pipeline.removeHandler(self.handler).wait()) self.assertInbound([0x05, 0x01]) } - + func testForceHandlerRemovalAfterAuth() { - + // go through auth self.writeInbound([0x05, 0x01, 0x01]) self.writeOutbound(.selectedAuthenticationMethod(.init(method: .gssapi))) @@ -237,54 +256,64 @@ class SOCKSServerHandlerTests: XCTestCase { self.writeOutbound(.authenticationData(ByteBuffer(), complete: true)) self.assertOutputBuffer([]) self.writeInbound([0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80]) - self.writeOutbound(.response(.init(reply: .succeeded, boundAddress: .address(try! .init(ipAddress: "127.0.0.1", port: 80))))) + self.writeOutbound( + .response(.init(reply: .succeeded, boundAddress: .address(try! .init(ipAddress: "127.0.0.1", port: 80)))) + ) self.assertOutputBuffer([0x05, 0x00, 0x00, 0x01, 127, 0, 0, 1, 0, 80]) - + // auth complete, try to write data without // removing the handler, it should fail - XCTAssertThrowsError(try self.channel.writeOutbound(ServerMessage.authenticationData(ByteBuffer(string: "hello, world!"), complete: false))) + XCTAssertThrowsError( + try self.channel.writeOutbound( + ServerMessage.authenticationData(ByteBuffer(string: "hello, world!"), complete: false) + ) + ) } - + func testAutoAuthenticationComplete() { - + // server selects none-required, this should mean we can continue without // having to manually inform the state machine self.writeInbound([0x05, 0x01, 0x00]) self.writeOutbound(.selectedAuthenticationMethod(.init(method: .noneRequired))) self.assertOutputBuffer([0x05, 0x00]) - + // if we try and write the request then the data would be read // as authentication data, and so the server wouldn't reply // with a response self.writeInbound([0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80]) - self.writeOutbound(.response(.init(reply: .succeeded, boundAddress: .address(try! .init(ipAddress: "127.0.0.1", port: 80))))) + self.writeOutbound( + .response(.init(reply: .succeeded, boundAddress: .address(try! .init(ipAddress: "127.0.0.1", port: 80)))) + ) self.assertOutputBuffer([0x05, 0x00, 0x00, 0x01, 127, 0, 0, 1, 0, 80]) } - + func testAutoAuthenticationCompleteWithManualCompletion() { - + // server selects none-required, this should mean we can continue without // having to manually inform the state machine. However, informing the state // machine manually shouldn't break anything. self.writeInbound([0x05, 0x01, 0x00]) self.writeOutbound(.selectedAuthenticationMethod(.init(method: .noneRequired))) self.assertOutputBuffer([0x05, 0x00]) - + // complete authentication, but nothing should be written // to the network self.writeOutbound(.authenticationData(ByteBuffer(), complete: true)) self.assertOutputBuffer([]) - + // if we try and write the request then the data would be read // as authentication data, and so the server wouldn't reply // with a response self.writeInbound([0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80]) - self.writeOutbound(.response(.init(reply: .succeeded, boundAddress: .address(try! .init(ipAddress: "127.0.0.1", port: 80))))) + self.writeOutbound( + .response(.init(reply: .succeeded, boundAddress: .address(try! .init(ipAddress: "127.0.0.1", port: 80)))) + ) self.assertOutputBuffer([0x05, 0x00, 0x00, 0x01, 127, 0, 0, 1, 0, 80]) } - + func testEagerClientRequestBeforeAuthenticationComplete() { - + // server selects none-required, this should mean we can continue without // having to manually inform the state machine. However, informing the state // machine manually shouldn't break anything. @@ -292,14 +321,14 @@ class SOCKSServerHandlerTests: XCTestCase { self.assertInbound(.greeting(.init(methods: [.gssapi]))) self.writeOutbound(.selectedAuthenticationMethod(.init(method: .gssapi))) self.assertOutputBuffer([0x05, 0x01]) - + // at this point authentication isn't complete // so if the client sends a request then the // server will read those as authentication bytes self.writeInbound([0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80]) self.assertInbound(.authenticationData(ByteBuffer(bytes: [0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80]))) } - + func testManualAuthenticationFailureExtraBytes() { // server selects none-required, this should mean we can continue without // having to manually inform the state machine. However, informing the state @@ -307,13 +336,15 @@ class SOCKSServerHandlerTests: XCTestCase { self.writeInbound([0x05, 0x01, 0x00]) self.writeOutbound(.selectedAuthenticationMethod(.init(method: .noneRequired))) self.assertOutputBuffer([0x05, 0x00]) - + // invalid authentication completion // we've selected `noneRequired`, so no // bytes should be written - XCTAssertThrowsError(try self.channel.writeOutbound(ServerMessage.authenticationData(ByteBuffer(bytes: [0x00]), complete: true))) + XCTAssertThrowsError( + try self.channel.writeOutbound(ServerMessage.authenticationData(ByteBuffer(bytes: [0x00]), complete: true)) + ) } - + func testManualAuthenticationFailureInvalidCompletion() { // server selects none-required, this should mean we can continue without // having to manually inform the state machine. However, informing the state @@ -321,11 +352,13 @@ class SOCKSServerHandlerTests: XCTestCase { self.writeInbound([0x05, 0x01, 0x00]) self.writeOutbound(.selectedAuthenticationMethod(.init(method: .noneRequired))) self.assertOutputBuffer([0x05, 0x00]) - + // invalid authentication completion // authentication should have already completed // as we selected `noneRequired`, so sending // `complete = false` should be an error - XCTAssertThrowsError(try self.channel.writeOutbound(ServerMessage.authenticationData(ByteBuffer(bytes: []), complete: false))) + XCTAssertThrowsError( + try self.channel.writeOutbound(ServerMessage.authenticationData(ByteBuffer(bytes: []), complete: false)) + ) } } diff --git a/Tests/NIOSOCKSTests/ServerResponse+Tests.swift b/Tests/NIOSOCKSTests/ServerResponse+Tests.swift index d35f83ce..f69c67ac 100644 --- a/Tests/NIOSOCKSTests/ServerResponse+Tests.swift +++ b/Tests/NIOSOCKSTests/ServerResponse+Tests.swift @@ -13,21 +13,26 @@ //===----------------------------------------------------------------------===// import NIOCore -@testable import NIOSOCKS import XCTest +@testable import NIOSOCKS + public class ServerResponseTests: XCTestCase { } // MARK: - ServeResponse extension ServerResponseTests { - + func testServerResponseReadFromByteBuffer() { var buffer = ByteBuffer(bytes: [0x05, 0x00, 0x00, 0x01, 0x01, 0x02, 0x03, 0x04, 0x00, 0x50]) XCTAssertEqual(buffer.readableBytes, 10) - XCTAssertNoThrow(XCTAssertEqual(try buffer.readServerResponse(), - .init(reply: .succeeded, boundAddress: .address(try! .init(ipAddress: "1.2.3.4", port: 80))))) + XCTAssertNoThrow( + XCTAssertEqual( + try buffer.readServerResponse(), + .init(reply: .succeeded, boundAddress: .address(try! .init(ipAddress: "1.2.3.4", port: 80))) + ) + ) XCTAssertEqual(buffer.readableBytes, 0) } - + } diff --git a/Tests/NIOSOCKSTests/ServerStateMachine+Tests.swift b/Tests/NIOSOCKSTests/ServerStateMachine+Tests.swift index b9c40525..45556f5d 100644 --- a/Tests/NIOSOCKSTests/ServerStateMachine+Tests.swift +++ b/Tests/NIOSOCKSTests/ServerStateMachine+Tests.swift @@ -13,59 +13,60 @@ //===----------------------------------------------------------------------===// import NIOCore -@testable import NIOSOCKS import XCTest +@testable import NIOSOCKS + public class ServerStateMachineTests: XCTestCase { - + func testUsualWorkflow() { - + // create state machine and immediately connect var stateMachine = ServerStateMachine() XCTAssertNoThrow(try stateMachine.connectionEstablished()) XCTAssertFalse(stateMachine.proxyEstablished) - + // send the client greeting var greeting = ByteBuffer(bytes: [0x05, 0x01, 0x00]) XCTAssertNoThrow(try stateMachine.receiveBuffer(&greeting)) XCTAssertFalse(stateMachine.proxyEstablished) - + // provide the given server greeting XCTAssertNoThrow(try stateMachine.sendAuthenticationMethod(.init(method: .noneRequired))) XCTAssertFalse(stateMachine.proxyEstablished) - + // send the client request var request = ByteBuffer(bytes: [0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80]) XCTAssertNoThrow(try stateMachine.receiveBuffer(&request)) XCTAssertFalse(stateMachine.proxyEstablished) - + // recieve server response let response = SOCKSResponse(reply: .succeeded, boundAddress: .domain("127.0.0.1", port: 80)) XCTAssertNoThrow(try stateMachine.sendServerResponse(response)) - + // proxy should be good to go XCTAssertTrue(stateMachine.proxyEstablished) } - + // Once an error occurs the state machine // should refuse to progress further, as // the connection should instead be closed. func testErrorsAreHandled() { - + // prepare the state machine var stateMachine = ServerStateMachine() var greeting = ByteBuffer(bytes: [0x05, 0x01, 0x00]) XCTAssertNoThrow(try stateMachine.connectionEstablished()) XCTAssertNoThrow(try stateMachine.receiveBuffer(&greeting)) XCTAssertNoThrow(try stateMachine.sendAuthenticationMethod(.init(method: .noneRequired))) - + // write some invalid bytes from the client // the state machine should throw var buffer = ByteBuffer(bytes: [0xFF, 0xFF]) XCTAssertThrowsError(try stateMachine.receiveBuffer(&buffer)) { e in XCTAssertTrue(e is SOCKSError.InvalidProtocolVersion) } - + // Now write some valid bytes. This time // the state machine should throw an // UnexpectedRead, as we should have closed @@ -75,7 +76,7 @@ public class ServerStateMachineTests: XCTestCase { XCTAssertTrue(e is SOCKSError.UnexpectedRead) } } - + func testBytesArentConsumedOnError() { var stateMachine = ServerStateMachine() XCTAssertNoThrow(try stateMachine.connectionEstablished()) diff --git a/Tests/NIOSOCKSTests/SocksClientHandler+Tests.swift b/Tests/NIOSOCKSTests/SocksClientHandler+Tests.swift index 53866a34..5462fdcd 100644 --- a/Tests/NIOSOCKSTests/SocksClientHandler+Tests.swift +++ b/Tests/NIOSOCKSTests/SocksClientHandler+Tests.swift @@ -14,20 +14,21 @@ import NIOCore import NIOEmbedded -@testable import NIOSOCKS import XCTest +@testable import NIOSOCKS + class SocksClientHandlerTests: XCTestCase { - + var channel: EmbeddedChannel! var handler: SOCKSClientHandler! - + override func setUp() { XCTAssertNil(self.channel) self.handler = SOCKSClientHandler(targetAddress: .address(try! .init(ipAddress: "192.168.1.1", port: 80))) self.channel = EmbeddedChannel(handler: self.handler) } - + func connect() { try! self.channel.connect(to: .init(ipAddress: "127.0.0.1", port: 80)).wait() } @@ -36,7 +37,7 @@ class SocksClientHandlerTests: XCTestCase { XCTAssertNotNil(self.channel) self.channel = nil } - + func assertOutputBuffer(_ bytes: [UInt8], line: UInt = #line) { if var buffer = try! self.channel.readOutbound(as: ByteBuffer.self) { XCTAssertEqual(buffer.readBytes(length: buffer.readableBytes), bytes, line: line) @@ -44,37 +45,37 @@ class SocksClientHandlerTests: XCTestCase { XCTFail("Expected bytes but found none") } } - + func writeInbound(_ bytes: [UInt8], line: UInt = #line) { try! self.channel.writeInbound(ByteBuffer(bytes: bytes)) } - + func assertInbound(_ bytes: [UInt8], line: UInt = #line) { var buffer = try! self.channel.readInbound(as: ByteBuffer.self) XCTAssertEqual(buffer!.readBytes(length: buffer!.readableBytes), bytes, line: line) } - + func testTypicalWorkflow() { - + let clientHandler = MockSOCKSClientHandler() XCTAssertNoThrow(try self.channel.pipeline.syncOperations.addHandler(clientHandler)) - + self.connect() - + // the client should start the handshake instantly self.assertOutputBuffer([0x05, 0x01, 0x00]) - + // server selects an authentication method self.writeInbound([0x05, 0x00]) - + // client sends the request self.assertOutputBuffer([0x05, 0x01, 0x00, 0x01, 192, 168, 1, 1, 0x00, 0x50]) - + // server replies yay XCTAssertFalse(clientHandler.hadSOCKSEstablishedProxyUserEvent) self.writeInbound([0x05, 0x00, 0x00, 0x01, 192, 168, 1, 1, 0x00, 0x50]) XCTAssertTrue(clientHandler.hadSOCKSEstablishedProxyUserEvent) - + // any inbound data should now go straight through self.writeInbound([1, 2, 3, 4, 5]) self.assertInbound([1, 2, 3, 4, 5]) @@ -83,60 +84,60 @@ class SocksClientHandlerTests: XCTestCase { XCTAssertNoThrow(try self.channel.writeOutbound(ByteBuffer(bytes: [1, 2, 3, 4, 5]))) self.assertOutputBuffer([1, 2, 3, 4, 5]) } - + // Tests that if we write alot of data at the start then // that data will be written after the client has completed // the socks handshake. func testThatBufferingWorks() { self.connect() - + let writePromise = self.channel.eventLoop.makePromise(of: Void.self) self.channel.writeAndFlush(ByteBuffer(bytes: [1, 2, 3, 4, 5]), promise: writePromise) self.assertOutputBuffer([0x05, 0x01, 0x00]) self.writeInbound([0x05, 0x00]) self.assertOutputBuffer([0x05, 0x01, 0x00, 0x01, 192, 168, 1, 1, 0x00, 0x50]) self.writeInbound([0x05, 0x00, 0x00, 0x01, 192, 168, 1, 1, 0x00, 0x50]) - + XCTAssertNoThrow(try writePromise.futureResult.wait()) self.assertOutputBuffer([1, 2, 3, 4, 5]) } - + func testBufferingWithMark() { self.connect() - + let writePromise1 = self.channel.eventLoop.makePromise(of: Void.self) let writePromise2 = self.channel.eventLoop.makePromise(of: Void.self) self.channel.write(ByteBuffer(bytes: [1, 2, 3]), promise: writePromise1) self.channel.flush() self.channel.write(ByteBuffer(bytes: [4, 5, 6]), promise: writePromise2) - + self.assertOutputBuffer([0x05, 0x01, 0x00]) self.writeInbound([0x05, 0x00]) self.assertOutputBuffer([0x05, 0x01, 0x00, 0x01, 192, 168, 1, 1, 0x00, 0x50]) self.writeInbound([0x05, 0x00, 0x00, 0x01, 192, 168, 1, 1, 0x00, 0x50]) - + XCTAssertNoThrow(try writePromise1.futureResult.wait()) self.assertOutputBuffer([1, 2, 3]) - + XCTAssertNoThrow(try self.channel.writeAndFlush(ByteBuffer(bytes: [7, 8, 9])).wait()) XCTAssertNoThrow(try writePromise2.futureResult.wait()) self.assertOutputBuffer([4, 5, 6]) self.assertOutputBuffer([7, 8, 9]) } - + func testTypicalWorkflowDripfeed() { self.connect() - + // the client should start the handshake instantly self.assertOutputBuffer([0x05, 0x01, 0x00]) - + // server selects authentication method // once the dripfeed is complete we should get the client request self.writeInbound([0x05]) self.assertOutputBuffer([]) self.writeInbound([0x00]) self.assertOutputBuffer([0x05, 0x01, 0x00, 0x01, 192, 168, 1, 1, 0x00, 0x50]) - + // drip feed server response self.writeInbound([0x05, 0x00, 0x00, 0x01]) self.assertOutputBuffer([]) @@ -145,31 +146,31 @@ class SocksClientHandlerTests: XCTestCase { self.writeInbound([1, 1]) self.assertOutputBuffer([]) self.writeInbound([0x00, 0x50]) - + // any inbound data should now go straight through self.writeInbound([1, 2, 3, 4, 5]) self.assertInbound([1, 2, 3, 4, 5]) } - + func testInvalidAuthenticationMethod() { self.connect() - + class ErrorHandler: ChannelInboundHandler { typealias InboundIn = ByteBuffer - + var promise: EventLoopPromise - + init(promise: EventLoopPromise) { self.promise = promise } - + func errorCaught(context: ChannelHandlerContext, error: Error) { promise.fail(error) } } - + self.assertOutputBuffer([0x05, 0x01, 0x00]) - + // server requests an auth method we don't support let promise = self.channel.eventLoop.makePromise(of: Void.self) try! self.channel.pipeline.addHandler(ErrorHandler(promise: promise), position: .last).wait() @@ -178,29 +179,29 @@ class SocksClientHandlerTests: XCTestCase { XCTAssertTrue(e is SOCKSError.InvalidAuthenticationSelection) } } - + func testProxyConnectionFailed() { self.connect() - + class ErrorHandler: ChannelInboundHandler { typealias InboundIn = ByteBuffer - + var promise: EventLoopPromise - + init(promise: EventLoopPromise) { self.promise = promise } - + func errorCaught(context: ChannelHandlerContext, error: Error) { promise.fail(error) } } - + // start handshake, send request self.assertOutputBuffer([0x05, 0x01, 0x00]) self.writeInbound([0x05, 0x00]) self.assertOutputBuffer([0x05, 0x01, 0x00, 0x01, 192, 168, 1, 1, 0x00, 0x50]) - + // server replies with an error let promise = self.channel.eventLoop.makePromise(of: Void.self) try! self.channel.pipeline.addHandler(ErrorHandler(promise: promise), position: .last).wait() @@ -209,45 +210,45 @@ class SocksClientHandlerTests: XCTestCase { XCTAssertEqual(e as? SOCKSError.ConnectionFailed, .init(reply: .serverFailure)) } } - + func testDelayedConnection() { // we shouldn't start the handshake until the client // has connected self.assertOutputBuffer([]) - + self.connect() - + // now the handshake should have started self.assertOutputBuffer([0x05, 0x01, 0x00]) } - + func testDelayedHandlerAdded() { - + // reset the channel that was set up automatically XCTAssertNoThrow(try self.channel.close().wait()) self.channel = EmbeddedChannel() self.handler = SOCKSClientHandler(targetAddress: .domain("127.0.0.1", port: 1234)) XCTAssertNoThrow(try self.channel.connect(to: .init(ipAddress: "127.0.0.1", port: 80)).wait()) XCTAssertTrue(self.channel.isActive) - + // there shouldn't be anything outbound self.assertOutputBuffer([]) - + // add the handler, there should be outbound data immediately XCTAssertNoThrow(self.channel.pipeline.addHandler(handler)) self.assertOutputBuffer([0x05, 0x01, 0x00]) } - + func testHandlerRemovalAfterEstablishEvent() { class SOCKSEventHandler: ChannelInboundHandler { typealias InboundIn = NIOAny - + var establishedPromise: EventLoopPromise - + init(establishedPromise: EventLoopPromise) { self.establishedPromise = establishedPromise } - + func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { switch event { case is SOCKSProxyEstablishedEvent: @@ -258,70 +259,72 @@ class SocksClientHandlerTests: XCTestCase { context.fireUserInboundEventTriggered(event) } } - + let establishPromise = self.channel.eventLoop.makePromise(of: Void.self) let removalPromise = self.channel.eventLoop.makePromise(of: Void.self) establishPromise.futureResult.whenSuccess { _ in self.channel.pipeline.removeHandler(self.handler).cascade(to: removalPromise) } - - XCTAssertNoThrow(try self.channel.pipeline.addHandler(SOCKSEventHandler(establishedPromise: establishPromise)).wait()) - + + XCTAssertNoThrow( + try self.channel.pipeline.addHandler(SOCKSEventHandler(establishedPromise: establishPromise)).wait() + ) + self.connect() - + // these writes should be buffered to be send out once the connection is established. self.channel.write(ByteBuffer(bytes: [1, 2, 3]), promise: nil) self.channel.flush() self.channel.write(ByteBuffer(bytes: [4, 5, 6]), promise: nil) - + self.assertOutputBuffer([0x05, 0x01, 0x00]) self.writeInbound([0x05, 0x00]) self.assertOutputBuffer([0x05, 0x01, 0x00, 0x01, 192, 168, 1, 1, 0x00, 0x50]) self.writeInbound([0x05, 0x00, 0x00, 0x01, 192, 168, 1, 1, 0x00, 0x50]) - + self.assertOutputBuffer([1, 2, 3]) - + XCTAssertNoThrow(try self.channel.writeAndFlush(ByteBuffer(bytes: [7, 8, 9])).wait()) - + self.assertOutputBuffer([4, 5, 6]) self.assertOutputBuffer([7, 8, 9]) - + XCTAssertNoThrow(try removalPromise.futureResult.wait()) XCTAssertThrowsError(try self.channel.pipeline.syncOperations.handler(type: SOCKSClientHandler.self)) { XCTAssertEqual($0 as? ChannelPipelineError, .notFound) } } - + func testHandlerRemovalBeforeConnectionIsEstablished() { self.connect() - + // these writes should be buffered to be send out once the connection is established. self.channel.write(ByteBuffer(bytes: [1, 2, 3]), promise: nil) self.channel.flush() self.channel.write(ByteBuffer(bytes: [4, 5, 6]), promise: nil) - + self.assertOutputBuffer([0x05, 0x01, 0x00]) self.writeInbound([0x05, 0x00]) self.assertOutputBuffer([0x05, 0x01, 0x00, 0x01, 192, 168, 1, 1, 0x00, 0x50]) - + // we try to remove the handler before the connection is established. let removalPromise = self.channel.eventLoop.makePromise(of: Void.self) self.channel.pipeline.removeHandler(self.handler, promise: removalPromise) - + // establishes the connection self.writeInbound([0x05, 0x00, 0x00, 0x01, 192, 168, 1, 1, 0x00, 0x50]) - + // write six more bytes - those should be passed through right away self.writeInbound([1, 2, 3, 4, 5, 6]) self.assertInbound([1, 2, 3, 4, 5, 6]) - + self.assertOutputBuffer([1, 2, 3]) - + XCTAssertNoThrow(try self.channel.writeAndFlush(ByteBuffer(bytes: [7, 8, 9])).wait()) - + self.assertOutputBuffer([4, 5, 6]) self.assertOutputBuffer([7, 8, 9]) - + XCTAssertNoThrow(try removalPromise.futureResult.wait()) XCTAssertThrowsError(try self.channel.pipeline.syncOperations.handler(type: SOCKSClientHandler.self)) { XCTAssertEqual($0 as? ChannelPipelineError, .notFound) @@ -331,11 +334,11 @@ class SocksClientHandlerTests: XCTestCase { class MockSOCKSClientHandler: ChannelInboundHandler { typealias InboundIn = NIOAny - + var hadSOCKSEstablishedProxyUserEvent: Bool = false - + init() {} - + func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { switch event { case is SOCKSProxyEstablishedEvent: diff --git a/docker/Dockerfile b/docker/Dockerfile deleted file mode 100644 index 2d3e61c7..00000000 --- a/docker/Dockerfile +++ /dev/null @@ -1,25 +0,0 @@ -ARG swift_version=5.7 -ARG ubuntu_version=focal -ARG base_image=swift:$swift_version-$ubuntu_version -FROM $base_image -# needed to do again after FROM due to docker limitation -ARG swift_version -ARG ubuntu_version - -# set as UTF-8 -RUN apt-get update && apt-get install -y locales locales-all -ENV LC_ALL en_US.UTF-8 -ENV LANG en_US.UTF-8 -ENV LANGUAGE en_US.UTF-8 - -# dependencies -RUN apt-get update && apt-get install -y wget -RUN apt-get update && apt-get install -y lsof dnsutils netcat-openbsd net-tools curl jq # used by integration tests -RUN apt-get update && apt-get install -y zlib1g-dev - -# ruby for soundness -RUN apt-get update && apt-get install -y ruby ruby-dev libsqlite3-dev build-essential - -# tools -RUN mkdir -p $HOME/.tools -RUN echo 'export PATH="$HOME/.tools:$PATH"' >> $HOME/.profile diff --git a/docker/docker-compose.2204.510.yaml b/docker/docker-compose.2204.510.yaml deleted file mode 100644 index 5b823e85..00000000 --- a/docker/docker-compose.2204.510.yaml +++ /dev/null @@ -1,21 +0,0 @@ -version: "3" - -services: - - runtime-setup: - image: swift-nio-extras:22.04-5.10 - build: - args: - ubuntu_version: "jammy" - swift_version: "5.10" - - test: - image: swift-nio-extras:22.04-5.10 - environment: - - IMPORT_CHECK_ARG=--explicit-target-dependency-import-check error - - documentation-check: - image: swift-nio-extras:22.04-5.10 - - shell: - image: swift-nio-extras:22.04-5.10 diff --git a/docker/docker-compose.2204.58.yaml b/docker/docker-compose.2204.58.yaml deleted file mode 100644 index 67b39947..00000000 --- a/docker/docker-compose.2204.58.yaml +++ /dev/null @@ -1,21 +0,0 @@ -version: "3" - -services: - - runtime-setup: - image: swift-nio-extras:22.04-5.8 - build: - args: - ubuntu_version: "jammy" - swift_version: "5.8" - - test: - image: swift-nio-extras:22.04-5.8 - environment: - - IMPORT_CHECK_ARG=--explicit-target-dependency-import-check error - - documentation-check: - image: swift-nio-extras:22.04-5.8 - - shell: - image: swift-nio-extras:22.04-5.8 diff --git a/docker/docker-compose.2204.59.yaml b/docker/docker-compose.2204.59.yaml deleted file mode 100644 index fec03c20..00000000 --- a/docker/docker-compose.2204.59.yaml +++ /dev/null @@ -1,21 +0,0 @@ -version: "3" - -services: - - runtime-setup: - image: swift-nio-extras:22.04-5.9 - build: - args: - ubuntu_version: "jammy" - swift_version: "5.9" - - test: - image: swift-nio-extras:22.04-5.9 - environment: - - IMPORT_CHECK_ARG=--explicit-target-dependency-import-check error - - documentation-check: - image: swift-nio-extras:22.04-5.9 - - shell: - image: swift-nio-extras:22.04-5.9 diff --git a/docker/docker-compose.2204.main.yaml b/docker/docker-compose.2204.main.yaml deleted file mode 100644 index 4a9c1331..00000000 --- a/docker/docker-compose.2204.main.yaml +++ /dev/null @@ -1,20 +0,0 @@ -version: "3" - -services: - - runtime-setup: - image: swift-nio-extras:22.04-main - build: - args: - base_image: "swiftlang/swift:nightly-main-jammy" - - test: - image: swift-nio-extras:22.04-main - environment: - - IMPORT_CHECK_ARG=--explicit-target-dependency-import-check error - - documentation-check: - image: swift-nio-extras:22.04-main - - shell: - image: swift-nio-extras:22.04-main diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml deleted file mode 100644 index ecf5f533..00000000 --- a/docker/docker-compose.yaml +++ /dev/null @@ -1,41 +0,0 @@ -# this file is not designed to be run directly -# instead, use the docker-compose.. files -# eg docker-compose -f docker/docker-compose.yaml -f docker/docker-compose.1604.41.yaml run test -version: "3" - -services: - - runtime-setup: - image: swift-nio-extras:default - build: - context: . - dockerfile: Dockerfile - - common: &common - image: swift-nio-extras:default - depends_on: [runtime-setup] - volumes: - - ~/.ssh:/root/.ssh - - ..:/code:z - working_dir: /code - cap_drop: - - CAP_NET_RAW - - CAP_NET_BIND_SERVICE - - soundness: - <<: *common - command: /bin/bash -xcl "./scripts/soundness.sh" - - documentation-check: - <<: *common - command: /bin/bash -xcl "./scripts/check-docs.sh" - - test: - <<: *common - command: /bin/bash -xcl "cat /etc/lsb-release && swift -version && swift test -Xswiftc -warnings-as-errors --enable-test-discovery $${SANITIZER_ARG-} $${IMPORT_CHECK_ARG-}" - - # util - - shell: - <<: *common - entrypoint: /bin/bash diff --git a/scripts/check-docs.sh b/scripts/check-docs.sh deleted file mode 100755 index 9405b7eb..00000000 --- a/scripts/check-docs.sh +++ /dev/null @@ -1,23 +0,0 @@ -#!/bin/bash -##===----------------------------------------------------------------------===## -## -## This source file is part of the SwiftNIO open source project -## -## Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors -## Licensed under Apache License v2.0 -## -## See LICENSE.txt for license information -## See CONTRIBUTORS.txt for the list of SwiftNIO project authors -## -## SPDX-License-Identifier: Apache-2.0 -## -##===----------------------------------------------------------------------===## - -set -eu - -raw_targets=$(sed -E -n -e 's/^.* - documentation_targets: \[(.*)\].*$/\1/p' .spi.yml) -targets=(${raw_targets//,/ }) - -for target in "${targets[@]}"; do - swift package plugin generate-documentation --target "$target" --warnings-as-errors --analyze --level detailed -done diff --git a/scripts/check_no_api_breakages.sh b/scripts/check_no_api_breakages.sh deleted file mode 100755 index ec3dc312..00000000 --- a/scripts/check_no_api_breakages.sh +++ /dev/null @@ -1,130 +0,0 @@ -#!/bin/bash -##===----------------------------------------------------------------------===## -## -## This source file is part of the SwiftNIO open source project -## -## Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors -## Licensed under Apache License v2.0 -## -## See LICENSE.txt for license information -## See CONTRIBUTORS.txt for the list of SwiftNIO project authors -## -## SPDX-License-Identifier: Apache-2.0 -## -##===----------------------------------------------------------------------===## - -set -eu - -# repodir -function all_modules() { - local repodir="$1" - ( - set -eu - cd "$repodir" - swift package dump-package | jq '.products | - map(select(.type | has("library") )) | - map(.name) | .[]' | tr -d '"' - ) -} - -# repodir tag output -function build_and_do() { - local repodir=$1 - local tag=$2 - local output=$3 - - ( - cd "$repodir" - git checkout -q "$tag" - swift build - while read -r module; do - swift api-digester -sdk "$sdk" -dump-sdk -module "$module" \ - -o "$output/$module.json" -I "$repodir/.build/debug" - done < <(all_modules "$repodir") - ) -} - -function usage() { - echo >&2 "Usage: $0 REPO-GITHUB-URL NEW-VERSION OLD-VERSIONS..." - echo >&2 - echo >&2 "This script requires a Swift 5.1+ toolchain." - echo >&2 - echo >&2 "Examples:" - echo >&2 - echo >&2 "Check between main and tag 1.2.0 of swift-nio-extras:" - echo >&2 " $0 https://github.com/apple/swift-nio-extras main 1.2.0" - echo >&2 - echo >&2 "Check between HEAD and commit 64cf63d7 using the provided toolchain:" - echo >&2 " xcrun --toolchain org.swift.5120190702a $0 ../some-local-repo HEAD 64cf63d7" -} - -if [[ $# -lt 3 ]]; then - usage - exit 1 -fi - -sdk=/ -if [[ "$(uname -s)" == Darwin ]]; then - sdk=$(xcrun --show-sdk-path) -fi - -hash jq 2> /dev/null || { echo >&2 "ERROR: jq must be installed"; exit 1; } -tmpdir=$(mktemp -d /tmp/.check-api_XXXXXX) -repo_url=$1 -new_tag=$2 -shift 2 - -repodir="$tmpdir/repo" -git clone "$repo_url" "$repodir" -git -C "$repodir" fetch -q origin '+refs/pull/*:refs/remotes/origin/pr/*' -errors=0 - -for old_tag in "$@"; do - mkdir "$tmpdir/api-old" - mkdir "$tmpdir/api-new" - - echo "Checking public API breakages from $old_tag to $new_tag" - - build_and_do "$repodir" "$new_tag" "$tmpdir/api-new/" - build_and_do "$repodir" "$old_tag" "$tmpdir/api-old/" - - for f in "$tmpdir/api-new"/*; do - f=$(basename "$f") - report="$tmpdir/$f.report" - if [[ ! -f "$tmpdir/api-old/$f" ]]; then - echo "NOTICE: NEW MODULE $f" - continue - fi - - echo -n "Checking $f... " - - # Since 5.2 on Linux setting sdk to / has started reporting errors - see SR-11143 - if [[ "$(uname -s)" == "Darwin" || "$(swift --version | grep -c 'swift-5.1')" != 0 ]]; then - swift api-digester -sdk "$sdk" -diagnose-sdk \ - --input-paths "$tmpdir/api-old/$f" -input-paths "$tmpdir/api-new/$f" 2>&1 \ - > "$report" 2>&1 - else - swift api-digester -diagnose-sdk \ - --input-paths "$tmpdir/api-old/$f" -input-paths "$tmpdir/api-new/$f" 2>&1 \ - > "$report" 2>&1 - fi - - if ! shasum "$report" | grep -q afd2a1b542b33273920d65821deddc653063c700; then - echo ERROR - echo >&2 "==============================" - echo >&2 "ERROR: public API change in $f" - echo >&2 "==============================" - cat >&2 "$report" - errors=$(( errors + 1 )) - else - echo OK - fi - done - rm -rf "$tmpdir/api-new" "$tmpdir/api-old" -done - -if [[ "$errors" == 0 ]]; then - echo "OK, all seems good" -fi -echo done -exit "$errors" diff --git a/scripts/soundness.sh b/scripts/soundness.sh deleted file mode 100755 index 8ab47356..00000000 --- a/scripts/soundness.sh +++ /dev/null @@ -1,142 +0,0 @@ -#!/bin/bash -##===----------------------------------------------------------------------===## -## -## This source file is part of the SwiftNIO open source project -## -## Copyright (c) 2017-2019 Apple Inc. and the SwiftNIO project authors -## Licensed under Apache License v2.0 -## -## See LICENSE.txt for license information -## See CONTRIBUTORS.txt for the list of SwiftNIO project authors -## -## SPDX-License-Identifier: Apache-2.0 -## -##===----------------------------------------------------------------------===## - -set -eu -here="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" - -function replace_acceptable_years() { - # this needs to replace all acceptable forms with 'YEARS' - sed -e 's/20[12][7890123]-20[12][7890123]/YEARS/' -e 's/20[12][890123]/YEARS/' -} - -printf "=> Checking for unacceptable language... " -# This greps for unacceptable terminology. The square bracket[s] are so that -# "git grep" doesn't find the lines that greps :). -unacceptable_terms=( - -e blacklis[t] - -e whitelis[t] - -e slav[e] - -e sanit[y] -) - -# We have to exclude the code of conduct as it gives examples of unacceptable -# language. -if git grep --color=never -i "${unacceptable_terms[@]}" -- . ":(exclude)CODE_OF_CONDUCT.md" > /dev/null; then - printf "\033[0;31mUnacceptable language found.\033[0m\n" - git grep -i "${unacceptable_terms[@]}" -- . ":(exclude)CODE_OF_CONDUCT.md" - exit 1 -fi -printf "\033[0;32mokay.\033[0m\n" - -# This checks for the umbrella NIO module. -printf "=> Checking for imports of umbrella NIO module... " -if git grep --color=never -i "^[ \t]*import \+NIO[ \t]*$" > /dev/null; then - printf "\033[0;31mUmbrella imports found.\033[0m\n" - git grep -i "^[ \t]*import \+NIO[ \t]*$" - exit 1 -fi -printf "\033[0;32mokay.\033[0m\n" - -printf "=> Checking license headers... " -tmp=$(mktemp /tmp/.swift-nio-soundness_XXXXXX) - -for language in swift-or-c bash dtrace; do - declare -a matching_files - declare -a exceptions - expections=( ) - matching_files=( -name '*' ) - case "$language" in - swift-or-c) - exceptions=( -name c_nio_http_parser.c -o -name c_nio_http_parser.h -o -name cpp_magic.h -o -name Package.swift -o -name CNIOSHA1.h -o -name c_nio_sha1.c -o -name ifaddrs-android.c -o -name ifaddrs-android.h -o -name 'Package@swift*.swift') - matching_files=( -name '*.swift' -o -name '*.c' -o -name '*.h' ) - cat > "$tmp" <<"EOF" -//===----------------------------------------------------------------------===// -// -// This source file is part of the SwiftNIO open source project -// -// Copyright (c) YEARS Apple Inc. and the SwiftNIO project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of SwiftNIO project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -EOF - ;; - bash) - matching_files=( -name '*.sh' ) - cat > "$tmp" <<"EOF" -#!/bin/bash -##===----------------------------------------------------------------------===## -## -## This source file is part of the SwiftNIO open source project -## -## Copyright (c) YEARS Apple Inc. and the SwiftNIO project authors -## Licensed under Apache License v2.0 -## -## See LICENSE.txt for license information -## See CONTRIBUTORS.txt for the list of SwiftNIO project authors -## -## SPDX-License-Identifier: Apache-2.0 -## -##===----------------------------------------------------------------------===## -EOF - ;; - dtrace) - matching_files=( -name '*.d' ) - cat > "$tmp" <<"EOF" -#!/usr/sbin/dtrace -q -s -/*===----------------------------------------------------------------------===* - * - * This source file is part of the SwiftNIO open source project - * - * Copyright (c) YEARS Apple Inc. and the SwiftNIO project authors - * Licensed under Apache License v2.0 - * - * See LICENSE.txt for license information - * See CONTRIBUTORS.txt for the list of SwiftNIO project authors - * - * SPDX-License-Identifier: Apache-2.0 - * - *===----------------------------------------------------------------------===*/ -EOF - ;; - *) - echo >&2 "ERROR: unknown language '$language'" - ;; - esac - - expected_lines=$(cat "$tmp" | wc -l) - expected_sha=$(cat "$tmp" | shasum) - - ( - cd "$here/.." - find . \ - \( \! -path './.build/*' -a \ - \( "${matching_files[@]}" \) -a \ - \( \! \( "${exceptions[@]}" \) \) \) | while read line; do - if [[ "$(cat "$line" | replace_acceptable_years | head -n $expected_lines | shasum)" != "$expected_sha" ]]; then - printf "\033[0;31mmissing headers in file '$line'!\033[0m\n" - diff -u <(cat "$line" | replace_acceptable_years | head -n $expected_lines) "$tmp" - exit 1 - fi - done - printf "\033[0;32mokay.\033[0m\n" - ) -done - -rm "$tmp" From 3a47b8e1e72c4ab066e238f5bbcd6f0b16748839 Mon Sep 17 00:00:00 2001 From: Rick Newton-Rogers Date: Fri, 25 Oct 2024 16:32:35 +0100 Subject: [PATCH 2/3] add main.yml --- .github/workflows/main.yml | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 .github/workflows/main.yml diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 00000000..537cc88b --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,17 @@ +name: Main + +on: + push: + branches: [main] + +jobs: + unit-tests: + name: Unit tests + uses: apple/swift-nio/.github/workflows/unit_tests.yml@main + with: + linux_5_8_enabled: false + linux_5_9_arguments_override: "-Xswiftc -warnings-as-errors --explicit-target-dependency-import-check error" + linux_5_10_arguments_override: "-Xswiftc -warnings-as-errors --explicit-target-dependency-import-check error" + linux_6_0_arguments_override: "-Xswiftc -warnings-as-errors --explicit-target-dependency-import-check error -Xswiftc -require-explicit-sendable" + linux_nightly_6_0_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -require-explicit-sendable" + linux_nightly_main_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -require-explicit-sendable" From c198c58dc7ac8558e7860543bd803e48e33c0970 Mon Sep 17 00:00:00 2001 From: Rick Newton-Rogers Date: Mon, 28 Oct 2024 11:07:04 +0000 Subject: [PATCH 3/3] Sendability warnings --- .../HTTPServerWithQuiescingDemo/main.swift | 9 +++--- .../NIOExtras/HTTP1ProxyConnectHandler.swift | 3 +- Sources/NIOWritePartialPCAPDemo/main.swift | 29 ++++++++++--------- Tests/NIOExtrasTests/PCAPRingBufferTest.swift | 23 +++++---------- .../SynchronizedFileSinkTests.swift | 4 +-- 5 files changed, 32 insertions(+), 36 deletions(-) diff --git a/Sources/HTTPServerWithQuiescingDemo/main.swift b/Sources/HTTPServerWithQuiescingDemo/main.swift index 491fc241..f9da3985 100644 --- a/Sources/HTTPServerWithQuiescingDemo/main.swift +++ b/Sources/HTTPServerWithQuiescingDemo/main.swift @@ -31,8 +31,9 @@ private final class HTTPHandler: ChannelInboundHandler { self.wrapOutboundOut(.head(HTTPResponseHead(version: head.version, status: .badRequest))), promise: nil ) + let loopBoundContext = NIOLoopBound.init(context, eventLoop: context.eventLoop) context.writeAndFlush(self.wrapOutboundOut(.end(nil))).whenComplete { (_: Result<(), Error>) in - context.close(promise: nil) + loopBoundContext.value.close(promise: nil) } return } @@ -59,10 +60,10 @@ private final class HTTPHandler: ChannelInboundHandler { context.writeAndFlush(self.wrapOutboundOut(.body(.byteBuffer(buffer))), promise: nil) buffer.clear() buffer.writeStaticString("done with the request now\n") + let loopBoundContext = NIOLoopBound.init(context, eventLoop: context.eventLoop) _ = context.eventLoop.scheduleTask(in: .seconds(30)) { [buffer] in - context.write(self.wrapOutboundOut(.body(.byteBuffer(buffer))), promise: nil) - context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) - + loopBoundContext.value.write(self.wrapOutboundOut(.body(.byteBuffer(buffer))), promise: nil) + loopBoundContext.value.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) } } } diff --git a/Sources/NIOExtras/HTTP1ProxyConnectHandler.swift b/Sources/NIOExtras/HTTP1ProxyConnectHandler.swift index 8ee5fe9c..eb9dc2fd 100644 --- a/Sources/NIOExtras/HTTP1ProxyConnectHandler.swift +++ b/Sources/NIOExtras/HTTP1ProxyConnectHandler.swift @@ -162,13 +162,14 @@ public final class NIOHTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableC return } + let loopBoundContext = NIOLoopBound.init(context, eventLoop: context.eventLoop) let timeout = context.eventLoop.scheduleTask(deadline: self.deadline) { switch self.state { case .initialized: preconditionFailure("How can we have a scheduled timeout, if the connection is not even up?") case .connectSent, .headReceived: - self.failWithError(Error.httpProxyHandshakeTimeout(), context: context) + self.failWithError(Error.httpProxyHandshakeTimeout(), context: loopBoundContext.value) case .failed, .completed: break diff --git a/Sources/NIOWritePartialPCAPDemo/main.swift b/Sources/NIOWritePartialPCAPDemo/main.swift index f7b94774..abcddc9f 100644 --- a/Sources/NIOWritePartialPCAPDemo/main.swift +++ b/Sources/NIOWritePartialPCAPDemo/main.swift @@ -130,21 +130,22 @@ let allDonePromise = group.next().makePromise(of: Void.self) let maximumFragments = 4 let connection = try ClientBootstrap(group: group.next()) .channelInitializer { channel in - let pcapRingBuffer = NIOPCAPRingBuffer( - maximumFragments: maximumFragments, - maximumBytes: 1_000_000 - ) - return channel.pipeline.addHandler( - NIOWritePCAPHandler( - mode: .client, - fileSink: pcapRingBuffer.addFragment + channel.eventLoop.makeCompletedFuture { + let pcapRingBuffer = NIOPCAPRingBuffer( + maximumFragments: maximumFragments, + maximumBytes: 1_000_000 ) - ).flatMap { - channel.pipeline.addHTTPClientHandlers() - }.flatMap { - channel.pipeline.addHandler(TriggerPCAPHandler(pcapRingBuffer: pcapRingBuffer, sink: fileSink.write)) - }.flatMap { - channel.pipeline.addHandler(SendSimpleSequenceRequestHandler(allDonePromise: allDonePromise)) + try channel.pipeline.syncOperations.addHandler( + NIOWritePCAPHandler( + mode: .client, + fileSink: pcapRingBuffer.addFragment + ) + ) + try channel.pipeline.syncOperations.addHTTPClientHandlers() + try channel.pipeline.syncOperations.addHandlers([ + TriggerPCAPHandler(pcapRingBuffer: pcapRingBuffer, sink: fileSink.write), + SendSimpleSequenceRequestHandler(allDonePromise: allDonePromise), + ]) } } .connect(host: "httpbin.org", port: 80) diff --git a/Tests/NIOExtrasTests/PCAPRingBufferTest.swift b/Tests/NIOExtrasTests/PCAPRingBufferTest.swift index 7534fadc..cd4ddff3 100644 --- a/Tests/NIOExtrasTests/PCAPRingBufferTest.swift +++ b/Tests/NIOExtrasTests/PCAPRingBufferTest.swift @@ -204,7 +204,6 @@ class PCAPRingBufferTest: XCTestCase { ) } - let channel = EmbeddedChannel() let trigger = self.dataForTests()[0.. EventLoopFuture in - let triggerHandler = TriggerOnCumulativeSizeHandler( - triggerBytes: trigger, - pcapRingBuffer: pcapRingBuffer, - sink: testRecordedBytes - ) - return channel.pipeline.addHandler(triggerHandler, name: "trigger") - } - XCTAssertNoThrow(try addHandlers.wait()) + let channel = EmbeddedChannel(handlers: [ + NIOWritePCAPHandler(mode: .client, fileSink: pcapRingBuffer.addFragment), + TriggerOnCumulativeSizeHandler( + triggerBytes: trigger, + pcapRingBuffer: pcapRingBuffer, + sink: testRecordedBytes + ), + ]) channel.localAddress = try! SocketAddress(ipAddress: "255.255.255.254", port: Int(UInt16.max) - 1) XCTAssertNoThrow(try channel.connect(to: .init(ipAddress: "1.2.3.4", port: 5678)).wait()) diff --git a/Tests/NIOExtrasTests/SynchronizedFileSinkTests.swift b/Tests/NIOExtrasTests/SynchronizedFileSinkTests.swift index 9a9ce277..5f806322 100644 --- a/Tests/NIOExtrasTests/SynchronizedFileSinkTests.swift +++ b/Tests/NIOExtrasTests/SynchronizedFileSinkTests.swift @@ -67,7 +67,7 @@ private func withTemporaryFile( _ body: (NIOCore.NIOFileHandle, String) throws -> T ) throws -> T { let temporaryFilePath = "\(temporaryDirectory)/nio_extras_\(UUID())" - FileManager.default.createFile(atPath: temporaryFilePath, contents: content?.data(using: .utf8)) + XCTAssertTrue(FileManager.default.createFile(atPath: temporaryFilePath, contents: content?.data(using: .utf8))) defer { XCTAssertNoThrow(try FileManager.default.removeItem(atPath: temporaryFilePath)) } @@ -86,7 +86,7 @@ private func withTemporaryFile( _ body: (NIOCore.NIOFileHandle, String) async throws -> T ) async throws -> T { let temporaryFilePath = "\(temporaryDirectory)/nio_extras_\(UUID())" - FileManager.default.createFile(atPath: temporaryFilePath, contents: content?.data(using: .utf8)) + XCTAssertTrue(FileManager.default.createFile(atPath: temporaryFilePath, contents: content?.data(using: .utf8))) defer { XCTAssertNoThrow(try FileManager.default.removeItem(atPath: temporaryFilePath)) }