Skip to content

Commit

Permalink
Backport #177 to 1.9 (#179)
Browse files Browse the repository at this point in the history
Motivation

Currently we don't confirm that the decompression has completed
successfully. This means that we can incorrectly spin forever attempting
to decompress past the end of a message, and that we can fail to notice
that a message is truncated. Neither of these is good.

Modifications

Propagate the message zlib gives us as to whether or not decompression
is done, and keep track of it.
Add some tests written by @vojtarylko to validate the behaviour.

Result

Correctly police the bounds of the messages.
  • Loading branch information
Lukasa authored Sep 16, 2022
1 parent 29e4c0a commit e800075
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 14 deletions.
56 changes: 47 additions & 9 deletions Sources/NIOHTTPCompression/HTTPDecompression.swift
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,30 @@ public enum NIOHTTPDecompression {
case initializationError(Int)
}

// Would have been public, but this is a backport and cannot add new API.
internal struct ExtraDecompressionError: Error, Hashable, CustomStringConvertible {
private var backing: Backing

private enum Backing {
case invalidTrailingData
case truncatedData
}

private init(_ backing: Backing) {
self.backing = backing
}

/// Decompression completed but there was invalid trailing data behind the compressed data.
static let invalidTrailingData = ExtraDecompressionError(.invalidTrailingData)

/// The decompressed data was incorrectly truncated.
static let truncatedData = ExtraDecompressionError(.truncatedData)

var description: String {
return String(describing: self.backing)
}
}

enum CompressionAlgorithm: String {
case gzip
case deflate
Expand Down Expand Up @@ -86,12 +110,15 @@ public enum NIOHTTPDecompression {
self.limit = limit
}

mutating func decompress(part: inout ByteBuffer, buffer: inout ByteBuffer, compressedLength: Int) throws {
self.inflated += try self.stream.inflatePart(input: &part, output: &buffer)
mutating func decompress(part: inout ByteBuffer, buffer: inout ByteBuffer, compressedLength: Int) throws -> InflateResult {
let result = try self.stream.inflatePart(input: &part, output: &buffer)
self.inflated += result.written

if self.limit.exceeded(compressed: compressedLength, decompressed: self.inflated) {
throw NIOHTTPDecompression.DecompressionError.limit
}

return result
}

mutating func initializeDecoder(encoding: NIOHTTPDecompression.CompressionAlgorithm) throws {
Expand All @@ -112,9 +139,10 @@ public enum NIOHTTPDecompression {
}

extension z_stream {
mutating func inflatePart(input: inout ByteBuffer, output: inout ByteBuffer) throws -> Int {
mutating func inflatePart(input: inout ByteBuffer, output: inout ByteBuffer) throws -> InflateResult {
let minimumCapacity = input.readableBytes * 2
var written = 0
var inflateResult = InflateResult(written: 0, complete: false)

try input.readWithUnsafeMutableReadableBytes { pointer in
self.avail_in = UInt32(pointer.count)
self.next_in = CNIOExtrasZlib_voidPtr_to_BytefPtr(pointer.baseAddress!)
Expand All @@ -126,24 +154,34 @@ extension z_stream {
self.next_out = nil
}

written += try self.inflatePart(to: &output, minimumCapacity: minimumCapacity)
inflateResult = try self.inflatePart(to: &output, minimumCapacity: minimumCapacity)

return pointer.count - Int(self.avail_in)
}
return written
return inflateResult
}

private mutating func inflatePart(to buffer: inout ByteBuffer, minimumCapacity: Int) throws -> Int {
return try buffer.writeWithUnsafeMutableBytes(minimumWritableBytes: minimumCapacity) { pointer in
private mutating func inflatePart(to buffer: inout ByteBuffer, minimumCapacity: Int) throws -> InflateResult {
var rc = Z_OK

let written = try buffer.writeWithUnsafeMutableBytes(minimumWritableBytes: minimumCapacity) { pointer in
self.avail_out = UInt32(pointer.count)
self.next_out = CNIOExtrasZlib_voidPtr_to_BytefPtr(pointer.baseAddress!)

let rc = inflate(&self, Z_NO_FLUSH)
rc = inflate(&self, Z_NO_FLUSH)
guard rc == Z_OK || rc == Z_STREAM_END else {
throw NIOHTTPDecompression.DecompressionError.inflationError(Int(rc))
}

return pointer.count - Int(self.avail_out)
}

return InflateResult(written: written, complete: rc == Z_STREAM_END)
}
}

struct InflateResult {
var written: Int

var complete: Bool
}
20 changes: 18 additions & 2 deletions Sources/NIOHTTPCompression/HTTPRequestDecompressor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@ public final class NIOHTTPRequestDecompressor: ChannelDuplexHandler, RemovableCh

private var decompressor: NIOHTTPDecompression.Decompressor
private var compression: Compression?
private var decompressionComplete: Bool

public init(limit: NIOHTTPDecompression.DecompressionLimit) {
self.decompressor = NIOHTTPDecompression.Decompressor(limit: limit)
self.compression = nil
self.decompressionComplete = false
}

public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
Expand Down Expand Up @@ -61,21 +63,35 @@ public final class NIOHTTPRequestDecompressor: ChannelDuplexHandler, RemovableCh
return
}

while part.readableBytes > 0 {
while part.readableBytes > 0 && !self.decompressionComplete {
do {
var buffer = context.channel.allocator.buffer(capacity: 16384)
try self.decompressor.decompress(part: &part, buffer: &buffer, compressedLength: compression.contentLength)
let result = try self.decompressor.decompress(part: &part, buffer: &buffer, compressedLength: compression.contentLength)
if result.complete {
self.decompressionComplete = true
}

context.fireChannelRead(self.wrapInboundOut(.body(buffer)))
} catch let error {
context.fireErrorCaught(error)
return
}
}

if part.readableBytes > 0 {
context.fireErrorCaught(NIOHTTPDecompression.ExtraDecompressionError.invalidTrailingData)
}
case .end:
if self.compression != nil {
let wasDecompressionComplete = self.decompressionComplete

self.decompressor.deinitializeDecoder()
self.compression = nil
self.decompressionComplete = false

if !wasDecompressionComplete {
context.fireErrorCaught(NIOHTTPDecompression.ExtraDecompressionError.truncatedData)
}
}

context.fireChannelRead(data)
Expand Down
20 changes: 18 additions & 2 deletions Sources/NIOHTTPCompression/HTTPResponseDecompressor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@ public final class NIOHTTPResponseDecompressor: ChannelDuplexHandler, RemovableC

private var compression: Compression? = nil
private var decompressor: NIOHTTPDecompression.Decompressor
private var decompressionComplete: Bool

public init(limit: NIOHTTPDecompression.DecompressionLimit) {
self.decompressor = NIOHTTPDecompression.Decompressor(limit: limit)
self.decompressionComplete = false
}

public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
Expand Down Expand Up @@ -77,22 +79,36 @@ public final class NIOHTTPResponseDecompressor: ChannelDuplexHandler, RemovableC

do {
compression.compressedLength += part.readableBytes
while part.readableBytes > 0 {
while part.readableBytes > 0 && !self.decompressionComplete {
var buffer = context.channel.allocator.buffer(capacity: 16384)
try self.decompressor.decompress(part: &part, buffer: &buffer, compressedLength: compression.compressedLength)
let result = try self.decompressor.decompress(part: &part, buffer: &buffer, compressedLength: compression.compressedLength)
if result.complete {
self.decompressionComplete = true
}
context.fireChannelRead(self.wrapInboundOut(.body(buffer)))
}

// assign the changed local property back to the class state
self.compression = compression

if part.readableBytes > 0 {
context.fireErrorCaught(NIOHTTPDecompression.ExtraDecompressionError.invalidTrailingData)
}
}
catch {
context.fireErrorCaught(error)
}
case .end:
if self.compression != nil {
let wasDecompressionComplete = self.decompressionComplete

self.decompressor.deinitializeDecoder()
self.compression = nil
self.decompressionComplete = false

if !wasDecompressionComplete {
context.fireErrorCaught(NIOHTTPDecompression.ExtraDecompressionError.truncatedData)
}
}
context.fireChannelRead(data)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ extension HTTPRequestDecompressorTest {
("testDecompressionLimitRatio", testDecompressionLimitRatio),
("testDecompressionLimitSize", testDecompressionLimitSize),
("testDecompression", testDecompression),
("testDecompressionTrailingData", testDecompressionTrailingData),
("testDecompressionTruncatedInput", testDecompressionTruncatedInput),
]
}
}
Expand Down
26 changes: 25 additions & 1 deletion Tests/NIOHTTPCompressionTests/HTTPRequestDecompressorTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,33 @@ class HTTPRequestDecompressorTest: XCTestCase {
)

XCTAssertNoThrow(try channel.writeInbound(HTTPServerRequestPart.body(compressed)))
XCTAssertNoThrow(try channel.writeInbound(HTTPServerRequestPart.end(nil)))
}
}

func testDecompressionTrailingData() throws {
// Valid compressed data with some trailing garbage
let compressed = ByteBuffer(bytes: [120, 156, 99, 0, 0, 0, 1, 0, 1] + [1, 2, 3])

let channel = EmbeddedChannel()
try channel.pipeline.addHandler(NIOHTTPRequestDecompressor(limit: .none)).wait()
let headers = HTTPHeaders([("Content-Encoding", "deflate"), ("Content-Length", "\(compressed.readableBytes)")])
try channel.writeInbound(HTTPServerRequestPart.head(.init(version: .init(major: 1, minor: 1), method: .POST, uri: "https://nio.swift.org/test", headers: headers)))

XCTAssertThrowsError(try channel.writeInbound(HTTPServerRequestPart.body(compressed)))
}

func testDecompressionTruncatedInput() throws {
// Truncated compressed data
let compressed = ByteBuffer(bytes: [120, 156, 99, 0])

XCTAssertNoThrow(try channel.writeInbound(HTTPServerRequestPart.end(nil)))
let channel = EmbeddedChannel()
try channel.pipeline.addHandler(NIOHTTPRequestDecompressor(limit: .none)).wait()
let headers = HTTPHeaders([("Content-Encoding", "deflate"), ("Content-Length", "\(compressed.readableBytes)")])
try channel.writeInbound(HTTPServerRequestPart.head(.init(version: .init(major: 1, minor: 1), method: .POST, uri: "https://nio.swift.org/test", headers: headers)))

XCTAssertNoThrow(try channel.writeInbound(HTTPServerRequestPart.body(compressed)))
XCTAssertThrowsError(try channel.writeInbound(HTTPServerRequestPart.end(nil)))
}

private func compress(_ body: ByteBuffer, _ algorithm: String) -> ByteBuffer {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ extension HTTPResponseDecompressorTest {
("testDecompressionLimitRatioWithoutContentLenghtHeaderFails", testDecompressionLimitRatioWithoutContentLenghtHeaderFails),
("testDecompression", testDecompression),
("testDecompressionWithoutContentLength", testDecompressionWithoutContentLength),
("testDecompressionTrailingData", testDecompressionTrailingData),
("testDecompressionTruncatedInput", testDecompressionTruncatedInput),
]
}
}
Expand Down
25 changes: 25 additions & 0 deletions Tests/NIOHTTPCompressionTests/HTTPResponseDecompressorTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,31 @@ class HTTPResponseDecompressorTest: XCTestCase {
XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.end(nil)))
}

func testDecompressionTrailingData() throws {
// Valid compressed data with some trailing garbage
let compressed = ByteBuffer(bytes: [120, 156, 99, 0, 0, 0, 1, 0, 1] + [1, 2, 3])

let channel = EmbeddedChannel()
try channel.pipeline.addHandler(NIOHTTPResponseDecompressor(limit: .none)).wait()
let headers = HTTPHeaders([("Content-Encoding", "deflate"), ("Content-Length", "\(compressed.readableBytes)")])
try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers)))

XCTAssertThrowsError(try channel.writeInbound(HTTPClientResponsePart.body(compressed)))
}

func testDecompressionTruncatedInput() throws {
// Truncated compressed data
let compressed = ByteBuffer(bytes: [120, 156, 99, 0])

let channel = EmbeddedChannel()
try channel.pipeline.addHandler(NIOHTTPResponseDecompressor(limit: .none)).wait()
let headers = HTTPHeaders([("Content-Encoding", "deflate"), ("Content-Length", "\(compressed.readableBytes)")])
try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers)))

XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.body(compressed)))
XCTAssertThrowsError(try channel.writeInbound(HTTPClientResponsePart.end(nil)))
}

private func compress(_ body: ByteBuffer, _ algorithm: String) -> ByteBuffer {
var stream = z_stream()

Expand Down

0 comments on commit e800075

Please sign in to comment.