diff --git a/Sources/NIOHTTP2/HTTP2Frame.swift b/Sources/NIOHTTP2/HTTP2Frame.swift index 34c83d10..244dbd7a 100644 --- a/Sources/NIOHTTP2/HTTP2Frame.swift +++ b/Sources/NIOHTTP2/HTTP2Frame.swift @@ -253,6 +253,10 @@ public struct HTTP2Frame { } extension HTTP2Frame: HTTP2FrameConvertible, HTTP2FramePayloadConvertible { + init(http2Frame: HTTP2Frame) { + self = http2Frame + } + func makeHTTP2Frame(streamID: HTTP2StreamID) -> HTTP2Frame { assert(self.streamID == streamID, "streamID does not match") return self @@ -264,6 +268,10 @@ extension HTTP2Frame.FramePayload: HTTP2FrameConvertible, HTTP2FramePayloadConve return self } + init(http2Frame: HTTP2Frame) { + self = http2Frame.payload + } + func makeHTTP2Frame(streamID: HTTP2StreamID) -> HTTP2Frame { return HTTP2Frame(streamID: streamID, payload: self) } diff --git a/Sources/NIOHTTP2/HTTP2FrameConvertible.swift b/Sources/NIOHTTP2/HTTP2FrameConvertible.swift index 113854c3..181ed9ba 100644 --- a/Sources/NIOHTTP2/HTTP2FrameConvertible.swift +++ b/Sources/NIOHTTP2/HTTP2FrameConvertible.swift @@ -13,6 +13,9 @@ //===----------------------------------------------------------------------===// protocol HTTP2FrameConvertible { + /// Initialize `Self` from an `HTTP2Frame`. + init(http2Frame: HTTP2Frame) + /// Makes an `HTTPFrame` with the given `streamID`. /// /// - Parameter streamID: The `streamID` to use when constructing the frame. @@ -23,3 +26,39 @@ protocol HTTP2FramePayloadConvertible { /// Makes a `HTTP2Frame.FramePayload`. var payload: HTTP2Frame.FramePayload { get } } + +extension HTTP2FrameConvertible where Self: HTTP2FramePayloadConvertible { + /// A shorthand heuristic for how many bytes we assume a frame consumes on the wire. + /// + /// Here we concern ourselves only with per-stream frames: that is, `HEADERS`, `DATA`, + /// `WINDOW_UDPATE`, `RST_STREAM`, and I guess `PRIORITY`. As a simple heuristic we + /// hard code fixed lengths for fixed length frames, use a calculated length for + /// variable length frames, and just ignore encoded headers because it's not worth doing a better + /// job. + var estimatedFrameSize: Int { + let frameHeaderSize = 9 + + switch self.payload { + case .data(let d): + let paddingBytes = d.paddingBytes.map { $0 + 1 } ?? 0 + return d.data.readableBytes + paddingBytes + frameHeaderSize + case .headers(let h): + let paddingBytes = h.paddingBytes.map { $0 + 1 } ?? 0 + return paddingBytes + frameHeaderSize + case .priority: + return frameHeaderSize + 5 + case .pushPromise(let p): + // Like headers, this is variably size, and we just ignore the encoded headers because + // it's not worth having a heuristic. + let paddingBytes = p.paddingBytes.map { $0 + 1 } ?? 0 + return paddingBytes + frameHeaderSize + case .rstStream: + return frameHeaderSize + 4 + case .windowUpdate: + return frameHeaderSize + 4 + default: + // Unknown or unexpected control frame: say 9 bytes. + return frameHeaderSize + } + } +} diff --git a/Sources/NIOHTTP2/HTTP2StreamChannel.swift b/Sources/NIOHTTP2/HTTP2StreamChannel.swift index 358954df..ca05ba20 100644 --- a/Sources/NIOHTTP2/HTTP2StreamChannel.swift +++ b/Sources/NIOHTTP2/HTTP2StreamChannel.swift @@ -121,8 +121,13 @@ private enum StreamChannelState { } } +/// An `HTTP2StreamChannel` which deals in `HTTPFrame`s. +typealias HTTP2FrameBasedStreamChannel = HTTP2StreamChannel -final class HTTP2StreamChannel: Channel, ChannelCore { +/// An `HTTP2StreamChannel` which reads and writes `HTTPFrame.FramePayload`s. +typealias HTTP2PayloadBasedStreamChannel = HTTP2StreamChannel + +final class HTTP2StreamChannel: Channel, ChannelCore { internal init(allocator: ByteBufferAllocator, parent: Channel, multiplexer: HTTP2StreamMultiplexer, @@ -350,7 +355,7 @@ final class HTTP2StreamChannel: Channel, ChannelCore { /// In the future this buffer will be used to manage interactions with read() and even, one day, /// with flow control. For now, though, all this does is hold frames until we have set the /// channel up. - private var pendingReads: CircularBuffer = CircularBuffer(initialCapacity: 8) + private var pendingReads: CircularBuffer = CircularBuffer(initialCapacity: 8) /// Whether `autoRead` is enabled. By default, all `HTTP2StreamChannel` objects inherit their `autoRead` /// state from their parent. @@ -364,7 +369,7 @@ final class HTTP2StreamChannel: Channel, ChannelCore { /// /// To correctly respect flushes, we deliberately withold frames from the parent channel until this /// stream is flushed, at which time we deliver them all. This buffer holds the pending ones. - private var pendingWrites: MarkedCircularBuffer<(HTTP2Frame, EventLoopPromise?)> = MarkedCircularBuffer(initialCapacity: 8) + private var pendingWrites: MarkedCircularBuffer<(Message, EventLoopPromise?)> = MarkedCircularBuffer(initialCapacity: 8) /// A list node used to hold stream channels. internal var streamChannelListNode: StreamChannelListNode = StreamChannelListNode() @@ -390,13 +395,13 @@ final class HTTP2StreamChannel: Channel, ChannelCore { return } - let frame = self.unwrapData(data, as: HTTP2Frame.self) + let outbound = self.unwrapData(data, as: Message.self) // We need a promise to attach our flow control callback to. // Regardless of whether the write succeeded or failed, we don't count // the bytes any longer. let promise = promise ?? self.eventLoop.makePromise() - let writeSize = frame.bufferBytes + let writeSize = outbound.estimatedFrameSize // Right now we deal with this math by just attaching a callback to all promises. This is going // to be annoyingly expensive, but for now it's the most straightforward approach. @@ -405,7 +410,7 @@ final class HTTP2StreamChannel: Channel, ChannelCore { self.changeWritability(to: value) } } - self.pendingWrites.append((frame, promise)) + self.pendingWrites.append((outbound, promise)) // Ok, we can make an outcall now, which means we can safely deal with the flow control. if case .changed(newValue: let value) = self.writabilityManager.bufferedBytes(writeSize) { @@ -511,7 +516,11 @@ final class HTTP2StreamChannel: Channel, ChannelCore { self.eventLoop.execute { self.removeHandlers(channel: self) self.closePromise.succeed(()) - self.multiplexer.childChannelClosed(MultiplexerAbstractChannel(self)) + if let streamID = self.streamID { + self.multiplexer.childChannelClosed(streamID: streamID) + } else { + self.multiplexer.childChannelClosed(channelID: ObjectIdentifier(self)) + } } } @@ -532,7 +541,11 @@ final class HTTP2StreamChannel: Channel, ChannelCore { self.eventLoop.execute { self.removeHandlers(channel: self) self.closePromise.fail(error) - self.multiplexer.childChannelClosed(MultiplexerAbstractChannel(self)) + if let streamID = self.streamID { + self.multiplexer.childChannelClosed(streamID: streamID) + } else { + self.multiplexer.childChannelClosed(channelID: ObjectIdentifier(self)) + } } } @@ -606,9 +619,16 @@ private extension HTTP2StreamChannel { return } + // Get a streamID from the multiplexer if we haven't got one already. + if self.streamID == nil { + self.streamID = self.multiplexer.requestStreamID() + } + while self.pendingWrites.hasMark { - let write = self.pendingWrites.removeFirst() - self.receiveOutboundFrame(write.0, promise: write.1) + let (outbound, promise) = self.pendingWrites.removeFirst() + // This unwrap is okay: we just ensured that `self.streamID` was set above. + let frame = outbound.makeHTTP2Frame(streamID: self.streamID!) + self.receiveOutboundFrame(frame, promise: promise) } self.multiplexer.childChannelFlush() } @@ -634,10 +654,12 @@ internal extension HTTP2StreamChannel { return } + let message = Message(http2Frame: frame) + if self.unsatisfiedRead { // We don't need to account for this frame in the window manager: it's being delivered // straight into the pipeline. - self.pipeline.fireChannelRead(NIOAny(frame)) + self.pipeline.fireChannelRead(NIOAny(message)) } else { // Record the size of the frame so that when we receive a window update event our // calculation on whether we emit a WINDOW_UPDATE frame is based on the bytes we have @@ -645,7 +667,7 @@ internal extension HTTP2StreamChannel { if case .data(let dataPayload) = frame.payload { self.windowManager.bufferedFrameReceived(size: dataPayload.data.readableBytes) } - self.pendingReads.append(frame) + self.pendingReads.append(message) } } @@ -744,39 +766,3 @@ extension HTTP2StreamChannel { return "HTTP2StreamChannel(streamID: \(String(describing: self.streamID)), isActive: \(self.isActive), isWritable: \(self.isWritable))" } } - -extension HTTP2Frame { - /// A shorthand heuristic for how many bytes we assume a frame consumes on the wire. - /// - /// Here we concern ourselves only with per-stream frames: that is, `HEADERS`, `DATA`, - /// `WINDOW_UDPATE`, `RST_STREAM`, and I guess `PRIORITY`. As a simple heuristic we - /// hard code fixed lengths for fixed length frames, use a calculated length for - /// variable length frames, and just ignore encoded headers because it's not worth doing a better - /// job. - fileprivate var bufferBytes: Int { - let frameHeaderSize = 9 - - switch self.payload { - case .data(let d): - let paddingBytes = d.paddingBytes.map { $0 + 1 } ?? 0 - return d.data.readableBytes + paddingBytes + frameHeaderSize - case .headers(let h): - let paddingBytes = h.paddingBytes.map { $0 + 1 } ?? 0 - return paddingBytes + frameHeaderSize - case .priority: - return frameHeaderSize + 5 - case .pushPromise(let p): - // Like headers, this is variably size, and we just ignore the encoded headers because - // it's not worth having a heuristic. - let paddingBytes = p.paddingBytes.map { $0 + 1 } ?? 0 - return paddingBytes + frameHeaderSize - case .rstStream: - return frameHeaderSize + 4 - case .windowUpdate: - return frameHeaderSize + 4 - default: - // Unknown or unexpected control frame: say 9 bytes. - return frameHeaderSize - } - } -} diff --git a/Sources/NIOHTTP2/HTTP2StreamMultiplexer.swift b/Sources/NIOHTTP2/HTTP2StreamMultiplexer.swift index c9b0d77a..62d9db8e 100644 --- a/Sources/NIOHTTP2/HTTP2StreamMultiplexer.swift +++ b/Sources/NIOHTTP2/HTTP2StreamMultiplexer.swift @@ -318,12 +318,12 @@ extension HTTP2StreamMultiplexer { // MARK:- Child to parent calls extension HTTP2StreamMultiplexer { - internal func childChannelClosed(_ channel: MultiplexerAbstractChannel) { - if let streamID = channel.streamID { - self.streams.removeValue(forKey: streamID) - } else { - preconditionFailure("Child channels always have stream IDs right now.") - } + internal func childChannelClosed(streamID: HTTP2StreamID) { + self.streams.removeValue(forKey: streamID) + } + + internal func childChannelClosed(channelID: ObjectIdentifier) { + preconditionFailure("We don't currently support closing channels by 'channelID'") } internal func childChannelWrite(_ frame: HTTP2Frame, promise: EventLoopPromise?) { diff --git a/Sources/NIOHTTP2/MultiplexerAbstractChannel.swift b/Sources/NIOHTTP2/MultiplexerAbstractChannel.swift index ec345427..bd14ac63 100644 --- a/Sources/NIOHTTP2/MultiplexerAbstractChannel.swift +++ b/Sources/NIOHTTP2/MultiplexerAbstractChannel.swift @@ -41,15 +41,12 @@ struct MultiplexerAbstractChannel { outboundBytesHighWatermark: outboundBytesHighWatermark, outboundBytesLowWatermark: outboundBytesLowWatermark)) } - - init(_ channel: HTTP2StreamChannel) { - self.baseChannel = .frameBased(channel) - } } extension MultiplexerAbstractChannel { enum BaseChannel { - case frameBased(HTTP2StreamChannel) + case frameBased(HTTP2FrameBasedStreamChannel) + case payloadBased(HTTP2PayloadBasedStreamChannel) } } @@ -59,6 +56,17 @@ extension MultiplexerAbstractChannel { switch self.baseChannel { case .frameBased(let base): return base.streamID + case .payloadBased(let base): + return base.streamID + } + } + + var channelID: ObjectIdentifier { + switch self.baseChannel { + case .frameBased(let base): + return ObjectIdentifier(base) + case .payloadBased(let base): + return ObjectIdentifier(base) } } @@ -66,6 +74,8 @@ extension MultiplexerAbstractChannel { switch self.baseChannel { case .frameBased(let base): return base.inList + case .payloadBased(let base): + return base.inList } } @@ -74,12 +84,16 @@ extension MultiplexerAbstractChannel { switch self.baseChannel { case .frameBased(let base): return base.streamChannelListNode + case .payloadBased(let base): + return base.streamChannelListNode } } nonmutating set { switch self.baseChannel { case .frameBased(let base): base.streamChannelListNode = newValue + case .payloadBased(let base): + base.streamChannelListNode = newValue } } } @@ -88,6 +102,8 @@ extension MultiplexerAbstractChannel { switch self.baseChannel { case .frameBased(let base): base.configure(initializer: initializer, userPromise: promise) + case .payloadBased: + fatalError("Can't configure a payload based channel with this initializer.") } } @@ -95,6 +111,8 @@ extension MultiplexerAbstractChannel { switch self.baseChannel { case .frameBased(let base): base.performActivation() + case .payloadBased(let base): + base.performActivation() } } @@ -102,6 +120,8 @@ extension MultiplexerAbstractChannel { switch self.baseChannel { case .frameBased(let base): base.networkActivationReceived() + case .payloadBased(let base): + base.networkActivationReceived() } } @@ -109,6 +129,8 @@ extension MultiplexerAbstractChannel { switch self.baseChannel { case .frameBased(let base): base.receiveInboundFrame(frame) + case .payloadBased(let base): + base.receiveInboundFrame(frame) } } @@ -116,6 +138,8 @@ extension MultiplexerAbstractChannel { switch self.baseChannel { case .frameBased(let base): base.receiveParentChannelReadComplete() + case .payloadBased(let base): + base.receiveParentChannelReadComplete() } } @@ -123,6 +147,8 @@ extension MultiplexerAbstractChannel { switch self.baseChannel { case .frameBased(let base): base.initialWindowSizeChanged(delta: delta) + case .payloadBased(let base): + base.initialWindowSizeChanged(delta: delta) } } @@ -130,6 +156,8 @@ extension MultiplexerAbstractChannel { switch self.baseChannel { case .frameBased(let base): base.receiveWindowUpdatedEvent(windowSize) + case .payloadBased(let base): + base.receiveWindowUpdatedEvent(windowSize) } } @@ -137,6 +165,8 @@ extension MultiplexerAbstractChannel { switch self.baseChannel { case .frameBased(let base): base.parentChannelWritabilityChanged(newValue: newValue) + case .payloadBased(let base): + base.parentChannelWritabilityChanged(newValue: newValue) } } @@ -144,6 +174,8 @@ extension MultiplexerAbstractChannel { switch self.baseChannel { case .frameBased(let base): base.receiveStreamClosed(reason) + case .payloadBased(let base): + base.receiveStreamClosed(reason) } } } @@ -153,6 +185,10 @@ extension MultiplexerAbstractChannel: Equatable { switch (lhs.baseChannel, rhs.baseChannel) { case (.frameBased(let lhs), .frameBased(let rhs)): return lhs === rhs + case (.payloadBased(let lhs), .payloadBased(let rhs)): + return lhs === rhs + case (.frameBased, .payloadBased), (.payloadBased, .frameBased): + return false } } } @@ -162,6 +198,8 @@ extension MultiplexerAbstractChannel: Hashable { switch self.baseChannel { case .frameBased(let base): hasher.combine(ObjectIdentifier(base)) + case .payloadBased(let base): + hasher.combine(ObjectIdentifier(base)) } } }