From e8000754be4dc63fea88f121bb8f1703f4a6a5af Mon Sep 17 00:00:00 2001 From: Cory Benfield Date: Fri, 16 Sep 2022 10:30:05 +0200 Subject: [PATCH] Backport #177 to 1.9 (#179) Motivation Currently we don't confirm that the decompression has completed successfully. This means that we can incorrectly spin forever attempting to decompress past the end of a message, and that we can fail to notice that a message is truncated. Neither of these is good. Modifications Propagate the message zlib gives us as to whether or not decompression is done, and keep track of it. Add some tests written by @vojtarylko to validate the behaviour. Result Correctly police the bounds of the messages. --- .../HTTPDecompression.swift | 56 ++++++++++++++++--- .../HTTPRequestDecompressor.swift | 20 ++++++- .../HTTPResponseDecompressor.swift | 20 ++++++- .../HTTPRequestDecompressorTest+XCTest.swift | 2 + .../HTTPRequestDecompressorTest.swift | 26 ++++++++- .../HTTPResponseDecompressorTest+XCTest.swift | 2 + .../HTTPResponseDecompressorTest.swift | 25 +++++++++ 7 files changed, 137 insertions(+), 14 deletions(-) diff --git a/Sources/NIOHTTPCompression/HTTPDecompression.swift b/Sources/NIOHTTPCompression/HTTPDecompression.swift index 2568ac22..459bc288 100644 --- a/Sources/NIOHTTPCompression/HTTPDecompression.swift +++ b/Sources/NIOHTTPCompression/HTTPDecompression.swift @@ -52,6 +52,30 @@ public enum NIOHTTPDecompression { case initializationError(Int) } + // Would have been public, but this is a backport and cannot add new API. + internal struct ExtraDecompressionError: Error, Hashable, CustomStringConvertible { + private var backing: Backing + + private enum Backing { + case invalidTrailingData + case truncatedData + } + + private init(_ backing: Backing) { + self.backing = backing + } + + /// Decompression completed but there was invalid trailing data behind the compressed data. + static let invalidTrailingData = ExtraDecompressionError(.invalidTrailingData) + + /// The decompressed data was incorrectly truncated. + static let truncatedData = ExtraDecompressionError(.truncatedData) + + var description: String { + return String(describing: self.backing) + } + } + enum CompressionAlgorithm: String { case gzip case deflate @@ -86,12 +110,15 @@ public enum NIOHTTPDecompression { self.limit = limit } - mutating func decompress(part: inout ByteBuffer, buffer: inout ByteBuffer, compressedLength: Int) throws { - self.inflated += try self.stream.inflatePart(input: &part, output: &buffer) + mutating func decompress(part: inout ByteBuffer, buffer: inout ByteBuffer, compressedLength: Int) throws -> InflateResult { + let result = try self.stream.inflatePart(input: &part, output: &buffer) + self.inflated += result.written if self.limit.exceeded(compressed: compressedLength, decompressed: self.inflated) { throw NIOHTTPDecompression.DecompressionError.limit } + + return result } mutating func initializeDecoder(encoding: NIOHTTPDecompression.CompressionAlgorithm) throws { @@ -112,9 +139,10 @@ public enum NIOHTTPDecompression { } extension z_stream { - mutating func inflatePart(input: inout ByteBuffer, output: inout ByteBuffer) throws -> Int { + mutating func inflatePart(input: inout ByteBuffer, output: inout ByteBuffer) throws -> InflateResult { let minimumCapacity = input.readableBytes * 2 - var written = 0 + var inflateResult = InflateResult(written: 0, complete: false) + try input.readWithUnsafeMutableReadableBytes { pointer in self.avail_in = UInt32(pointer.count) self.next_in = CNIOExtrasZlib_voidPtr_to_BytefPtr(pointer.baseAddress!) @@ -126,24 +154,34 @@ extension z_stream { self.next_out = nil } - written += try self.inflatePart(to: &output, minimumCapacity: minimumCapacity) + inflateResult = try self.inflatePart(to: &output, minimumCapacity: minimumCapacity) return pointer.count - Int(self.avail_in) } - return written + return inflateResult } - private mutating func inflatePart(to buffer: inout ByteBuffer, minimumCapacity: Int) throws -> Int { - return try buffer.writeWithUnsafeMutableBytes(minimumWritableBytes: minimumCapacity) { pointer in + private mutating func inflatePart(to buffer: inout ByteBuffer, minimumCapacity: Int) throws -> InflateResult { + var rc = Z_OK + + let written = try buffer.writeWithUnsafeMutableBytes(minimumWritableBytes: minimumCapacity) { pointer in self.avail_out = UInt32(pointer.count) self.next_out = CNIOExtrasZlib_voidPtr_to_BytefPtr(pointer.baseAddress!) - let rc = inflate(&self, Z_NO_FLUSH) + rc = inflate(&self, Z_NO_FLUSH) guard rc == Z_OK || rc == Z_STREAM_END else { throw NIOHTTPDecompression.DecompressionError.inflationError(Int(rc)) } return pointer.count - Int(self.avail_out) } + + return InflateResult(written: written, complete: rc == Z_STREAM_END) } } + +struct InflateResult { + var written: Int + + var complete: Bool +} diff --git a/Sources/NIOHTTPCompression/HTTPRequestDecompressor.swift b/Sources/NIOHTTPCompression/HTTPRequestDecompressor.swift index 9336975c..bf7951a4 100644 --- a/Sources/NIOHTTPCompression/HTTPRequestDecompressor.swift +++ b/Sources/NIOHTTPCompression/HTTPRequestDecompressor.swift @@ -29,10 +29,12 @@ public final class NIOHTTPRequestDecompressor: ChannelDuplexHandler, RemovableCh private var decompressor: NIOHTTPDecompression.Decompressor private var compression: Compression? + private var decompressionComplete: Bool public init(limit: NIOHTTPDecompression.DecompressionLimit) { self.decompressor = NIOHTTPDecompression.Decompressor(limit: limit) self.compression = nil + self.decompressionComplete = false } public func channelRead(context: ChannelHandlerContext, data: NIOAny) { @@ -61,10 +63,13 @@ public final class NIOHTTPRequestDecompressor: ChannelDuplexHandler, RemovableCh return } - while part.readableBytes > 0 { + while part.readableBytes > 0 && !self.decompressionComplete { do { var buffer = context.channel.allocator.buffer(capacity: 16384) - try self.decompressor.decompress(part: &part, buffer: &buffer, compressedLength: compression.contentLength) + let result = try self.decompressor.decompress(part: &part, buffer: &buffer, compressedLength: compression.contentLength) + if result.complete { + self.decompressionComplete = true + } context.fireChannelRead(self.wrapInboundOut(.body(buffer))) } catch let error { @@ -72,10 +77,21 @@ public final class NIOHTTPRequestDecompressor: ChannelDuplexHandler, RemovableCh return } } + + if part.readableBytes > 0 { + context.fireErrorCaught(NIOHTTPDecompression.ExtraDecompressionError.invalidTrailingData) + } case .end: if self.compression != nil { + let wasDecompressionComplete = self.decompressionComplete + self.decompressor.deinitializeDecoder() self.compression = nil + self.decompressionComplete = false + + if !wasDecompressionComplete { + context.fireErrorCaught(NIOHTTPDecompression.ExtraDecompressionError.truncatedData) + } } context.fireChannelRead(data) diff --git a/Sources/NIOHTTPCompression/HTTPResponseDecompressor.swift b/Sources/NIOHTTPCompression/HTTPResponseDecompressor.swift index de6a9834..a98d2d58 100644 --- a/Sources/NIOHTTPCompression/HTTPResponseDecompressor.swift +++ b/Sources/NIOHTTPCompression/HTTPResponseDecompressor.swift @@ -33,9 +33,11 @@ public final class NIOHTTPResponseDecompressor: ChannelDuplexHandler, RemovableC private var compression: Compression? = nil private var decompressor: NIOHTTPDecompression.Decompressor + private var decompressionComplete: Bool public init(limit: NIOHTTPDecompression.DecompressionLimit) { self.decompressor = NIOHTTPDecompression.Decompressor(limit: limit) + self.decompressionComplete = false } public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { @@ -77,22 +79,36 @@ public final class NIOHTTPResponseDecompressor: ChannelDuplexHandler, RemovableC do { compression.compressedLength += part.readableBytes - while part.readableBytes > 0 { + while part.readableBytes > 0 && !self.decompressionComplete { var buffer = context.channel.allocator.buffer(capacity: 16384) - try self.decompressor.decompress(part: &part, buffer: &buffer, compressedLength: compression.compressedLength) + let result = try self.decompressor.decompress(part: &part, buffer: &buffer, compressedLength: compression.compressedLength) + if result.complete { + self.decompressionComplete = true + } context.fireChannelRead(self.wrapInboundOut(.body(buffer))) } // assign the changed local property back to the class state self.compression = compression + + if part.readableBytes > 0 { + context.fireErrorCaught(NIOHTTPDecompression.ExtraDecompressionError.invalidTrailingData) + } } catch { context.fireErrorCaught(error) } case .end: if self.compression != nil { + let wasDecompressionComplete = self.decompressionComplete + self.decompressor.deinitializeDecoder() self.compression = nil + self.decompressionComplete = false + + if !wasDecompressionComplete { + context.fireErrorCaught(NIOHTTPDecompression.ExtraDecompressionError.truncatedData) + } } context.fireChannelRead(data) } diff --git a/Tests/NIOHTTPCompressionTests/HTTPRequestDecompressorTest+XCTest.swift b/Tests/NIOHTTPCompressionTests/HTTPRequestDecompressorTest+XCTest.swift index 12649957..04a31a9c 100644 --- a/Tests/NIOHTTPCompressionTests/HTTPRequestDecompressorTest+XCTest.swift +++ b/Tests/NIOHTTPCompressionTests/HTTPRequestDecompressorTest+XCTest.swift @@ -30,6 +30,8 @@ extension HTTPRequestDecompressorTest { ("testDecompressionLimitRatio", testDecompressionLimitRatio), ("testDecompressionLimitSize", testDecompressionLimitSize), ("testDecompression", testDecompression), + ("testDecompressionTrailingData", testDecompressionTrailingData), + ("testDecompressionTruncatedInput", testDecompressionTruncatedInput), ] } } diff --git a/Tests/NIOHTTPCompressionTests/HTTPRequestDecompressorTest.swift b/Tests/NIOHTTPCompressionTests/HTTPRequestDecompressorTest.swift index d6733562..e6943840 100644 --- a/Tests/NIOHTTPCompressionTests/HTTPRequestDecompressorTest.swift +++ b/Tests/NIOHTTPCompressionTests/HTTPRequestDecompressorTest.swift @@ -119,9 +119,33 @@ class HTTPRequestDecompressorTest: XCTestCase { ) XCTAssertNoThrow(try channel.writeInbound(HTTPServerRequestPart.body(compressed))) + XCTAssertNoThrow(try channel.writeInbound(HTTPServerRequestPart.end(nil))) } + } + + func testDecompressionTrailingData() throws { + // Valid compressed data with some trailing garbage + let compressed = ByteBuffer(bytes: [120, 156, 99, 0, 0, 0, 1, 0, 1] + [1, 2, 3]) + + let channel = EmbeddedChannel() + try channel.pipeline.addHandler(NIOHTTPRequestDecompressor(limit: .none)).wait() + let headers = HTTPHeaders([("Content-Encoding", "deflate"), ("Content-Length", "\(compressed.readableBytes)")]) + try channel.writeInbound(HTTPServerRequestPart.head(.init(version: .init(major: 1, minor: 1), method: .POST, uri: "https://nio.swift.org/test", headers: headers))) + + XCTAssertThrowsError(try channel.writeInbound(HTTPServerRequestPart.body(compressed))) + } + + func testDecompressionTruncatedInput() throws { + // Truncated compressed data + let compressed = ByteBuffer(bytes: [120, 156, 99, 0]) - XCTAssertNoThrow(try channel.writeInbound(HTTPServerRequestPart.end(nil))) + let channel = EmbeddedChannel() + try channel.pipeline.addHandler(NIOHTTPRequestDecompressor(limit: .none)).wait() + let headers = HTTPHeaders([("Content-Encoding", "deflate"), ("Content-Length", "\(compressed.readableBytes)")]) + try channel.writeInbound(HTTPServerRequestPart.head(.init(version: .init(major: 1, minor: 1), method: .POST, uri: "https://nio.swift.org/test", headers: headers))) + + XCTAssertNoThrow(try channel.writeInbound(HTTPServerRequestPart.body(compressed))) + XCTAssertThrowsError(try channel.writeInbound(HTTPServerRequestPart.end(nil))) } private func compress(_ body: ByteBuffer, _ algorithm: String) -> ByteBuffer { diff --git a/Tests/NIOHTTPCompressionTests/HTTPResponseDecompressorTest+XCTest.swift b/Tests/NIOHTTPCompressionTests/HTTPResponseDecompressorTest+XCTest.swift index 30d6d459..9f844d11 100644 --- a/Tests/NIOHTTPCompressionTests/HTTPResponseDecompressorTest+XCTest.swift +++ b/Tests/NIOHTTPCompressionTests/HTTPResponseDecompressorTest+XCTest.swift @@ -37,6 +37,8 @@ extension HTTPResponseDecompressorTest { ("testDecompressionLimitRatioWithoutContentLenghtHeaderFails", testDecompressionLimitRatioWithoutContentLenghtHeaderFails), ("testDecompression", testDecompression), ("testDecompressionWithoutContentLength", testDecompressionWithoutContentLength), + ("testDecompressionTrailingData", testDecompressionTrailingData), + ("testDecompressionTruncatedInput", testDecompressionTruncatedInput), ] } } diff --git a/Tests/NIOHTTPCompressionTests/HTTPResponseDecompressorTest.swift b/Tests/NIOHTTPCompressionTests/HTTPResponseDecompressorTest.swift index 389a049b..a0a33df2 100644 --- a/Tests/NIOHTTPCompressionTests/HTTPResponseDecompressorTest.swift +++ b/Tests/NIOHTTPCompressionTests/HTTPResponseDecompressorTest.swift @@ -238,6 +238,31 @@ class HTTPResponseDecompressorTest: XCTestCase { XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.end(nil))) } + func testDecompressionTrailingData() throws { + // Valid compressed data with some trailing garbage + let compressed = ByteBuffer(bytes: [120, 156, 99, 0, 0, 0, 1, 0, 1] + [1, 2, 3]) + + let channel = EmbeddedChannel() + try channel.pipeline.addHandler(NIOHTTPResponseDecompressor(limit: .none)).wait() + let headers = HTTPHeaders([("Content-Encoding", "deflate"), ("Content-Length", "\(compressed.readableBytes)")]) + try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers))) + + XCTAssertThrowsError(try channel.writeInbound(HTTPClientResponsePart.body(compressed))) + } + + func testDecompressionTruncatedInput() throws { + // Truncated compressed data + let compressed = ByteBuffer(bytes: [120, 156, 99, 0]) + + let channel = EmbeddedChannel() + try channel.pipeline.addHandler(NIOHTTPResponseDecompressor(limit: .none)).wait() + let headers = HTTPHeaders([("Content-Encoding", "deflate"), ("Content-Length", "\(compressed.readableBytes)")]) + try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers))) + + XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.body(compressed))) + XCTAssertThrowsError(try channel.writeInbound(HTTPClientResponsePart.end(nil))) + } + private func compress(_ body: ByteBuffer, _ algorithm: String) -> ByteBuffer { var stream = z_stream()