diff --git a/.changesets/fix_geal_fix_defer_compression.md b/.changesets/fix_geal_fix_defer_compression.md new file mode 100644 index 0000000000..2873379764 --- /dev/null +++ b/.changesets/fix_geal_fix_defer_compression.md @@ -0,0 +1,7 @@ +### Fix compression for deferred responses ([Issue #1572](https://github.com/apollographql/router/issues/1572)) + +We replace tower-http's `CompressionLayer` with a custom stream transformation. This is necessary because tower-http uses async-compression, which buffers data until the end of the stream to then write it, ensuring a better compression. This is incompatible with the multipart protocol for `@defer`, which requires chunks to be sent as soon as possible. So we need to compress them independently. + +This extracts parts of the codec module of async-compression, which so far is not public, and makes a streaming wrapper above it that flushes the compressed data on every response in the stream. + +By [@Geal](https://github.com/Geal) in https://github.com/apollographql/router/pull/2986 \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 5b4a1d983a..3e18d4460c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -286,6 +286,7 @@ dependencies = [ "axum", "backtrace", "base64 0.20.0", + "brotli", "buildstructor 0.5.2", "bytes", "ci_info", @@ -401,6 +402,8 @@ dependencies = [ "wiremock", "wsl", "yaml-rust", + "zstd", + "zstd-safe", ] [[package]] @@ -7073,3 +7076,33 @@ dependencies = [ "quote", "syn 2.0.13", ] + +[[package]] +name = "zstd" +version = "0.12.3+zstd.1.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76eea132fb024e0e13fd9c2f5d5d595d8a967aa72382ac2f9d39fcc95afd0806" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "6.0.5+zstd.1.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d56d9e60b4b1758206c238a10165fbcae3ca37b01744e394c463463f6529d23b" +dependencies = [ + "libc", + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.8+zstd.1.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5556e6ee25d32df2586c098bbfa278803692a20d0ab9565e049480d52707ec8c" +dependencies = [ + "cc", + "libc", + "pkg-config", +] diff --git a/apollo-router/Cargo.toml b/apollo-router/Cargo.toml index 4b7c34a917..8b6411dfd7 100644 --- a/apollo-router/Cargo.toml +++ b/apollo-router/Cargo.toml @@ -203,6 +203,10 @@ yaml-rust = "0.4.5" wsl = "0.1.0" tokio-rustls = "0.23.4" http-serde = "1.1.2" +memchr = "2.5.0" +brotli = "3.3.4" +zstd = "0.12.3" +zstd-safe = "6.0.5" [target.'cfg(macos)'.dependencies] uname = "0.1.1" diff --git a/apollo-router/src/axum_factory/axum_http_server_factory.rs b/apollo-router/src/axum_factory/axum_http_server_factory.rs index 74c0ac1601..398f686a4a 100644 --- a/apollo-router/src/axum_factory/axum_http_server_factory.rs +++ b/apollo-router/src/axum_factory/axum_http_server_factory.rs @@ -19,6 +19,9 @@ use futures::channel::oneshot; use futures::future::join; use futures::future::join_all; use futures::prelude::*; +use http::header::ACCEPT_ENCODING; +use http::header::CONTENT_ENCODING; +use http::HeaderValue; use http::Request; use http_body::combinators::UnsyncBoxBody; use hyper::Body; @@ -32,10 +35,6 @@ use tokio_rustls::TlsAcceptor; use tower::service_fn; use tower::BoxError; use tower::ServiceExt; -use tower_http::compression::predicate::NotForContentType; -use tower_http::compression::CompressionLayer; -use tower_http::compression::DefaultPredicate; -use tower_http::compression::Predicate; use tower_http::trace::TraceLayer; use super::listeners::ensure_endpoints_consistency; @@ -45,6 +44,7 @@ use super::listeners::ListenersAndRouters; use super::utils::decompress_request_body; use super::utils::PropagatingMakeSpan; use super::ListenAddrAndRouter; +use crate::axum_factory::compression::Compressor; use crate::axum_factory::listeners::get_extra_listeners; use crate::axum_factory::listeners::serve_router_on_listen_addr; use crate::configuration::Configuration; @@ -329,12 +329,7 @@ where )) .layer(TraceLayer::new_for_http().make_span_with(PropagatingMakeSpan { entitlement })) .layer(Extension(service_factory)) - .layer(cors) - // Compress the response body, except for multipart responses such as with `@defer`. - // This is a work-around for https://github.com/apollographql/router/issues/1572 - .layer(CompressionLayer::new().compress_when( - DefaultPredicate::new().and(NotForContentType::const_new("multipart/")), - )); + .layer(cors); let route = endpoints_on_main_listener .into_iter() @@ -434,6 +429,11 @@ async fn handle_graphql( let request: router::Request = http_request.into(); let context = request.context.clone(); + let accept_encoding = request + .router_request + .headers() + .get(ACCEPT_ENCODING) + .cloned(); let res = service.oneshot(request).await; let dur = context.busy_time().await; @@ -467,7 +467,24 @@ async fn handle_graphql( } Ok(response) => { tracing::info!(counter.apollo_router_session_count_active = -1,); - response.response.into_response() + let (mut parts, body) = response.response.into_parts(); + + let opt_compressor = accept_encoding + .as_ref() + .and_then(|value| value.to_str().ok()) + .and_then(|v| Compressor::new(v.split(',').map(|s| s.trim()))); + let body = match opt_compressor { + None => body, + Some(compressor) => { + parts.headers.insert( + CONTENT_ENCODING, + HeaderValue::from_static(compressor.content_encoding()), + ); + Body::wrap_stream(compressor.process(body)) + } + }; + + http::Response::from_parts(parts, body).into_response() } } } diff --git a/apollo-router/src/axum_factory/compression/codec/brotli/encoder.rs b/apollo-router/src/axum_factory/compression/codec/brotli/encoder.rs new file mode 100644 index 0000000000..ef877335ac --- /dev/null +++ b/apollo-router/src/axum_factory/compression/codec/brotli/encoder.rs @@ -0,0 +1,112 @@ +// All code from this module is extracted from https://github.com/Nemo157/async-compression and is under MIT or Apache-2 licence +// it will be removed when we find a long lasting solution to https://github.com/Nemo157/async-compression/issues/154 +use std::fmt; +use std::io::Error; +use std::io::ErrorKind; +use std::io::Result; + +use brotli::enc::backward_references::BrotliEncoderParams; +use brotli::enc::encode::BrotliEncoderCompressStream; +use brotli::enc::encode::BrotliEncoderCreateInstance; +use brotli::enc::encode::BrotliEncoderHasMoreOutput; +use brotli::enc::encode::BrotliEncoderIsFinished; +use brotli::enc::encode::BrotliEncoderOperation; +use brotli::enc::encode::BrotliEncoderStateStruct; +use brotli::enc::StandardAlloc; + +use crate::axum_factory::compression::codec::Encode; +use crate::axum_factory::compression::util::PartialBuffer; + +pub(crate) struct BrotliEncoder { + state: BrotliEncoderStateStruct, +} + +impl BrotliEncoder { + pub(crate) fn new(params: BrotliEncoderParams) -> Self { + let mut state = BrotliEncoderCreateInstance(StandardAlloc::default()); + state.params = params; + Self { state } + } + + fn encode( + &mut self, + input: &mut PartialBuffer>, + output: &mut PartialBuffer + AsMut<[u8]>>, + op: BrotliEncoderOperation, + ) -> Result<()> { + let in_buf = input.unwritten(); + let out_buf = output.unwritten_mut(); + + let mut input_len = 0; + let mut output_len = 0; + + if BrotliEncoderCompressStream( + &mut self.state, + op, + &mut in_buf.len(), + in_buf, + &mut input_len, + &mut out_buf.len(), + out_buf, + &mut output_len, + &mut None, + &mut |_, _, _, _| (), + ) <= 0 + { + return Err(Error::new(ErrorKind::Other, "brotli error")); + } + + input.advance(input_len); + output.advance(output_len); + + Ok(()) + } +} + +impl Encode for BrotliEncoder { + fn encode( + &mut self, + input: &mut PartialBuffer>, + output: &mut PartialBuffer + AsMut<[u8]>>, + ) -> Result<()> { + self.encode( + input, + output, + BrotliEncoderOperation::BROTLI_OPERATION_PROCESS, + ) + } + + fn flush( + &mut self, + output: &mut PartialBuffer + AsMut<[u8]>>, + ) -> Result { + self.encode( + &mut PartialBuffer::new(&[][..]), + output, + BrotliEncoderOperation::BROTLI_OPERATION_FLUSH, + )?; + + Ok(BrotliEncoderHasMoreOutput(&self.state) == 0) + } + + fn finish( + &mut self, + output: &mut PartialBuffer + AsMut<[u8]>>, + ) -> Result { + self.encode( + &mut PartialBuffer::new(&[][..]), + output, + BrotliEncoderOperation::BROTLI_OPERATION_FINISH, + )?; + + Ok(BrotliEncoderIsFinished(&self.state) == 1) + } +} + +impl fmt::Debug for BrotliEncoder { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("BrotliEncoder") + .field("compress", &"") + .finish() + } +} diff --git a/apollo-router/src/axum_factory/compression/codec/brotli/mod.rs b/apollo-router/src/axum_factory/compression/codec/brotli/mod.rs new file mode 100644 index 0000000000..1e71412652 --- /dev/null +++ b/apollo-router/src/axum_factory/compression/codec/brotli/mod.rs @@ -0,0 +1,5 @@ +// All code from this module is extracted from https://github.com/Nemo157/async-compression and is under MIT or Apache-2 licence +// it will be removed when we find a long lasting solution to https://github.com/Nemo157/async-compression/issues/154 +mod encoder; + +pub(crate) use self::encoder::BrotliEncoder; diff --git a/apollo-router/src/axum_factory/compression/codec/deflate/encoder.rs b/apollo-router/src/axum_factory/compression/codec/deflate/encoder.rs new file mode 100644 index 0000000000..88cac88903 --- /dev/null +++ b/apollo-router/src/axum_factory/compression/codec/deflate/encoder.rs @@ -0,0 +1,46 @@ +// All code from this module is extracted from https://github.com/Nemo157/async-compression and is under MIT or Apache-2 licence +// it will be removed when we find a long lasting solution to https://github.com/Nemo157/async-compression/issues/154 +use std::io::Result; + +use flate2::Compression; + +use crate::axum_factory::compression::codec::Encode; +use crate::axum_factory::compression::codec::FlateEncoder; +use crate::axum_factory::compression::util::PartialBuffer; + +#[derive(Debug)] +pub(crate) struct DeflateEncoder { + inner: FlateEncoder, +} + +impl DeflateEncoder { + pub(crate) fn new(level: Compression) -> Self { + Self { + inner: FlateEncoder::new(level, false), + } + } +} + +impl Encode for DeflateEncoder { + fn encode( + &mut self, + input: &mut PartialBuffer>, + output: &mut PartialBuffer + AsMut<[u8]>>, + ) -> Result<()> { + self.inner.encode(input, output) + } + + fn flush( + &mut self, + output: &mut PartialBuffer + AsMut<[u8]>>, + ) -> Result { + self.inner.flush(output) + } + + fn finish( + &mut self, + output: &mut PartialBuffer + AsMut<[u8]>>, + ) -> Result { + self.inner.finish(output) + } +} diff --git a/apollo-router/src/axum_factory/compression/codec/deflate/mod.rs b/apollo-router/src/axum_factory/compression/codec/deflate/mod.rs new file mode 100644 index 0000000000..5a2d24c4be --- /dev/null +++ b/apollo-router/src/axum_factory/compression/codec/deflate/mod.rs @@ -0,0 +1,5 @@ +// All code from this module is extracted from https://github.com/Nemo157/async-compression and is under MIT or Apache-2 licence +// it will be removed when we find a long lasting solution to https://github.com/Nemo157/async-compression/issues/154 +mod encoder; + +pub(crate) use self::encoder::DeflateEncoder; diff --git a/apollo-router/src/axum_factory/compression/codec/flate/encoder.rs b/apollo-router/src/axum_factory/compression/codec/flate/encoder.rs new file mode 100644 index 0000000000..e264b874ff --- /dev/null +++ b/apollo-router/src/axum_factory/compression/codec/flate/encoder.rs @@ -0,0 +1,110 @@ +// All code from this module is extracted from https://github.com/Nemo157/async-compression and is under MIT or Apache-2 licence +// it will be removed when we find a long lasting solution to https://github.com/Nemo157/async-compression/issues/154 +use std::io::Error; +use std::io::ErrorKind; +use std::io::Result; + +use flate2::Compress; +use flate2::Compression; +use flate2::FlushCompress; +use flate2::Status; + +use crate::axum_factory::compression::codec::Encode; +use crate::axum_factory::compression::util::PartialBuffer; + +#[derive(Debug)] +pub(crate) struct FlateEncoder { + compress: Compress, + flushed: bool, +} + +impl FlateEncoder { + pub(crate) fn new(level: Compression, zlib_header: bool) -> Self { + Self { + compress: Compress::new(level, zlib_header), + flushed: true, + } + } + + fn encode( + &mut self, + input: &mut PartialBuffer>, + output: &mut PartialBuffer + AsMut<[u8]>>, + flush: FlushCompress, + ) -> Result { + let prior_in = self.compress.total_in(); + let prior_out = self.compress.total_out(); + + let status = self + .compress + .compress(input.unwritten(), output.unwritten_mut(), flush)?; + + input.advance((self.compress.total_in() - prior_in) as usize); + output.advance((self.compress.total_out() - prior_out) as usize); + + Ok(status) + } +} + +impl Encode for FlateEncoder { + fn encode( + &mut self, + input: &mut PartialBuffer>, + output: &mut PartialBuffer + AsMut<[u8]>>, + ) -> Result<()> { + self.flushed = false; + match self.encode(input, output, FlushCompress::None)? { + Status::Ok => Ok(()), + Status::StreamEnd => unreachable!(), + Status::BufError => Err(Error::new(ErrorKind::Other, "unexpected BufError")), + } + } + + fn flush( + &mut self, + output: &mut PartialBuffer + AsMut<[u8]>>, + ) -> Result { + // We need to keep track of whether we've already flushed otherwise we'll just keep writing + // out sync blocks continuously and probably never complete flushing. + if self.flushed { + return Ok(true); + } + + self.encode( + &mut PartialBuffer::new(&[][..]), + output, + FlushCompress::Sync, + )?; + + loop { + let old_len = output.written().len(); + self.encode( + &mut PartialBuffer::new(&[][..]), + output, + FlushCompress::None, + )?; + if output.written().len() == old_len { + break; + } + } + + self.flushed = true; + Ok(!output.unwritten().is_empty()) + } + + fn finish( + &mut self, + output: &mut PartialBuffer + AsMut<[u8]>>, + ) -> Result { + self.flushed = false; + match self.encode( + &mut PartialBuffer::new(&[][..]), + output, + FlushCompress::Finish, + )? { + Status::Ok => Ok(false), + Status::StreamEnd => Ok(true), + Status::BufError => Err(Error::new(ErrorKind::Other, "unexpected BufError")), + } + } +} diff --git a/apollo-router/src/axum_factory/compression/codec/flate/mod.rs b/apollo-router/src/axum_factory/compression/codec/flate/mod.rs new file mode 100644 index 0000000000..215623803c --- /dev/null +++ b/apollo-router/src/axum_factory/compression/codec/flate/mod.rs @@ -0,0 +1,5 @@ +// All code from this module is extracted from https://github.com/Nemo157/async-compression and is under MIT or Apache-2 licence +// it will be removed when we find a long lasting solution to https://github.com/Nemo157/async-compression/issues/154 +mod encoder; + +pub(crate) use self::encoder::FlateEncoder; diff --git a/apollo-router/src/axum_factory/compression/codec/gzip/encoder.rs b/apollo-router/src/axum_factory/compression/codec/gzip/encoder.rs new file mode 100644 index 0000000000..9203c9103b --- /dev/null +++ b/apollo-router/src/axum_factory/compression/codec/gzip/encoder.rs @@ -0,0 +1,170 @@ +// All code from this module is extracted from https://github.com/Nemo157/async-compression and is under MIT or Apache-2 licence +// it will be removed when we find a long lasting solution to https://github.com/Nemo157/async-compression/issues/154 +use std::io::Result; + +use flate2::Compression; +use flate2::Crc; + +use crate::axum_factory::compression::codec::Encode; +use crate::axum_factory::compression::codec::FlateEncoder; +use crate::axum_factory::compression::util::PartialBuffer; + +#[derive(Debug)] +enum State { + Header(PartialBuffer>), + Encoding, + Footer(PartialBuffer>), + Done, +} + +#[derive(Debug)] +pub(crate) struct GzipEncoder { + inner: FlateEncoder, + crc: Crc, + state: State, +} + +fn header(level: Compression) -> Vec { + let level_byte = if level.level() >= Compression::best().level() { + 0x02 + } else if level.level() <= Compression::fast().level() { + 0x04 + } else { + 0x00 + }; + + vec![0x1f, 0x8b, 0x08, 0, 0, 0, 0, 0, level_byte, 0xff] +} + +impl GzipEncoder { + pub(crate) fn new(level: Compression) -> Self { + Self { + inner: FlateEncoder::new(level, false), + crc: Crc::new(), + state: State::Header(header(level).into()), + } + } + + fn footer(&mut self) -> Vec { + let mut output = Vec::with_capacity(8); + + output.extend(&self.crc.sum().to_le_bytes()); + output.extend(&self.crc.amount().to_le_bytes()); + + output + } +} + +impl Encode for GzipEncoder { + fn encode( + &mut self, + input: &mut PartialBuffer>, + output: &mut PartialBuffer + AsMut<[u8]>>, + ) -> Result<()> { + loop { + match &mut self.state { + State::Header(header) => { + output.copy_unwritten_from(&mut *header); + + if header.unwritten().is_empty() { + self.state = State::Encoding; + } + } + + State::Encoding => { + let prior_written = input.written().len(); + self.inner.encode(input, output)?; + self.crc.update(&input.written()[prior_written..]); + } + + State::Footer(_) | State::Done => panic!("encode after complete"), + }; + + if input.unwritten().is_empty() || output.unwritten().is_empty() { + return Ok(()); + } + } + } + + fn flush( + &mut self, + output: &mut PartialBuffer + AsMut<[u8]>>, + ) -> Result { + loop { + let done = match &mut self.state { + State::Header(header) => { + output.copy_unwritten_from(&mut *header); + + if header.unwritten().is_empty() { + self.state = State::Encoding; + } + false + } + + State::Encoding => self.inner.flush(output)?, + + State::Footer(footer) => { + output.copy_unwritten_from(&mut *footer); + + if footer.unwritten().is_empty() { + self.state = State::Done; + true + } else { + false + } + } + + State::Done => true, + }; + + if done { + return Ok(true); + } + + if output.unwritten().is_empty() { + return Ok(false); + } + } + } + + fn finish( + &mut self, + output: &mut PartialBuffer + AsMut<[u8]>>, + ) -> Result { + loop { + match &mut self.state { + State::Header(header) => { + output.copy_unwritten_from(&mut *header); + + if header.unwritten().is_empty() { + self.state = State::Encoding; + } + } + + State::Encoding => { + if self.inner.finish(output)? { + self.state = State::Footer(self.footer().into()); + } + } + + State::Footer(footer) => { + output.copy_unwritten_from(&mut *footer); + + if footer.unwritten().is_empty() { + self.state = State::Done; + } + } + + State::Done => {} + }; + + if let State::Done = self.state { + return Ok(true); + } + + if output.unwritten().is_empty() { + return Ok(false); + } + } + } +} diff --git a/apollo-router/src/axum_factory/compression/codec/gzip/header.rs b/apollo-router/src/axum_factory/compression/codec/gzip/header.rs new file mode 100644 index 0000000000..754dcaa012 --- /dev/null +++ b/apollo-router/src/axum_factory/compression/codec/gzip/header.rs @@ -0,0 +1,167 @@ +#![allow(dead_code)] +// All code from this module is extracted from https://github.com/Nemo157/async-compression and is under MIT or Apache-2 licence +// it will be removed when we find a long lasting solution to https://github.com/Nemo157/async-compression/issues/154 +use std::io::Error; +use std::io::ErrorKind; +use std::io::Result; + +use crate::axum_factory::compression::util::PartialBuffer; + +#[derive(Debug, Default)] +struct Flags { + ascii: bool, + crc: bool, + extra: bool, + filename: bool, + comment: bool, +} + +#[derive(Debug, Default)] +pub(super) struct Header { + flags: Flags, +} + +#[derive(Debug)] +enum State { + Fixed(PartialBuffer<[u8; 10]>), + ExtraLen(PartialBuffer<[u8; 2]>), + Extra(PartialBuffer>), + Filename(Vec), + Comment(Vec), + Crc(PartialBuffer<[u8; 2]>), + Done, +} + +impl Default for State { + fn default() -> Self { + State::Fixed(<_>::default()) + } +} + +#[derive(Debug, Default)] +pub(super) struct Parser { + state: State, + header: Header, +} + +impl Header { + fn parse(input: &[u8; 10]) -> Result { + if input[0..3] != [0x1f, 0x8b, 0x08] { + return Err(Error::new(ErrorKind::InvalidData, "Invalid gzip header")); + } + + let flag = input[3]; + + let flags = Flags { + ascii: (flag & 0b0000_0001) != 0, + crc: (flag & 0b0000_0010) != 0, + extra: (flag & 0b0000_0100) != 0, + filename: (flag & 0b0000_1000) != 0, + comment: (flag & 0b0001_0000) != 0, + }; + + Ok(Header { flags }) + } +} + +impl Parser { + pub(super) fn input( + &mut self, + input: &mut PartialBuffer>, + ) -> Result> { + loop { + match &mut self.state { + State::Fixed(data) => { + data.copy_unwritten_from(input); + + if data.unwritten().is_empty() { + self.header = Header::parse(&data.take().into_inner())?; + self.state = State::ExtraLen(<_>::default()); + } else { + return Ok(None); + } + } + + State::ExtraLen(data) => { + if !self.header.flags.extra { + self.state = State::Filename(<_>::default()); + continue; + } + + data.copy_unwritten_from(input); + + if data.unwritten().is_empty() { + let len = u16::from_le_bytes(data.take().into_inner()); + self.state = State::Extra(vec![0; usize::from(len)].into()); + } else { + return Ok(None); + } + } + + State::Extra(data) => { + data.copy_unwritten_from(input); + + if data.unwritten().is_empty() { + self.state = State::Filename(<_>::default()); + } else { + return Ok(None); + } + } + + State::Filename(data) => { + if !self.header.flags.filename { + self.state = State::Comment(<_>::default()); + continue; + } + + if let Some(len) = memchr::memchr(0, input.unwritten()) { + data.extend_from_slice(&input.unwritten()[..len]); + input.advance(len + 1); + self.state = State::Comment(<_>::default()); + } else { + data.extend_from_slice(input.unwritten()); + input.advance(input.unwritten().len()); + return Ok(None); + } + } + + State::Comment(data) => { + if !self.header.flags.comment { + self.state = State::Crc(<_>::default()); + continue; + } + + if let Some(len) = memchr::memchr(0, input.unwritten()) { + data.extend_from_slice(&input.unwritten()[..len]); + input.advance(len + 1); + self.state = State::Crc(<_>::default()); + } else { + data.extend_from_slice(input.unwritten()); + input.advance(input.unwritten().len()); + return Ok(None); + } + } + + State::Crc(data) => { + if !self.header.flags.crc { + self.state = State::Done; + return Ok(Some(std::mem::take(&mut self.header))); + } + + data.copy_unwritten_from(input); + + if data.unwritten().is_empty() { + self.state = State::Done; + return Ok(Some(std::mem::take(&mut self.header))); + } else { + return Ok(None); + } + } + + State::Done => { + panic!("parser used after done"); + } + }; + } + } +} diff --git a/apollo-router/src/axum_factory/compression/codec/gzip/mod.rs b/apollo-router/src/axum_factory/compression/codec/gzip/mod.rs new file mode 100644 index 0000000000..77d5604f1c --- /dev/null +++ b/apollo-router/src/axum_factory/compression/codec/gzip/mod.rs @@ -0,0 +1,6 @@ +// All code from this module is extracted from https://github.com/Nemo157/async-compression and is under MIT or Apache-2 licence +// it will be removed when we find a long lasting solution to https://github.com/Nemo157/async-compression/issues/154 +mod encoder; +mod header; + +pub(crate) use self::encoder::GzipEncoder; diff --git a/apollo-router/src/axum_factory/compression/codec/mod.rs b/apollo-router/src/axum_factory/compression/codec/mod.rs new file mode 100644 index 0000000000..71e801b313 --- /dev/null +++ b/apollo-router/src/axum_factory/compression/codec/mod.rs @@ -0,0 +1,36 @@ +// All code from this module is extracted from https://github.com/Nemo157/async-compression and is under MIT or Apache-2 licence +// it will be removed when we find a long lasting solution to https://github.com/Nemo157/async-compression/issues/154 +use std::io::Result; + +use super::util::PartialBuffer; + +mod brotli; +mod deflate; +mod flate; +mod gzip; +//mod zlib; +mod zstd; + +pub(crate) use self::brotli::BrotliEncoder; +pub(crate) use self::deflate::DeflateEncoder; +pub(crate) use self::flate::FlateEncoder; +pub(crate) use self::gzip::GzipEncoder; +pub(crate) use self::zstd::ZstdEncoder; + +pub(crate) trait Encode { + fn encode( + &mut self, + input: &mut PartialBuffer>, + output: &mut PartialBuffer + AsMut<[u8]>>, + ) -> Result<()>; + + /// Returns whether the internal buffers are flushed + fn flush(&mut self, output: &mut PartialBuffer + AsMut<[u8]>>) + -> Result; + + /// Returns whether the internal buffers are flushed and the end of the stream is written + fn finish( + &mut self, + output: &mut PartialBuffer + AsMut<[u8]>>, + ) -> Result; +} diff --git a/apollo-router/src/axum_factory/compression/codec/zstd/encoder.rs b/apollo-router/src/axum_factory/compression/codec/zstd/encoder.rs new file mode 100644 index 0000000000..fe6230bf72 --- /dev/null +++ b/apollo-router/src/axum_factory/compression/codec/zstd/encoder.rs @@ -0,0 +1,61 @@ +// All code from this module is extracted from https://github.com/Nemo157/async-compression and is under MIT or Apache-2 licence +// it will be removed when we find a long lasting solution to https://github.com/Nemo157/async-compression/issues/154 +use std::io::Result; + +use zstd::stream::raw::Encoder; +use zstd::stream::raw::Operation; + +use crate::axum_factory::compression::codec::Encode; +use crate::axum_factory::compression::unshared::Unshared; +use crate::axum_factory::compression::util::PartialBuffer; + +#[derive(Debug)] +pub(crate) struct ZstdEncoder { + encoder: Unshared>, +} + +impl ZstdEncoder { + pub(crate) fn new(level: i32) -> Self { + Self { + encoder: Unshared::new(Encoder::new(level).unwrap()), + } + } +} + +impl Encode for ZstdEncoder { + fn encode( + &mut self, + input: &mut PartialBuffer>, + output: &mut PartialBuffer + AsMut<[u8]>>, + ) -> Result<()> { + let status = self + .encoder + .get_mut() + .run_on_buffers(input.unwritten(), output.unwritten_mut())?; + input.advance(status.bytes_read); + output.advance(status.bytes_written); + Ok(()) + } + + fn flush( + &mut self, + output: &mut PartialBuffer + AsMut<[u8]>>, + ) -> Result { + let mut out_buf = zstd_safe::OutBuffer::around(output.unwritten_mut()); + let bytes_left = self.encoder.get_mut().flush(&mut out_buf)?; + let len = out_buf.as_slice().len(); + output.advance(len); + Ok(bytes_left == 0) + } + + fn finish( + &mut self, + output: &mut PartialBuffer + AsMut<[u8]>>, + ) -> Result { + let mut out_buf = zstd_safe::OutBuffer::around(output.unwritten_mut()); + let bytes_left = self.encoder.get_mut().finish(&mut out_buf, true)?; + let len = out_buf.as_slice().len(); + output.advance(len); + Ok(bytes_left == 0) + } +} diff --git a/apollo-router/src/axum_factory/compression/codec/zstd/mod.rs b/apollo-router/src/axum_factory/compression/codec/zstd/mod.rs new file mode 100644 index 0000000000..a99dd85331 --- /dev/null +++ b/apollo-router/src/axum_factory/compression/codec/zstd/mod.rs @@ -0,0 +1,5 @@ +// All code from this module is extracted from https://github.com/Nemo157/async-compression and is under MIT or Apache-2 licence +// it will be removed when we find a long lasting solution to https://github.com/Nemo157/async-compression/issues/154 +mod encoder; + +pub(crate) use self::encoder::ZstdEncoder; diff --git a/apollo-router/src/axum_factory/compression/mod.rs b/apollo-router/src/axum_factory/compression/mod.rs new file mode 100644 index 0000000000..38eb0c2e6d --- /dev/null +++ b/apollo-router/src/axum_factory/compression/mod.rs @@ -0,0 +1,183 @@ +use brotli::enc::BrotliEncoderParams; +use bytes::Bytes; +use bytes::BytesMut; +use flate2::Compression; +use futures::Stream; +use futures::StreamExt; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; +use tower::BoxError; + +use self::codec::BrotliEncoder; +use self::codec::DeflateEncoder; +use self::codec::Encode; +use self::codec::GzipEncoder; +use self::codec::ZstdEncoder; +use self::util::PartialBuffer; + +pub(crate) mod codec; +pub(crate) mod unshared; +pub(crate) mod util; + +pub(crate) enum Compressor { + Deflate(DeflateEncoder), + Gzip(GzipEncoder), + Brotli(Box), + Zstd(ZstdEncoder), +} + +impl Compressor { + pub(crate) fn new<'a, It: 'a>(it: It) -> Option + where + It: Iterator, + { + for s in it { + match s { + "gzip" => return Some(Compressor::Gzip(GzipEncoder::new(Compression::fast()))), + "deflate" => { + return Some(Compressor::Deflate( + DeflateEncoder::new(Compression::fast()), + )) + } + // FIXME: find the "fast" brotli encoder params + "br" => { + return Some(Compressor::Brotli(Box::new(BrotliEncoder::new( + BrotliEncoderParams::default(), + )))) + } + "zstd" => { + return Some(Compressor::Zstd(ZstdEncoder::new(zstd_safe::min_c_level()))) + } + _ => {} + } + } + None + } + + pub(crate) fn content_encoding(&self) -> &'static str { + match self { + Compressor::Deflate(_) => "deflate", + Compressor::Gzip(_) => "gzip", + Compressor::Brotli(_) => "br", + Compressor::Zstd(_) => "zstd", + } + } + + pub(crate) fn process( + mut self, + mut stream: hyper::Body, + ) -> impl Stream> +where { + let (tx, rx) = mpsc::channel(10); + + tokio::task::spawn(async move { + while let Some(data) = stream.next().await { + match data { + Err(e) => { + if (tx.send(Err(e.into())).await).is_err() { + return; + } + } + Ok(data) => { + let mut buf = BytesMut::zeroed(1024); + let mut written = 0usize; + + let mut partial_input = PartialBuffer::new(&*data); + loop { + let mut partial_output = PartialBuffer::new(&mut buf); + partial_output.advance(written); + + if let Err(e) = self.encode(&mut partial_input, &mut partial_output) { + let _ = tx.send(Err(e.into())).await; + return; + } + + written += partial_output.written().len(); + + if !partial_input.unwritten().is_empty() { + // there was not enough space in the output buffer to compress everything, + // so we resize and add more data + if partial_output.unwritten().is_empty() { + let _ = partial_output.into_inner(); + buf.reserve(written); + } + } else { + match self.flush(&mut partial_output) { + Err(e) => { + let _ = tx.send(Err(e.into())).await; + return; + } + Ok(_) => { + let len = partial_output.written().len(); + let _ = partial_output.into_inner(); + buf.resize(len, 0); + if (tx.send(Ok(buf.freeze())).await).is_err() { + return; + } + break; + } + } + } + } + } + } + } + + let buf = BytesMut::zeroed(64); + let mut partial_output = PartialBuffer::new(buf); + + match self.finish(&mut partial_output) { + Err(e) => { + let _ = tx.send(Err(e.into())).await; + } + Ok(_) => { + let len = partial_output.written().len(); + + let mut buf = partial_output.into_inner(); + buf.resize(len, 0); + let _ = tx.send(Ok(buf.freeze())).await; + } + } + }); + ReceiverStream::new(rx) + } +} + +impl Encode for Compressor { + fn encode( + &mut self, + input: &mut PartialBuffer>, + output: &mut PartialBuffer + AsMut<[u8]>>, + ) -> std::io::Result<()> { + match self { + Compressor::Deflate(e) => e.encode(input, output), + Compressor::Gzip(e) => e.encode(input, output), + Compressor::Brotli(e) => e.encode(input, output), + Compressor::Zstd(e) => e.encode(input, output), + } + } + + fn flush( + &mut self, + output: &mut PartialBuffer + AsMut<[u8]>>, + ) -> std::io::Result { + match self { + Compressor::Deflate(e) => e.flush(output), + Compressor::Gzip(e) => e.flush(output), + Compressor::Brotli(e) => e.flush(output), + Compressor::Zstd(e) => e.flush(output), + } + } + + fn finish( + &mut self, + output: &mut PartialBuffer + AsMut<[u8]>>, + ) -> std::io::Result { + match self { + Compressor::Deflate(e) => e.finish(output), + Compressor::Gzip(e) => e.finish(output), + Compressor::Brotli(e) => e.finish(output), + Compressor::Zstd(e) => e.finish(output), + } + } +} diff --git a/apollo-router/src/axum_factory/compression/unshared.rs b/apollo-router/src/axum_factory/compression/unshared.rs new file mode 100644 index 0000000000..b4b244f168 --- /dev/null +++ b/apollo-router/src/axum_factory/compression/unshared.rs @@ -0,0 +1,42 @@ +// All code from this module is extracted from https://github.com/Nemo157/async-compression and is under MIT or Apache-2 licence +// it will be removed when we find a long lasting solution to https://github.com/Nemo157/async-compression/issues/154 +#![allow(dead_code)] // unused without any features + +use core::fmt::Debug; +use core::fmt::{self}; + +/// Wraps a type and only allows unique borrowing, the main usecase is to wrap a `!Sync` type and +/// implement `Sync` for it as this type blocks having multiple shared references to the inner +/// value. +/// +/// # Safety +/// +/// We must be careful when accessing `inner`, there must be no way to create a shared reference to +/// it from a shared reference to an `Unshared`, as that would allow creating shared references on +/// multiple threads. +/// +/// As an example deriving or implementing `Clone` is impossible, two threads could attempt to +/// clone a shared `Unshared` reference which would result in accessing the same inner value +/// concurrently. +pub(crate) struct Unshared { + inner: T, +} + +impl Unshared { + pub(crate) fn new(inner: T) -> Self { + Unshared { inner } + } + + pub(crate) fn get_mut(&mut self) -> &mut T { + &mut self.inner + } +} + +/// Safety: See comments on main docs for `Unshared` +unsafe impl Sync for Unshared {} + +impl Debug for Unshared { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct(core::any::type_name::()).finish() + } +} diff --git a/apollo-router/src/axum_factory/compression/util.rs b/apollo-router/src/axum_factory/compression/util.rs new file mode 100644 index 0000000000..609667b217 --- /dev/null +++ b/apollo-router/src/axum_factory/compression/util.rs @@ -0,0 +1,64 @@ +#![allow(dead_code)] +// All code from this module is extracted from https://github.com/Nemo157/async-compression and is under MIT or Apache-2 licence +// it will be removed when we find a long lasting solution to https://github.com/Nemo157/async-compression/issues/154 +pub(crate) fn _assert_send() {} +pub(crate) fn _assert_sync() {} + +#[derive(Debug, Default)] +pub(crate) struct PartialBuffer> { + buffer: B, + index: usize, +} + +impl> PartialBuffer { + pub(crate) fn new(buffer: B) -> Self { + Self { buffer, index: 0 } + } + + pub(crate) fn written(&self) -> &[u8] { + &self.buffer.as_ref()[..self.index] + } + + pub(crate) fn unwritten(&self) -> &[u8] { + &self.buffer.as_ref()[self.index..] + } + + pub(crate) fn advance(&mut self, amount: usize) { + self.index += amount; + } + + pub(crate) fn get_mut(&mut self) -> &mut B { + &mut self.buffer + } + + pub(crate) fn into_inner(self) -> B { + self.buffer + } +} + +impl + AsMut<[u8]>> PartialBuffer { + pub(crate) fn unwritten_mut(&mut self) -> &mut [u8] { + &mut self.buffer.as_mut()[self.index..] + } + + pub(crate) fn copy_unwritten_from>(&mut self, other: &mut PartialBuffer) { + let len = std::cmp::min(self.unwritten().len(), other.unwritten().len()); + + self.unwritten_mut()[..len].copy_from_slice(&other.unwritten()[..len]); + + self.advance(len); + other.advance(len); + } +} + +impl + Default> PartialBuffer { + pub(crate) fn take(&mut self) -> Self { + std::mem::replace(self, Self::new(B::default())) + } +} + +impl + AsMut<[u8]>> From for PartialBuffer { + fn from(buffer: B) -> Self { + Self::new(buffer) + } +} diff --git a/apollo-router/src/axum_factory/mod.rs b/apollo-router/src/axum_factory/mod.rs index 5f6668794d..78234133dd 100644 --- a/apollo-router/src/axum_factory/mod.rs +++ b/apollo-router/src/axum_factory/mod.rs @@ -1,5 +1,6 @@ //! axum factory is useful to create an [`AxumHttpServerFactory`] which implements [`crate::http_server_factory::HttpServerFactory`] mod axum_http_server_factory; +mod compression; mod listeners; #[cfg(test)] pub(crate) mod tests;