From ca622e6f112dc5fdf352007367420d58c834df8d Mon Sep 17 00:00:00 2001 From: Cory Benfield Date: Tue, 28 Jul 2020 18:26:58 +0100 Subject: [PATCH] Refactor the codecs for extension. Motivation: As part of the work in #214 we're going to need to update the HTTP2ToHTTP1 codecs. These need to be replaced for the new channel pipelines. The core of the logic will be identical in both cases, so let's start by factoring that logic out into some nice standalone objects that we can reuse. Modifications: - Pull out the base codecs into structures. - Rewrite the main codecs in terms of these new structures. Result: Easier extension points. --- Sources/NIOHTTP2/HTTP2ToHTTP1Codec.swift | 274 ++++++++++++++--------- 1 file changed, 166 insertions(+), 108 deletions(-) diff --git a/Sources/NIOHTTP2/HTTP2ToHTTP1Codec.swift b/Sources/NIOHTTP2/HTTP2ToHTTP1Codec.swift index befea1ec..5f65acbb 100644 --- a/Sources/NIOHTTP2/HTTP2ToHTTP1Codec.swift +++ b/Sources/NIOHTTP2/HTTP2ToHTTP1Codec.swift @@ -17,6 +17,84 @@ import NIOHTTP1 import NIOHPACK +fileprivate struct BaseClientCodec { + private let protocolString: String + private let normalizeHTTPHeaders: Bool + + private var headerStateMachine: HTTP2HeadersStateMachine = HTTP2HeadersStateMachine(mode: .client) + + /// Initializes a `BaseClientCodec`. + /// + /// - parameters: + /// - httpProtocol: The protocol (usually `"http"` or `"https"` that is used). + /// - normalizeHTTPHeaders: Whether to automatically normalize the HTTP headers to be suitable for HTTP/2. + /// The normalization will for example lower-case all heder names (as required by the + /// HTTP/2 spec) and remove headers that are unsuitable for HTTP/2 such as + /// headers related to HTTP/1's keep-alive behaviour. Unless you are sure that all your + /// headers conform to the HTTP/2 spec, you should leave this parameter set to `true`. + fileprivate init(httpProtocol: HTTP2ToHTTP1ClientCodec.HTTPProtocol, normalizeHTTPHeaders: Bool) { + self.normalizeHTTPHeaders = normalizeHTTPHeaders + + switch httpProtocol { + case .http: + self.protocolString = "http" + case .https: + self.protocolString = "https" + } + } + + mutating func processInboundData(_ data: HTTP2Frame.FramePayload) throws -> (first: HTTPClientResponsePart?, second: HTTPClientResponsePart?) { + switch data { + case .headers(let headerContent): + if case .trailer = try self.headerStateMachine.newHeaders(block: headerContent.headers) { + return (first: .end(HTTPHeaders(regularHeadersFrom: headerContent.headers)), second: nil) + } else { + let respHead = try HTTPResponseHead(http2HeaderBlock: headerContent.headers) + let first = HTTPClientResponsePart.head(respHead) + var second: HTTPClientResponsePart? = nil + if headerContent.endStream { + second = .end(nil) + } + return (first: first, second: second) + } + case .data(let content): + guard case .byteBuffer(let b) = content.data else { + preconditionFailure("Received DATA frame with non-bytebuffer IOData") + } + + let first = HTTPClientResponsePart.body(b) + var second: HTTPClientResponsePart? = nil + if content.endStream { + second = .end(nil) + } + return (first: first, second: second) + case .alternativeService, .rstStream, .priority, .windowUpdate, .settings, .pushPromise, .ping, .goAway, .origin: + // These don't have an HTTP/1 equivalent, so let's drop them. + return (first: nil, second: nil) + } + } + + mutating func processOutboundData(_ data: HTTPClientRequestPart, allocator: ByteBufferAllocator) throws -> HTTP2Frame.FramePayload { + switch data { + case .head(let head): + let h1Headers = try HTTPHeaders(requestHead: head, protocolString: self.protocolString) + let headerContent = HTTP2Frame.FramePayload.Headers(headers: HPACKHeaders(httpHeaders: h1Headers, + normalizeHTTPHeaders: self.normalizeHTTPHeaders)) + return .headers(headerContent) + case .body(let body): + return .data(HTTP2Frame.FramePayload.Data(data: body)) + case .end(let trailers): + if let trailers = trailers { + return .headers(.init(headers: HPACKHeaders(httpHeaders: trailers, + normalizeHTTPHeaders: self.normalizeHTTPHeaders), + endStream: true)) + } else { + return .data(.init(data: .byteBuffer(allocator.buffer(capacity: 0)), endStream: true)) + } + } + } +} + /// A simple channel handler that translates HTTP/2 concepts into HTTP/1 data types, /// and vice versa, for use on the client side. /// @@ -37,10 +115,7 @@ public final class HTTP2ToHTTP1ClientCodec: ChannelInboundHandler, ChannelOutbou } private let streamID: HTTP2StreamID - private let protocolString: String - private let normalizeHTTPHeaders: Bool - - private var headerStateMachine: HTTP2HeadersStateMachine = HTTP2HeadersStateMachine(mode: .client) + private var baseCodec: BaseClientCodec /// Initializes a `HTTP2ToHTTP1ClientCodec` for the given `HTTP2StreamID`. /// @@ -54,14 +129,7 @@ public final class HTTP2ToHTTP1ClientCodec: ChannelInboundHandler, ChannelOutbou /// headers conform to the HTTP/2 spec, you should leave this parameter set to `true`. public init(streamID: HTTP2StreamID, httpProtocol: HTTPProtocol, normalizeHTTPHeaders: Bool) { self.streamID = streamID - self.normalizeHTTPHeaders = normalizeHTTPHeaders - - switch httpProtocol { - case .http: - self.protocolString = "http" - case .https: - self.protocolString = "https" - } + self.baseCodec = BaseClientCodec(httpProtocol: httpProtocol, normalizeHTTPHeaders: normalizeHTTPHeaders) } /// Initializes a `HTTP2ToHTTP1ClientCodec` for the given `HTTP2StreamID`. @@ -75,67 +143,91 @@ public final class HTTP2ToHTTP1ClientCodec: ChannelInboundHandler, ChannelOutbou public func channelRead(context: ChannelHandlerContext, data: NIOAny) { let frame = self.unwrapInboundIn(data) + do { + let (first, second) = try self.baseCodec.processInboundData(frame.payload) + if let first = first { + context.fireChannelRead(self.wrapInboundOut(first)) + } + if let second = second { + context.fireChannelRead(self.wrapInboundOut(second)) + } + } catch { + context.fireErrorCaught(error) + } + } + + public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + let responsePart = self.unwrapOutboundIn(data) + + do { + let transformedPayload = try self.baseCodec.processOutboundData(responsePart, allocator: context.channel.allocator) + let part = HTTP2Frame(streamID: self.streamID, payload: transformedPayload) + context.write(self.wrapOutboundOut(part), promise: promise) + } catch { + promise?.fail(error) + context.fireErrorCaught(error) + } + } +} + + +fileprivate struct BaseServerCodec { + private let normalizeHTTPHeaders: Bool + private var headerStateMachine: HTTP2HeadersStateMachine = HTTP2HeadersStateMachine(mode: .server) + + init(normalizeHTTPHeaders: Bool) { + self.normalizeHTTPHeaders = normalizeHTTPHeaders + } - switch frame.payload { + mutating func processInboundData(_ data: HTTP2Frame.FramePayload) throws -> (first: HTTPServerRequestPart?, second: HTTPServerRequestPart?) { + switch data { case .headers(let headerContent): - do { - if case .trailer = try self.headerStateMachine.newHeaders(block: headerContent.headers) { - context.fireChannelRead(self.wrapInboundOut(.end(HTTPHeaders(regularHeadersFrom: headerContent.headers)))) - } else { - let respHead = try HTTPResponseHead(http2HeaderBlock: headerContent.headers) - context.fireChannelRead(self.wrapInboundOut(.head(respHead))) - if headerContent.endStream { - context.fireChannelRead(self.wrapInboundOut(.end(nil))) - } + if case .trailer = try self.headerStateMachine.newHeaders(block: headerContent.headers) { + return (first: .end(HTTPHeaders(regularHeadersFrom: headerContent.headers)), second: nil) + } else { + let reqHead = try HTTPRequestHead(http2HeaderBlock: headerContent.headers) + + let first = HTTPServerRequestPart.head(reqHead) + var second: HTTPServerRequestPart? = nil + if headerContent.endStream { + second = .end(nil) } - } catch { - context.fireErrorCaught(error) + return (first: first, second: second) } - case .data(let content): - guard case .byteBuffer(let b) = content.data else { - preconditionFailure("Received DATA frame with non-bytebuffer IOData") + case .data(let dataContent): + guard case .byteBuffer(let b) = dataContent.data else { + preconditionFailure("Received non-byteBuffer IOData from network") } - - context.fireChannelRead(self.wrapInboundOut(.body(b))) - if content.endStream { - context.fireChannelRead(self.wrapInboundOut(.end(nil))) + let first = HTTPServerRequestPart.body(b) + var second: HTTPServerRequestPart? = nil + if dataContent.endStream { + second = .end(nil) } - case .alternativeService, .rstStream, .priority, .windowUpdate, .settings, .pushPromise, .ping, .goAway, .origin: - // These don't have an HTTP/1 equivalent, so let's drop them. - return + return (first: first, second: second) + default: + // Any other frame type is ignored. + return (first: nil, second: nil) } } - public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { - let responsePart = self.unwrapOutboundIn(data) - switch responsePart { + mutating func processOutboundData(_ data: HTTPServerResponsePart, allocator: ByteBufferAllocator) throws -> HTTP2Frame.FramePayload { + switch data { case .head(let head): - do { - let h1Headers = try HTTPHeaders(requestHead: head, protocolString: self.protocolString) - let headerContent = HTTP2Frame.FramePayload.Headers(headers: HPACKHeaders(httpHeaders: h1Headers, - normalizeHTTPHeaders: self.normalizeHTTPHeaders)) - let frame = HTTP2Frame(streamID: self.streamID, payload: .headers(headerContent)) - context.write(self.wrapOutboundOut(frame), promise: promise) - } catch { - promise?.fail(error) - context.fireErrorCaught(error) - } + let h1 = HTTPHeaders(responseHead: head) + let payload = HTTP2Frame.FramePayload.Headers(headers: HPACKHeaders(httpHeaders: h1, + normalizeHTTPHeaders: self.normalizeHTTPHeaders)) + return .headers(payload) case .body(let body): let payload = HTTP2Frame.FramePayload.Data(data: body) - let frame = HTTP2Frame(streamID: self.streamID, payload: .data(payload)) - context.write(self.wrapOutboundOut(frame), promise: promise) + return .data(payload) case .end(let trailers): - let payload: HTTP2Frame.FramePayload if let trailers = trailers { - payload = .headers(.init(headers: HPACKHeaders(httpHeaders: trailers, - normalizeHTTPHeaders: self.normalizeHTTPHeaders), - endStream: true)) + return .headers(.init(headers: HPACKHeaders(httpHeaders: trailers, + normalizeHTTPHeaders: self.normalizeHTTPHeaders), + endStream: true)) } else { - payload = .data(.init(data: .byteBuffer(context.channel.allocator.buffer(capacity: 0)), endStream: true)) + return .data(.init(data: .byteBuffer(allocator.buffer(capacity: 0)), endStream: true)) } - - let frame = HTTP2Frame(streamID: self.streamID, payload: payload) - context.write(self.wrapOutboundOut(frame), promise: promise) } } } @@ -155,9 +247,7 @@ public final class HTTP2ToHTTP1ServerCodec: ChannelInboundHandler, ChannelOutbou public typealias OutboundOut = HTTP2Frame private let streamID: HTTP2StreamID - private let normalizeHTTPHeaders: Bool - - private var headerStateMachine: HTTP2HeadersStateMachine = HTTP2HeadersStateMachine(mode: .server) + private var baseCodec: BaseServerCodec /// Initializes a `HTTP2ToHTTP1ServerCodec` for the given `HTTP2StreamID`. /// @@ -170,7 +260,7 @@ public final class HTTP2ToHTTP1ServerCodec: ChannelInboundHandler, ChannelOutbou /// headers conform to the HTTP/2 spec, you should leave this parameter set to `true`. public init(streamID: HTTP2StreamID, normalizeHTTPHeaders: Bool) { self.streamID = streamID - self.normalizeHTTPHeaders = normalizeHTTPHeaders + self.baseCodec = BaseServerCodec(normalizeHTTPHeaders: normalizeHTTPHeaders) } public convenience init(streamID: HTTP2StreamID) { @@ -180,61 +270,29 @@ public final class HTTP2ToHTTP1ServerCodec: ChannelInboundHandler, ChannelOutbou public func channelRead(context: ChannelHandlerContext, data: NIOAny) { let frame = self.unwrapInboundIn(data) - switch frame.payload { - case .headers(let headerContent): - do { - if case .trailer = try self.headerStateMachine.newHeaders(block: headerContent.headers) { - context.fireChannelRead(self.wrapInboundOut(.end(HTTPHeaders(regularHeadersFrom: headerContent.headers)))) - } else { - let reqHead = try HTTPRequestHead(http2HeaderBlock: headerContent.headers) - context.fireChannelRead(self.wrapInboundOut(.head(reqHead))) - if headerContent.endStream { - context.fireChannelRead(self.wrapInboundOut(.end(nil))) - } - } - } catch { - context.fireErrorCaught(error) + do { + let (first, second) = try self.baseCodec.processInboundData(frame.payload) + if let first = first { + context.fireChannelRead(self.wrapInboundOut(first)) } - case .data(let dataContent): - guard case .byteBuffer(let b) = dataContent.data else { - preconditionFailure("Received non-byteBuffer IOData from network") + if let second = second { + context.fireChannelRead(self.wrapInboundOut(second)) } - context.fireChannelRead(self.wrapInboundOut(.body(b))) - if dataContent.endStream { - context.fireChannelRead(self.wrapInboundOut(.end(nil))) - } - default: - // Any other frame type is ignored. - break + } catch { + context.fireErrorCaught(error) } } public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { let responsePart = self.unwrapOutboundIn(data) - switch responsePart { - case .head(let head): - let h1 = HTTPHeaders(responseHead: head) - let payload = HTTP2Frame.FramePayload.Headers(headers: HPACKHeaders(httpHeaders: h1, - normalizeHTTPHeaders: self.normalizeHTTPHeaders)) - let frame = HTTP2Frame(streamID: self.streamID, payload: .headers(payload)) - context.write(self.wrapOutboundOut(frame), promise: promise) - case .body(let body): - let payload = HTTP2Frame.FramePayload.Data(data: body) - let frame = HTTP2Frame(streamID: self.streamID, payload: .data(payload)) - context.write(self.wrapOutboundOut(frame), promise: promise) - case .end(let trailers): - let payload: HTTP2Frame.FramePayload - - if let trailers = trailers { - payload = .headers(.init(headers: HPACKHeaders(httpHeaders: trailers, - normalizeHTTPHeaders: self.normalizeHTTPHeaders), - endStream: true)) - } else { - payload = .data(.init(data: .byteBuffer(context.channel.allocator.buffer(capacity: 0)), endStream: true)) - } - let frame = HTTP2Frame(streamID: self.streamID, payload: payload) - context.write(self.wrapOutboundOut(frame), promise: promise) + do { + let transformedPayload = try self.baseCodec.processOutboundData(responsePart, allocator: context.channel.allocator) + let part = HTTP2Frame(streamID: self.streamID, payload: transformedPayload) + context.write(self.wrapOutboundOut(part), promise: promise) + } catch { + promise?.fail(error) + context.fireErrorCaught(error) } } }