diff --git a/Sources/NIOExtras/LengthFieldBasedFrameDecoder.swift b/Sources/NIOExtras/LengthFieldBasedFrameDecoder.swift index 4117015d..1ee82a69 100644 --- a/Sources/NIOExtras/LengthFieldBasedFrameDecoder.swift +++ b/Sources/NIOExtras/LengthFieldBasedFrameDecoder.swift @@ -46,6 +46,13 @@ extension ByteBuffer { } } +public enum NIOLengthFieldBasedFrameDecoderError: Error { + /// This error can be thrown by `LengthFieldBasedFrameDecoder` if the length field value is larger than `Int.max` + case lengthFieldValueTooLarge + /// This error can be thrown by `LengthFieldBasedFrameDecoder` if the length field value is larger than `LengthFieldBasedFrameDecoder.maxSupportedLengthFieldSize` + case lengthFieldValueLargerThanMaxSupportedSize +} + /// /// A decoder that splits the received `ByteBuffer` by the number of bytes specified in a fixed length header /// contained within the buffer. @@ -65,6 +72,8 @@ extension ByteBuffer { /// 'A' and 'E' will be the headers and will not be passed forward. /// public final class LengthFieldBasedFrameDecoder: ByteToMessageDecoder { + /// Maximum supported length field size in bytes of `LengthFieldBasedFrameDecoder` and is currently `Int32.max` + public static let maxSupportedLengthFieldSize: Int = Int(Int32.max) /// /// An enumeration to describe the length of a piece of data in bytes. /// @@ -120,11 +129,6 @@ public final class LengthFieldBasedFrameDecoder: ByteToMessageDecoder { /// - lengthFieldEndianness: The endianness of the field specifying the remaining length of the frame. /// public init(lengthFieldBitLength: NIOLengthFieldBitLength, lengthFieldEndianness: Endianness = .big) { - - // The value contained in the length field must be able to be represented by an integer type on the platform. - // ie. .eight == 64bit which would not fit into the Int type on a 32bit platform. - precondition(lengthFieldBitLength.length <= Int.bitWidth/8) - self.lengthFieldLength = lengthFieldBitLength self.lengthFieldEndianness = lengthFieldEndianness } @@ -166,7 +170,7 @@ public final class LengthFieldBasedFrameDecoder: ByteToMessageDecoder { private func readNextLengthFieldToState(buffer: inout ByteBuffer) throws { // Convert the length field to an integer specifying the length - guard let lengthFieldValue = self.readFrameLength(for: &buffer) else { + guard let lengthFieldValue = try self.readFrameLength(for: &buffer) else { return } @@ -198,19 +202,35 @@ public final class LengthFieldBasedFrameDecoder: ByteToMessageDecoder { /// - parameters: /// - buffer: The buffer containing the integer frame length. /// - private func readFrameLength(for buffer: inout ByteBuffer) -> Int? { - + private func readFrameLength(for buffer: inout ByteBuffer) throws -> Int? { + let frameLength: Int? switch self.lengthFieldLength.bitLength { case .bits8: - return buffer.readInteger(endianness: self.lengthFieldEndianness, as: UInt8.self).map { Int($0) } + frameLength = buffer.readInteger(endianness: self.lengthFieldEndianness, as: UInt8.self).map { Int($0) } case .bits16: - return buffer.readInteger(endianness: self.lengthFieldEndianness, as: UInt16.self).map { Int($0) } + frameLength = buffer.readInteger(endianness: self.lengthFieldEndianness, as: UInt16.self).map { Int($0) } case .bits24: - return buffer.read24UInt(endianness: self.lengthFieldEndianness).map { Int($0) } + frameLength = buffer.read24UInt(endianness: self.lengthFieldEndianness).map { Int($0) } case .bits32: - return buffer.readInteger(endianness: self.lengthFieldEndianness, as: UInt32.self).map { Int($0) } + frameLength = try buffer.readInteger(endianness: self.lengthFieldEndianness, as: UInt32.self).map { + guard let size = Int(exactly: $0) else { + throw NIOLengthFieldBasedFrameDecoderError.lengthFieldValueTooLarge + } + return size + } case .bits64: - return buffer.readInteger(endianness: self.lengthFieldEndianness, as: UInt64.self).map { Int($0) } + frameLength = try buffer.readInteger(endianness: self.lengthFieldEndianness, as: UInt64.self).map { + guard let size = Int(exactly: $0) else { + throw NIOLengthFieldBasedFrameDecoderError.lengthFieldValueTooLarge + } + return size + } + } + + if let frameLength = frameLength, + frameLength > LengthFieldBasedFrameDecoder.maxSupportedLengthFieldSize { + throw NIOLengthFieldBasedFrameDecoderError.lengthFieldValueLargerThanMaxSupportedSize } + return frameLength } } diff --git a/Tests/NIOExtrasTests/LengthFieldBasedFrameDecoderTest+XCTest.swift b/Tests/NIOExtrasTests/LengthFieldBasedFrameDecoderTest+XCTest.swift index dd48cef1..f576f286 100644 --- a/Tests/NIOExtrasTests/LengthFieldBasedFrameDecoderTest+XCTest.swift +++ b/Tests/NIOExtrasTests/LengthFieldBasedFrameDecoderTest+XCTest.swift @@ -45,6 +45,10 @@ extension LengthFieldBasedFrameDecoderTest { ("testRemoveHandlerWhenBufferIsNotEmpty", testRemoveHandlerWhenBufferIsNotEmpty), ("testCloseInChannelRead", testCloseInChannelRead), ("testBasicVerification", testBasicVerification), + ("testMaximumAllowedLengthWith32BitFieldLength", testMaximumAllowedLengthWith32BitFieldLength), + ("testMaliciousLengthWith32BitFieldLength", testMaliciousLengthWith32BitFieldLength), + ("testMaximumAllowedLengthWith64BitFieldLength", testMaximumAllowedLengthWith64BitFieldLength), + ("testMaliciousLengthWith64BitFieldLength", testMaliciousLengthWith64BitFieldLength), ] } } diff --git a/Tests/NIOExtrasTests/LengthFieldBasedFrameDecoderTest.swift b/Tests/NIOExtrasTests/LengthFieldBasedFrameDecoderTest.swift index 7cfb0fbc..886f411b 100644 --- a/Tests/NIOExtrasTests/LengthFieldBasedFrameDecoderTest.swift +++ b/Tests/NIOExtrasTests/LengthFieldBasedFrameDecoderTest.swift @@ -513,4 +513,58 @@ class LengthFieldBasedFrameDecoderTest: XCTestCase { }) } } + func testMaximumAllowedLengthWith32BitFieldLength() throws { + self.decoderUnderTest = .init(LengthFieldBasedFrameDecoder(lengthFieldLength: .four, + lengthFieldEndianness: .little)) + XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.decoderUnderTest).wait()) + + let dataLength = UInt32(Int32.max) + + var buffer = self.channel.allocator.buffer(capacity: 4) // 4 byte header + buffer.writeInteger(dataLength, endianness: .little, as: UInt32.self) + buffer.writeString(standardDataString) + + XCTAssertNoThrow(try self.channel.writeInbound(buffer)) + } + + func testMaliciousLengthWith32BitFieldLength() throws { + self.decoderUnderTest = .init(LengthFieldBasedFrameDecoder(lengthFieldLength: .four, + lengthFieldEndianness: .little)) + XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.decoderUnderTest).wait()) + + let dataLength = UInt32(Int32.max) + 1 + + var buffer = self.channel.allocator.buffer(capacity: 4) // 4 byte header + buffer.writeInteger(dataLength, endianness: .little, as: UInt32.self) + buffer.writeString(standardDataString) + + XCTAssertThrowsError(try self.channel.writeInbound(buffer)) + } + + func testMaximumAllowedLengthWith64BitFieldLength() throws { + self.decoderUnderTest = .init(LengthFieldBasedFrameDecoder(lengthFieldLength: .eight, + lengthFieldEndianness: .little)) + XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.decoderUnderTest).wait()) + + let dataLength = UInt64(Int32.max) + + var buffer = self.channel.allocator.buffer(capacity: 8) // 8 byte header + buffer.writeInteger(dataLength, endianness: .little, as: UInt64.self) + buffer.writeString(standardDataString) + + XCTAssertNoThrow(try self.channel.writeInbound(buffer)) + } + + func testMaliciousLengthWith64BitFieldLength() { + self.decoderUnderTest = .init(LengthFieldBasedFrameDecoder(lengthFieldLength: .eight, + lengthFieldEndianness: .little)) + XCTAssertNoThrow(try self.channel.pipeline.addHandler(self.decoderUnderTest).wait()) + + let dataLength = UInt64(Int32.max) + 1 + + var buffer = self.channel.allocator.buffer(capacity: 8) // 8 byte header + buffer.writeInteger(dataLength, endianness: .little, as: UInt64.self) + + XCTAssertThrowsError(try self.channel.writeInbound(buffer)) + } }