diff --git a/.github/workflows/bencher.yml b/.github/workflows/bencher.yml index 51d41ded..3c44eea0 100644 --- a/.github/workflows/bencher.yml +++ b/.github/workflows/bencher.yml @@ -31,7 +31,6 @@ jobs: BENCHER_API_TOKEN: ${{ secrets.BENCHER_API_TOKEN }} BENCHER_CMD: cargo bench --all-features BENCHER_PROJECT: socketioxide - RUSTFLAGS: --cfg=socketioxide_test benchmark_pr: if: github.event_name == 'workflow_dispatch' && github.event.pull_request.head.repo.full_name == github.repository @@ -62,4 +61,3 @@ jobs: BENCHER_API_TOKEN: ${{ secrets.BENCHER_API_TOKEN }} BENCHER_CMD: cargo bench --all-features BENCHER_PROJECT: socketioxide - RUSTFLAGS: --cfg=socketioxide_test \ No newline at end of file diff --git a/.github/workflows/github-ci.yml b/.github/workflows/github-ci.yml index 664c4780..7d295d73 100644 --- a/.github/workflows/github-ci.yml +++ b/.github/workflows/github-ci.yml @@ -12,18 +12,16 @@ jobs: format: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@master - with: - toolchain: stable - components: rustfmt - - run: cargo fmt --all -- --check - env: - RUSTFLAGS: --cfg=socketioxide_test - + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: stable + components: rustfmt + - run: cargo fmt --all -- --check + test: runs-on: ubuntu-latest - + steps: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master @@ -39,8 +37,6 @@ jobs: target/ key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} - run: cargo test --tests --all-features --workspace - env: - RUSTFLAGS: --cfg=socketioxide_test udeps: runs-on: ubuntu-latest steps: @@ -85,13 +81,11 @@ jobs: key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}-${{ steps.msrv.outputs.version }} - uses: dtolnay/rust-toolchain@master with: - toolchain: ${{ steps.msrv.outputs.version }} - components: rustfmt, clippy + toolchain: ${{ steps.msrv.outputs.version }} + components: rustfmt, clippy - name: check crates run: cargo check -p socketioxide -p engineioxide --all-features - env: - RUSTFLAGS: --cfg=socketioxide_test feature_set: runs-on: ubuntu-latest @@ -183,20 +177,18 @@ jobs: run: cargo install clippy-sarif sarif-fmt || true - name: Run rust-clippy - run: - cargo clippy + run: cargo clippy --all-features + --tests --message-format=json | clippy-sarif | tee rust-clippy-results.sarif | sarif-fmt continue-on-error: true - env: - RUSTFLAGS: --cfg=socketioxide_test - name: Upload analysis results to GitHub uses: github/codeql-action/upload-sarif@v3 with: sarif_file: rust-clippy-results.sarif wait-for-processing: true - + engine_io: runs-on: ubuntu-latest needs: [test] @@ -226,7 +218,7 @@ jobs: ~/.cargo/git/db/ target/ key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}-release - - name: Install deps & run tests + - name: Install deps & run tests run: | cd engine.io-protocol/test-suite && npm install && cd ../.. cargo build -p engineioxide-e2e --bin engineioxide-e2e --features ${{ matrix.engineio-version }} --release @@ -236,7 +228,7 @@ jobs: run: cat server.txt - name: Client output if: always() - run: cat client.txt + run: cat client.txt socket_io: runs-on: ubuntu-latest @@ -266,7 +258,7 @@ jobs: ~/.cargo/git/db/ target/ key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}-release - - name: Install deps & run tests + - name: Install deps & run tests run: | cd socket.io-protocol/test-suite && npm install && cd ../.. cargo build -p socketioxide-e2e --bin socketioxide-e2e --features ${{ matrix.socketio-version }} --release @@ -276,4 +268,4 @@ jobs: run: cat server.txt - name: Client output if: always() - run: cat client.txt \ No newline at end of file + run: cat client.txt diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4d9068f1..01a23ed8 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -16,7 +16,7 @@ today! As a contributor, here are the guidelines we would like you to follow: ## Got a Question or Problem? -Please open a discussion on the [Q&A discussions](https://github.com/totodore/socketioxide/discussions) page. +Please open a discussion on the [Q&A discussions](https://github.com/totodore/socketioxide/discussions) page. We want to keep GitHub Issues for bugs and feature requests. If you open an issue it will be moved to Discussions. ## Found a Bug? @@ -139,11 +139,6 @@ You will need [rustc and cargo](www.rust-lang.org/tools/install). ```shell git clone https://github.com/totodore/socketioxide ``` -2. To test socketioxide don't forget to enable the flag `socketioxide_test` through the `RUSTFLAGS` environment variable: - - ```shell - export RUSTFLAGS="--cfg socketioxide_test" - ``` 2. Depending on what you want to change, clone the [socketio/engine.io-protocol](https://github.com/socketio/engine.io-protocol) repo or the [socketio/socket.io-protocol](https://github.com/socketio/socket.io-protocol) repo or both ```shell git clone https://github.com/socketio/engine.io-protocol @@ -261,4 +256,4 @@ The subject contains succinct description of the change: ### Body Just as in the **subject**, use the imperative, present tense: "change" not "changed" nor "changes". -The body should include the motivation for the change and contrast this with previous behavior. \ No newline at end of file +The body should include the motivation for the change and contrast this with previous behavior. diff --git a/engineioxide/Cargo.toml b/engineioxide/Cargo.toml index 41ba0341..dde3f270 100644 --- a/engineioxide/Cargo.toml +++ b/engineioxide/Cargo.toml @@ -58,6 +58,7 @@ hyper-util = { workspace = true, features = ["tokio", "client-legacy"] } [features] v3 = ["memchr", "unicode-segmentation", "itoa"] tracing = ["dep:tracing"] +__test_harness = [] [[bench]] name = "packet_encode" diff --git a/engineioxide/src/lib.rs b/engineioxide/src/lib.rs index 14777b7a..142435cf 100644 --- a/engineioxide/src/lib.rs +++ b/engineioxide/src/lib.rs @@ -36,7 +36,8 @@ pub use crate::str::Str; pub use service::{ProtocolVersion, TransportType}; pub use socket::{DisconnectReason, Socket}; -#[cfg(any(test, socketioxide_test))] +#[doc(hidden)] +#[cfg(feature = "__test_harness")] pub use packet::*; pub mod config; diff --git a/engineioxide/src/socket.rs b/engineioxide/src/socket.rs index 700a5571..a1dd24e9 100644 --- a/engineioxide/src/socket.rs +++ b/engineioxide/src/socket.rs @@ -126,8 +126,8 @@ pub struct Permit<'a> { impl Permit<'_> { /// Consume the permit and emit a message to the client. #[inline] - pub fn emit(self, msg: String) { - self.inner.send(smallvec![Packet::Message(msg.into())]); + pub fn emit(self, msg: Str) { + self.inner.send(smallvec![Packet::Message(msg)]); } /// Consume the permit and emit a binary message to the client. #[inline] @@ -138,9 +138,21 @@ impl Permit<'_> { /// Consume the permit and emit a message with multiple binary data to the client. /// /// It can be used to ensure atomicity when sending a string packet with adjacent binary packets. - pub fn emit_many(self, msg: String, data: Vec) { + pub fn emit_many(self, msg: Str, data: Vec) { let mut packets = SmallVec::with_capacity(data.len() + 1); - packets.push(Packet::Message(msg.into())); + packets.push(Packet::Message(msg)); + for d in data { + packets.push(Packet::Binary(d)); + } + self.inner.send(packets); + } + + /// Consume the permit and emit a message with multiple binary data to the client. + /// + /// It can be used to ensure atomicity when sending a string packet with adjacent binary packets. + pub fn emit_many_binary(self, bin: Bytes, data: Vec) { + let mut packets = SmallVec::with_capacity(data.len() + 1); + packets.push(Packet::Binary(bin)); for d in data { packets.push(Packet::Binary(d)); } @@ -469,7 +481,8 @@ impl std::fmt::Debug for Socket { } } -#[cfg(socketioxide_test)] +#[doc(hidden)] +#[cfg(feature = "__test_harness")] impl Drop for Socket where D: Default + Send + Sync + 'static, @@ -480,7 +493,8 @@ where } } -#[cfg(any(socketioxide_test, test))] +#[doc(hidden)] +#[cfg(feature = "__test_harness")] impl Socket where D: Default + Send + Sync + 'static, diff --git a/engineioxide/src/str.rs b/engineioxide/src/str.rs index 0105ce91..c96ecf0c 100644 --- a/engineioxide/src/str.rs +++ b/engineioxide/src/str.rs @@ -1,6 +1,6 @@ -use std::borrow::{Borrow, Cow}; - use bytes::Bytes; +use serde::{Deserialize, Serialize}; +use std::borrow::{Borrow, Cow}; /// A custom [`Bytes`] wrapper to efficiently store string packets #[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd)] @@ -92,6 +92,43 @@ impl From for String { unsafe { String::from_utf8_unchecked(vec) } } } +impl Serialize for Str { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(self.as_str()) + } +} +impl<'de> Deserialize<'de> for Str { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct StrVisitor; + impl<'de> serde::de::Visitor<'de> for StrVisitor { + type Value = Str; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("a str") + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + Ok(Str::copy_from_slice(v)) + } + fn visit_string(self, v: String) -> Result + where + E: serde::de::Error, + { + Ok(Str::from(v)) + } + } + deserializer.deserialize_str(StrVisitor) + } +} impl std::cmp::PartialEq<&str> for Str { fn eq(&self, other: &&str) -> bool { diff --git a/socketioxide/Cargo.toml b/socketioxide/Cargo.toml index 489addbb..59a5a3e4 100644 --- a/socketioxide/Cargo.toml +++ b/socketioxide/Cargo.toml @@ -25,6 +25,7 @@ tower.workspace = true http.workspace = true http-body.workspace = true thiserror.workspace = true +smallvec.workspace = true itoa.workspace = true hyper.workspace = true matchit.workspace = true @@ -41,6 +42,7 @@ v4 = ["engineioxide/v3"] tracing = ["dep:tracing", "engineioxide/tracing"] extensions = [] state = ["dep:state"] +__test_harness = ["engineioxide/__test_harness"] [dev-dependencies] engineioxide = { path = "../engineioxide", features = ["v3", "tracing"] } diff --git a/socketioxide/benches/packet_decode.rs b/socketioxide/benches/packet_decode.rs index 1c902ffc..fad00932 100644 --- a/socketioxide/benches/packet_decode.rs +++ b/socketioxide/benches/packet_decode.rs @@ -1,139 +1,175 @@ use bytes::Bytes; -use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use criterion::{black_box, criterion_group, criterion_main, BatchSize, Criterion}; use engineioxide::sid::Sid; +use serde_json::to_value; use socketioxide::{ packet::{Packet, PacketData}, + parser::{CommonParser, Parse, TransportPayload}, ProtocolVersion, }; + +fn encode(packet: Packet<'_>) -> String { + match CommonParser::default().encode(black_box(packet)).0 { + TransportPayload::Str(d) => d.into(), + TransportPayload::Bytes(_) => panic!("testing only returns str"), + } +} +fn decode(value: String) -> Option> { + CommonParser::default() + .decode_str(black_box(value.into())) + .ok() +} fn criterion_benchmark(c: &mut Criterion) { let mut group = c.benchmark_group("socketio_packet/decode"); group.bench_function("Decode packet connect on /", |b| { - let packet: String = - Packet::connect(black_box("/"), black_box(Sid::ZERO), ProtocolVersion::V5).into(); - b.iter(|| Packet::try_from(packet.clone()).unwrap()) + b.iter_batched( + || encode(Packet::connect("/", Sid::ZERO, ProtocolVersion::V5)), + decode, + BatchSize::SmallInput, + ) }); group.bench_function("Decode packet connect on /custom_nsp", |b| { - let packet: String = Packet::connect( - black_box("/custom_nsp"), - black_box(Sid::ZERO), - ProtocolVersion::V5, + b.iter_batched( + || { + encode(Packet::connect( + "/custom_nsp", + Sid::ZERO, + ProtocolVersion::V5, + )) + }, + decode, + BatchSize::SmallInput, ) - .into(); - b.iter(|| Packet::try_from(packet.clone()).unwrap()) }); const DATA: &str = r#"{"_placeholder":true,"num":0}"#; const BINARY: Bytes = Bytes::from_static(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); group.bench_function("Decode packet event on /", |b| { - let data = serde_json::to_value(DATA).unwrap(); - let packet: String = - Packet::event(black_box("/"), black_box("event"), black_box(data.clone())).into(); - b.iter(|| Packet::try_from(packet.clone()).unwrap()) + b.iter_batched( + || encode(Packet::event("/", "event", to_value(DATA).unwrap())), + decode, + BatchSize::SmallInput, + ) }); group.bench_function("Decode packet event on /custom_nsp", |b| { - let data = serde_json::to_value(DATA).unwrap(); - let packet: String = Packet::event( - black_box("custom_nsp"), - black_box("event"), - black_box(data.clone()), + b.iter_batched( + || { + encode(Packet::event( + "custom_nsp", + "event", + to_value(DATA).unwrap(), + )) + }, + decode, + BatchSize::SmallInput, ) - .into(); - b.iter(|| Packet::try_from(packet.clone()).unwrap()) }); group.bench_function("Decode packet event with ack on /", |b| { - let data = serde_json::to_value(DATA).unwrap(); - let packet: Packet = - Packet::event(black_box("/"), black_box("event"), black_box(data.clone())); - match packet.inner { - PacketData::Event(_, _, mut ack) => ack.insert(black_box(0)), - _ => panic!("Wrong packet type"), - }; - let packet: String = packet.into(); - b.iter(|| Packet::try_from(packet.clone()).unwrap()) + b.iter_batched( + || { + let packet = Packet::event("/", "event", to_value(DATA).unwrap()); + match packet.inner { + PacketData::Event(_, _, mut ack) => ack.insert(black_box(0)), + _ => panic!("Wrong packet type"), + }; + encode(packet) + }, + decode, + BatchSize::SmallInput, + ) }); group.bench_function("Decode packet event with ack on /custom_nsp", |b| { - let data = serde_json::to_value(DATA).unwrap(); - let packet = Packet::event( - black_box("/custom_nsp"), - black_box("event"), - black_box(data.clone()), - ); - match packet.inner { - PacketData::Event(_, _, mut ack) => ack.insert(black_box(0)), - _ => panic!("Wrong packet type"), - }; - let packet: String = packet.into(); - - b.iter(|| Packet::try_from(packet.clone()).unwrap()) + b.iter_batched( + || { + let packet = Packet::event("/custom_nsp", "event", to_value(DATA).unwrap()); + match packet.inner { + PacketData::Event(_, _, mut ack) => ack.insert(black_box(0)), + _ => panic!("Wrong packet type"), + }; + encode(packet) + }, + decode, + BatchSize::SmallInput, + ) }); group.bench_function("Decode packet ack on /", |b| { - let data = serde_json::to_value(DATA).unwrap(); - let packet: String = - Packet::ack(black_box("/"), black_box(data.clone()), black_box(0)).into(); - b.iter(|| Packet::try_from(packet.clone()).unwrap()) + b.iter_batched( + || encode(Packet::ack("/", to_value(DATA).unwrap(), black_box(0))), + decode, + BatchSize::SmallInput, + ) }); group.bench_function("Decode packet ack on /custom_nsp", |b| { - let data = serde_json::to_value(DATA).unwrap(); - let packet: String = Packet::ack( - black_box("/custom_nsp"), - black_box(data.clone()), - black_box(0), + b.iter_batched( + || encode(Packet::ack("/custom_nsp", to_value(DATA).unwrap(), 0)), + decode, + BatchSize::SmallInput, ) - .into(); - b.iter(|| Packet::try_from(packet.clone()).unwrap()) }); group.bench_function("Decode packet binary event (b64) on /", |b| { - let data = serde_json::to_value(DATA).unwrap(); - let packet: String = Packet::bin_event( - black_box("/"), - black_box("event"), - black_box(data.clone()), - black_box(vec![BINARY.clone()]), + b.iter_batched( + || { + encode(Packet::bin_event( + "/", + "event", + to_value(DATA).unwrap(), + vec![BINARY.clone()], + )) + }, + decode, + BatchSize::SmallInput, ) - .into(); - b.iter(|| Packet::try_from(packet.clone()).unwrap()) }); group.bench_function("Decode packet binary event (b64) on /custom_nsp", |b| { - let data = serde_json::to_value(DATA).unwrap(); - let packet: String = Packet::bin_event( - black_box("/custom_nsp"), - black_box("event"), - black_box(data.clone()), - black_box(vec![BINARY.clone()]), + b.iter_batched( + || { + encode(Packet::bin_event( + "/custom_nsp", + "event", + to_value(DATA).unwrap(), + vec![BINARY], + )) + }, + decode, + BatchSize::SmallInput, ) - .into(); - b.iter(|| Packet::try_from(packet.clone()).unwrap()) }); group.bench_function("Decode packet binary ack (b64) on /", |b| { - let data = serde_json::to_value(DATA).unwrap(); - let packet: String = Packet::bin_ack( - black_box("/"), - black_box(data.clone()), - black_box(vec![BINARY.clone()]), - black_box(0), + b.iter_batched( + || { + encode(Packet::bin_ack( + "/", + to_value(DATA).unwrap(), + vec![BINARY.clone()], + 0, + )) + }, + decode, + BatchSize::SmallInput, ) - .into(); - b.iter(|| Packet::try_from(packet.clone()).unwrap()) }); group.bench_function("Decode packet binary ack (b64) on /custom_nsp", |b| { - let data = serde_json::to_value(DATA).unwrap(); - let packet: String = Packet::bin_ack( - black_box("/custom_nsp"), - black_box(data.clone()), - black_box(vec![BINARY.clone()]), - black_box(0), + b.iter_batched( + || { + encode(Packet::bin_ack( + "/custom_nsp", + to_value(DATA).unwrap(), + vec![BINARY.clone()], + 0, + )) + }, + decode, + BatchSize::SmallInput, ) - .into(); - b.iter(|| Packet::try_from(packet.clone()).unwrap()) }); group.finish(); diff --git a/socketioxide/benches/packet_encode.rs b/socketioxide/benches/packet_encode.rs index e9c435a9..8952818c 100644 --- a/socketioxide/benches/packet_encode.rs +++ b/socketioxide/benches/packet_encode.rs @@ -1,150 +1,146 @@ use bytes::Bytes; -use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use criterion::{black_box, criterion_group, criterion_main, BatchSize, Criterion}; use engineioxide::sid::Sid; +use serde_json::to_value; use socketioxide::{ packet::{Packet, PacketData}, + parser::{CommonParser, Parse, TransportPayload}, ProtocolVersion, }; + +fn encode(packet: Packet<'_>) -> String { + match CommonParser::default().encode(black_box(packet)).0 { + TransportPayload::Str(d) => d.into(), + TransportPayload::Bytes(_) => panic!("testing only returns str"), + } +} + fn criterion_benchmark(c: &mut Criterion) { let mut group = c.benchmark_group("socketio_packet/encode"); + group.bench_function("Encode packet connect on /", |b| { - let packet = Packet::connect(black_box("/"), black_box(Sid::ZERO), ProtocolVersion::V5); - b.iter(|| { - let _: String = packet.clone().into(); - }) + b.iter_batched( + || Packet::connect("/", Sid::ZERO, ProtocolVersion::V5), + encode, + BatchSize::SmallInput, + ) }); + group.bench_function("Encode packet connect on /custom_nsp", |b| { - let packet = Packet::connect( - black_box("/custom_nsp"), - black_box(Sid::ZERO), - ProtocolVersion::V5, - ); - b.iter(|| { - let _: String = packet.clone().into(); - }) + b.iter_batched( + || Packet::connect("/custom_nsp", Sid::ZERO, ProtocolVersion::V5), + encode, + BatchSize::SmallInput, + ) }); const DATA: &str = r#"{"_placeholder":true,"num":0}"#; const BINARY: Bytes = Bytes::from_static(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); group.bench_function("Encode packet event on /", |b| { - let data = serde_json::to_value(DATA).unwrap(); - let packet = Packet::event(black_box("/"), black_box("event"), black_box(data.clone())); - b.iter(|| { - let _: String = packet.clone().into(); - }) + b.iter_batched( + || Packet::event("/", "event", to_value(DATA).unwrap()), + encode, + BatchSize::SmallInput, + ) }); group.bench_function("Encode packet event on /custom_nsp", |b| { - let data = serde_json::to_value(DATA).unwrap(); - let packet = Packet::event( - black_box("custom_nsp"), - black_box("event"), - black_box(data.clone()), - ); - b.iter(|| { - let _: String = packet.clone().into(); - }) + b.iter_batched( + || Packet::event("custom_nsp", "event", to_value(DATA).unwrap()), + encode, + BatchSize::SmallInput, + ) }); group.bench_function("Encode packet event with ack on /", |b| { - let data = serde_json::to_value(DATA).unwrap(); - let packet = Packet::event(black_box("/"), black_box("event"), black_box(data.clone())); - match packet.inner { - PacketData::Event(_, _, mut ack) => ack.insert(black_box(0)), - _ => panic!("Wrong packet type"), - }; - b.iter(|| { - let _: String = packet.clone().into(); - }) + b.iter_batched( + || { + let packet = Packet::event("/", "event", to_value(DATA).unwrap()); + if let PacketData::Event(_, _, mut ack) = packet.inner { + let _ = ack.insert(black_box(0)); + } + packet + }, + encode, + BatchSize::SmallInput, + ) }); group.bench_function("Encode packet event with ack on /custom_nsp", |b| { - let data = serde_json::to_value(DATA).unwrap(); - let packet = Packet::event( - black_box("/custom_nsp"), - black_box("event"), - black_box(data.clone()), - ); - match packet.inner { - PacketData::Event(_, _, mut ack) => ack.insert(black_box(0)), - _ => panic!("Wrong packet type"), - }; - b.iter(|| { - let _: String = packet.clone().into(); - }) + b.iter_batched( + || { + let packet = Packet::event("/custom_nsp", "event", to_value(DATA).unwrap()); + if let PacketData::Event(_, _, mut ack) = packet.inner { + let _ = ack.insert(black_box(0)); + } + packet + }, + encode, + BatchSize::SmallInput, + ) }); group.bench_function("Encode packet ack on /", |b| { - let data = serde_json::to_value(DATA).unwrap(); - let packet = Packet::ack(black_box("/"), black_box(data.clone()), black_box(0)); - b.iter(|| { - let _: String = packet.clone().into(); - }) + b.iter_batched( + || Packet::ack("/", to_value(DATA).unwrap(), 0), + encode, + BatchSize::SmallInput, + ) }); group.bench_function("Encode packet ack on /custom_nsp", |b| { - let data = serde_json::to_value(DATA).unwrap(); - let packet = Packet::ack( - black_box("/custom_nsp"), - black_box(data.clone()), - black_box(0), - ); - b.iter(|| { - let _: String = packet.clone().into(); - }) + b.iter_batched( + || Packet::ack("/custom_nsp", to_value(DATA).unwrap(), 0), + encode, + BatchSize::SmallInput, + ) }); group.bench_function("Encode packet binary event (b64) on /", |b| { - let data = serde_json::to_value(DATA).unwrap(); - let packet = Packet::bin_event( - black_box("/"), - black_box("event"), - black_box(data.clone()), - black_box(vec![BINARY.clone()]), - ); - b.iter(|| { - let _: String = packet.clone().into(); - }) + b.iter_batched( + || Packet::bin_event("/", "event", to_value(DATA).unwrap(), vec![BINARY.clone()]), + encode, + BatchSize::SmallInput, + ) }); group.bench_function("Encode packet binary event (b64) on /custom_nsp", |b| { - let data = serde_json::to_value(DATA).unwrap(); - let packet = Packet::bin_event( - black_box("/custom_nsp"), - black_box("event"), - black_box(data.clone()), - black_box(vec![BINARY.clone()]), - ); - b.iter(|| { - let _: String = packet.clone().into(); - }) + b.iter_batched( + || { + Packet::bin_event( + "/custom_nsp", + "event", + to_value(DATA).unwrap(), + vec![BINARY.clone()], + ) + }, + encode, + BatchSize::SmallInput, + ) }); group.bench_function("Encode packet binary ack (b64) on /", |b| { - let data = serde_json::to_value(DATA).unwrap(); - let packet = Packet::bin_ack( - black_box("/"), - black_box(data.clone()), - black_box(vec![BINARY.clone()]), - black_box(0), - ); - b.iter(|| { - let _: String = packet.clone().into(); - }) + b.iter_batched( + || Packet::bin_ack("/", to_value(DATA).unwrap(), vec![BINARY.clone()], 0), + encode, + BatchSize::SmallInput, + ) }); group.bench_function("Encode packet binary ack (b64) on /custom_nsp", |b| { - let data = serde_json::to_value(DATA).unwrap(); - let packet = Packet::bin_ack( - black_box("/custom_nsp"), - black_box(data.clone()), - black_box(vec![BINARY.clone()]), - black_box(0), - ); - b.iter(|| { - let _: String = packet.clone().into(); - }) + b.iter_batched( + || { + Packet::bin_ack( + "/custom_nsp", + to_value(DATA).unwrap(), + vec![BINARY.clone()], + 0, + ) + }, + encode, + BatchSize::SmallInput, + ) }); group.finish(); diff --git a/socketioxide/src/client.rs b/socketioxide/src/client.rs index 9d8f8ede..355eff1c 100644 --- a/socketioxide/src/client.rs +++ b/socketioxide/src/client.rs @@ -10,11 +10,13 @@ use futures_util::{FutureExt, TryFutureExt}; use engineioxide::sid::Sid; use matchit::{Match, Router}; +use serde_json::Value; use tokio::sync::oneshot; use crate::adapter::Adapter; use crate::handler::ConnectHandler; use crate::ns::NamespaceCtr; +use crate::parser::{self, Parse, Parser, TransportPayload}; use crate::socket::DisconnectReason; use crate::{ errors::Error, @@ -55,7 +57,7 @@ impl Client { /// Called when a socket connects to a new namespace fn sock_connect( &self, - auth: Option, + auth: Option, ns_path: Str, esocket: &Arc>>, ) { @@ -88,11 +90,22 @@ impl Client { ); esocket.close(EIoDisconnectReason::TransportClose); } else { - let packet: String = Packet::connect_error(ns_path, "Invalid namespace").into(); - if let Err(_e) = esocket.emit(packet) { - #[cfg(feature = "tracing")] - tracing::error!("error while sending invalid namespace packet: {}", _e); - } + let (packet, _) = esocket + .data + .parser + .get() + .unwrap() + .encode(Packet::connect_error(ns_path, "Invalid namespace")); + let _ = match packet { + TransportPayload::Str(p) => esocket.emit(p).map_err(|_e| { + #[cfg(feature = "tracing")] + tracing::error!("error while sending invalid namespace packet: {}", _e); + }), + TransportPayload::Bytes(p) => esocket.emit_binary(p).map_err(|_e| { + #[cfg(feature = "tracing")] + tracing::error!("error while sending invalid namespace packet: {}", _e); + }), + }; } } @@ -186,9 +199,8 @@ impl Client { #[derive(Debug)] pub struct SocketData { - /// Partial binary packet that is being received - /// Stored here until all the binary payloads are received - pub partial_bin_packet: Mutex>>, + /// The parser to decode the socket.io packets + pub parser: OnceLock, /// Channel used to notify the socket that it has been connected to a namespace for v5 pub connect_recv_tx: Mutex>>, @@ -199,12 +211,17 @@ pub struct SocketData { impl Default for SocketData { fn default() -> Self { Self { - partial_bin_packet: Default::default(), + parser: Default::default(), connect_recv_tx: Default::default(), io: OnceLock::new(), } } } +impl SocketData { + pub fn parser(&self) -> &Parser { + self.parser.get().unwrap() + } +} impl EngineIoHandler for Client { type Data = SocketData; @@ -212,6 +229,7 @@ impl EngineIoHandler for Client { #[cfg_attr(feature = "tracing", tracing::instrument(skip(self, socket), fields(sid = socket.id.to_string())))] fn on_connect(self: Arc, socket: Arc>>) { socket.data.io.set(SocketIo::from(self.clone())).ok(); + socket.data.parser.set(self.config.parser.clone()).ok(); #[cfg(feature = "tracing")] tracing::debug!("eio socket connect"); @@ -261,8 +279,9 @@ impl EngineIoHandler for Client { fn on_message(&self, msg: Str, socket: Arc>>) { #[cfg(feature = "tracing")] tracing::debug!("Received message: {:?}", msg); - let packet = match Packet::try_from(msg) { + let packet = match socket.data.parser().decode_str(msg) { Ok(packet) => packet, + Err(parser::Error::NeedsMoreBinaryData) => return, Err(_e) => { #[cfg(feature = "tracing")] tracing::debug!("socket serialization error: {}", _e); @@ -278,16 +297,6 @@ impl EngineIoHandler for Client { self.sock_connect(auth, packet.ns, &socket); Ok(()) } - PacketData::BinaryEvent(_, _, _) | PacketData::BinaryAck(_, _) => { - // Cache-in the socket data until all the binary payloads are received - socket - .data - .partial_bin_packet - .lock() - .unwrap() - .replace(packet); - Ok(()) - } _ => self.sock_propagate_packet(packet, socket.id), }; if let Err(ref err) = res { @@ -307,19 +316,25 @@ impl EngineIoHandler for Client { /// /// If the packet is complete, it is propagated to the namespace fn on_binary(&self, data: Bytes, socket: Arc>>) { - if apply_payload_on_packet(data, &socket) { - if let Some(packet) = socket.data.partial_bin_packet.lock().unwrap().take() { - if let Err(ref err) = self.sock_propagate_packet(packet, socket.id) { - #[cfg(feature = "tracing")] - tracing::debug!( - "error while propagating packet to socket {}: {}", - socket.id, - err - ); - if let Some(reason) = err.into() { - socket.close(reason); - } - } + let packet = match socket.data.parser().decode_bin(data) { + Ok(packet) => packet, + Err(parser::Error::NeedsMoreBinaryData) => return, + Err(_e) => { + #[cfg(feature = "tracing")] + tracing::debug!("socket serialization error: {}", _e); + socket.close(EIoDisconnectReason::PacketParsingError); + return; + } + }; + if let Err(ref err) = self.sock_propagate_packet(packet, socket.id) { + #[cfg(feature = "tracing")] + tracing::debug!( + "error while propagating packet to socket {}: {}", + socket.id, + err + ); + if let Some(reason) = err.into() { + socket.close(reason); } } } @@ -334,29 +349,8 @@ impl std::fmt::Debug for Client { } } -/// Utility that applies an incoming binary payload to a partial binary packet -/// waiting to be filled with all the payloads -/// -/// Returns true if the packet is complete and should be processed -fn apply_payload_on_packet(data: Bytes, socket: &EIoSocket>) -> bool { - #[cfg(feature = "tracing")] - tracing::debug!("[sid={}] applying payload on packet", socket.id); - if let Some(ref mut packet) = *socket.data.partial_bin_packet.lock().unwrap() { - match packet.inner { - PacketData::BinaryEvent(_, ref mut bin, _) | PacketData::BinaryAck(ref mut bin, _) => { - bin.add_payload(data); - bin.is_complete() - } - _ => unreachable!("partial_bin_packet should only be set for binary packets"), - } - } else { - #[cfg(feature = "tracing")] - tracing::debug!("[sid={}] socket received unexpected bin data", socket.id); - false - } -} - -#[cfg(socketioxide_test)] +#[doc(hidden)] +#[cfg(feature = "__test_harness")] impl Client { pub async fn new_dummy_sock( self: Arc, @@ -371,6 +365,7 @@ impl Client { let (esock, rx) = EIoSocket::>::new_dummy_piped(sid, Box::new(|_, _| {}), buffer_size); esock.data.io.set(SocketIo::from(self.clone())).ok(); + esock.data.parser.set(Parser::default()).ok(); let (tx1, mut rx1) = tokio::sync::mpsc::channel(buffer_size); tokio::spawn({ let esock = esock.clone(); @@ -393,12 +388,13 @@ impl Client { } } }); - let p: String = Packet { + let (p, _) = parser::CommonParser::default().encode(Packet { ns: ns.into(), - inner: PacketData::Connect(Some(serde_json::to_string(&auth).unwrap())), + inner: PacketData::Connect(Some(serde_json::to_value(&auth).unwrap())), + }); + if let TransportPayload::Str(s) = p { + self.on_message(s, esock.clone()); } - .into(); - self.on_message(p.into(), esock.clone()); // wait for the socket to be connected to the namespace tokio::time::sleep(std::time::Duration::from_millis(10)).await; diff --git a/socketioxide/src/extract/data.rs b/socketioxide/src/extract/data.rs index acc86f87..25de8457 100644 --- a/socketioxide/src/extract/data.rs +++ b/socketioxide/src/extract/data.rs @@ -27,9 +27,9 @@ where A: Adapter, { type Error = serde_json::Error; - fn from_connect_parts(_: &Arc>, auth: &Option) -> Result { + fn from_connect_parts(_: &Arc>, auth: &Option) -> Result { auth.as_ref() - .map(|a| serde_json::from_str::(a)) + .map(|a| serde_json::from_value::(a.clone())) //TODO: clone .unwrap_or(serde_json::from_str::("{}")) .map(Data) } @@ -60,10 +60,10 @@ where A: Adapter, { type Error = Infallible; - fn from_connect_parts(_: &Arc>, auth: &Option) -> Result { + fn from_connect_parts(_: &Arc>, auth: &Option) -> Result { let v: Result = auth .as_ref() - .map(|a| serde_json::from_str(a)) + .map(|a| serde_json::from_value(a.clone())) //TODO: clone .unwrap_or(serde_json::from_str("{}")); Ok(TryData(v)) } diff --git a/socketioxide/src/extract/extensions.rs b/socketioxide/src/extract/extensions.rs index 5d5624a1..3298896b 100644 --- a/socketioxide/src/extract/extensions.rs +++ b/socketioxide/src/extract/extensions.rs @@ -9,6 +9,7 @@ use bytes::Bytes; #[cfg(feature = "extensions")] #[cfg_attr(docsrs, doc(cfg(feature = "extensions")))] pub use extensions_extract::*; +use serde_json::Value; /// It was impossible to find the given extension. pub struct ExtensionNotFound(std::marker::PhantomData); @@ -48,7 +49,7 @@ impl FromConnectParts for HttpE type Error = ExtensionNotFound; fn from_connect_parts( s: &Arc>, - _: &Option, + _: &Option, ) -> Result> { extract_http_extension(s).map(HttpExtension) } @@ -56,7 +57,7 @@ impl FromConnectParts for HttpE impl FromConnectParts for MaybeHttpExtension { type Error = Infallible; - fn from_connect_parts(s: &Arc>, _: &Option) -> Result { + fn from_connect_parts(s: &Arc>, _: &Option) -> Result { Ok(MaybeHttpExtension(extract_http_extension(s).ok())) } } @@ -131,14 +132,14 @@ mod extensions_extract { type Error = ExtensionNotFound; fn from_connect_parts( s: &Arc>, - _: &Option, + _: &Option, ) -> Result> { extract_extension(s).map(Extension) } } impl FromConnectParts for MaybeExtension { type Error = Infallible; - fn from_connect_parts(s: &Arc>, _: &Option) -> Result { + fn from_connect_parts(s: &Arc>, _: &Option) -> Result { Ok(MaybeExtension(extract_extension(s).ok())) } } diff --git a/socketioxide/src/extract/mod.rs b/socketioxide/src/extract/mod.rs index abbad861..c7b22023 100644 --- a/socketioxide/src/extract/mod.rs +++ b/socketioxide/src/extract/mod.rs @@ -45,6 +45,7 @@ //! # use std::sync::Arc; //! # use std::convert::Infallible; //! # use socketioxide::SocketIo; +//! # use serde_json::Value; //! //! struct UserId(String); //! @@ -59,7 +60,7 @@ //! //! impl FromConnectParts for UserId { //! type Error = Infallible; -//! fn from_connect_parts(s: &Arc>, _: &Option) -> Result { +//! fn from_connect_parts(s: &Arc>, _: &Option) -> Result { //! // In a real app it would be better to parse the query params with a crate like `url` //! let uri = &s.req_parts().uri; //! let uid = uri diff --git a/socketioxide/src/extract/socket.rs b/socketioxide/src/extract/socket.rs index 914cd361..ed65a942 100644 --- a/socketioxide/src/extract/socket.rs +++ b/socketioxide/src/extract/socket.rs @@ -11,6 +11,7 @@ use crate::{ }; use bytes::Bytes; use serde::Serialize; +use serde_json::Value; /// An Extractor that returns a reference to a [`Socket`]. #[derive(Debug)] @@ -18,7 +19,7 @@ pub struct SocketRef(Arc>); impl FromConnectParts for SocketRef { type Error = Infallible; - fn from_connect_parts(s: &Arc>, _: &Option) -> Result { + fn from_connect_parts(s: &Arc>, _: &Option) -> Result { Ok(SocketRef(s.clone())) } } @@ -129,7 +130,7 @@ impl AckSender { } else { Packet::bin_ack(ns, data, self.binary, ack_id) }; - permit.send(packet); + permit.send(packet, self.socket.parser()); Ok(()) } else { Ok(()) @@ -139,7 +140,7 @@ impl AckSender { impl FromConnectParts for crate::ProtocolVersion { type Error = Infallible; - fn from_connect_parts(s: &Arc>, _: &Option) -> Result { + fn from_connect_parts(s: &Arc>, _: &Option) -> Result { Ok(s.protocol()) } } @@ -163,7 +164,7 @@ impl FromDisconnectParts for crate::ProtocolVersion { impl FromConnectParts for crate::TransportType { type Error = Infallible; - fn from_connect_parts(s: &Arc>, _: &Option) -> Result { + fn from_connect_parts(s: &Arc>, _: &Option) -> Result { Ok(s.transport_type()) } } @@ -198,7 +199,7 @@ impl FromDisconnectParts for DisconnectReason { impl FromConnectParts for SocketIo { type Error = Infallible; - fn from_connect_parts(s: &Arc>, _: &Option) -> Result { + fn from_connect_parts(s: &Arc>, _: &Option) -> Result { Ok(s.get_io().clone()) } } diff --git a/socketioxide/src/extract/state.rs b/socketioxide/src/extract/state.rs index 50886aec..99d7e770 100644 --- a/socketioxide/src/extract/state.rs +++ b/socketioxide/src/extract/state.rs @@ -1,4 +1,5 @@ use bytes::Bytes; +use serde_json::Value; use std::sync::Arc; @@ -59,10 +60,7 @@ impl std::error::Error for StateNotFound {} impl FromConnectParts for State { type Error = StateNotFound; - fn from_connect_parts( - s: &Arc>, - _: &Option, - ) -> Result> { + fn from_connect_parts(s: &Arc>, _: &Option) -> Result> { s.get_io() .get_state::() .map(State) diff --git a/socketioxide/src/handler/connect.rs b/socketioxide/src/handler/connect.rs index 634ac239..018a2dcc 100644 --- a/socketioxide/src/handler/connect.rs +++ b/socketioxide/src/handler/connect.rs @@ -117,6 +117,7 @@ use std::sync::Arc; use crate::{adapter::Adapter, socket::Socket}; use futures_core::Future; +use serde_json::Value; use super::MakeErasedHandler; @@ -127,11 +128,11 @@ type MiddlewareRes = Result<(), Box>; type MiddlewareResFut<'a> = Pin + Send + 'a>>; pub(crate) trait ErasedConnectHandler: Send + Sync + 'static { - fn call(&self, s: Arc>, auth: Option); + fn call(&self, s: Arc>, auth: Option); fn call_middleware<'a>( &'a self, s: Arc>, - auth: &'a Option, + auth: &'a Option, ) -> MiddlewareResFut<'a>; fn boxed_clone(&self) -> BoxedConnectHandler; @@ -149,7 +150,7 @@ pub trait FromConnectParts: Sized { /// Extract the arguments from the connect event. /// If it fails, the handler is not called - fn from_connect_parts(s: &Arc>, auth: &Option) -> Result; + fn from_connect_parts(s: &Arc>, auth: &Option) -> Result; } /// Define a middleware for the connect event. @@ -163,7 +164,7 @@ pub trait ConnectMiddleware: Sized + Clone + Send + Sync + 'stati fn call<'a>( &'a self, s: Arc>, - auth: &'a Option, + auth: &'a Option, ) -> impl Future + Send; #[doc(hidden)] @@ -179,13 +180,13 @@ pub trait ConnectMiddleware: Sized + Clone + Send + Sync + 'stati /// * See the [`extract`](crate::extract) module doc for more details on available extractors. pub trait ConnectHandler: Sized + Clone + Send + Sync + 'static { /// Call the handler with the given arguments. - fn call(&self, s: Arc>, auth: Option); + fn call(&self, s: Arc>, auth: Option); /// Call the middleware with the given arguments. fn call_middleware<'a>( &'a self, _: Arc>, - _: &'a Option, + _: &'a Option, ) -> MiddlewareResFut<'a> { Box::pin(async move { Ok(()) }) } @@ -278,14 +279,14 @@ where H: ConnectHandler + Send + Sync + 'static, T: Send + Sync + 'static, { - fn call(&self, s: Arc>, auth: Option) { + fn call(&self, s: Arc>, auth: Option) { self.handler.call(s, auth); } fn call_middleware<'a>( &'a self, s: Arc>, - auth: &'a Option, + auth: &'a Option, ) -> MiddlewareResFut<'a> { self.handler.call_middleware(s, auth) } @@ -303,14 +304,14 @@ where T: Send + Sync + 'static, T1: Send + Sync + 'static, { - fn call(&self, s: Arc>, auth: Option) { + fn call(&self, s: Arc>, auth: Option) { self.handler.call(s, auth); } fn call_middleware<'a>( &'a self, s: Arc>, - auth: &'a Option, + auth: &'a Option, ) -> MiddlewareResFut<'a> { Box::pin(async move { self.middleware.call(s, auth).await }) } @@ -339,7 +340,7 @@ where T: Send + Sync + 'static, T1: Send + Sync + 'static, { - async fn call<'a>(&'a self, s: Arc>, auth: &'a Option) -> MiddlewareRes { + async fn call<'a>(&'a self, s: Arc>, auth: &'a Option) -> MiddlewareRes { self.middleware.call(s, auth).await } } @@ -378,7 +379,7 @@ where T: Send + Sync + 'static, T1: Send + Sync + 'static, { - async fn call<'a>(&'a self, s: Arc>, auth: &'a Option) -> MiddlewareRes { + async fn call<'a>(&'a self, s: Arc>, auth: &'a Option) -> MiddlewareRes { self.middleware.call(s.clone(), auth).await?; self.next.call(s, auth).await } @@ -403,7 +404,7 @@ macro_rules! impl_handler_async { A: Adapter, $( $ty: FromConnectParts + Send, )* { - fn call(&self, s: Arc>, auth: Option) { + fn call(&self, s: Arc>, auth: Option) { $( let $ty = match $ty::from_connect_parts(&s, &auth) { Ok(v) => v, @@ -434,7 +435,7 @@ macro_rules! impl_handler { A: Adapter, $( $ty: FromConnectParts + Send, )* { - fn call(&self, s: Arc>, auth: Option) { + fn call(&self, s: Arc>, auth: Option) { $( let $ty = match $ty::from_connect_parts(&s, &auth) { Ok(v) => v, @@ -468,7 +469,7 @@ macro_rules! impl_middleware_async { async fn call<'a>( &'a self, s: Arc>, - auth: &'a Option, + auth: &'a Option, ) -> MiddlewareRes { $( let $ty = match $ty::from_connect_parts(&s, auth) { @@ -509,7 +510,7 @@ macro_rules! impl_middleware { async fn call<'a>( &'a self, s: Arc>, - auth: &'a Option, + auth: &'a Option, ) -> MiddlewareRes { $( let $ty = match $ty::from_connect_parts(&s, auth) { diff --git a/socketioxide/src/io.rs b/socketioxide/src/io.rs index 8120c505..3250d17c 100644 --- a/socketioxide/src/io.rs +++ b/socketioxide/src/io.rs @@ -16,6 +16,7 @@ use crate::{ handler::ConnectHandler, layer::SocketIoLayer, operators::{BroadcastOperators, RoomParam}, + parser::Parser, service::SocketIoService, BroadcastError, DisconnectError, }; @@ -35,6 +36,9 @@ pub struct SocketIoConfig { /// /// Defaults to 45 seconds. pub connect_timeout: Duration, + + /// The parser to use to encode and decode socket.io packets + pub parser: Parser, } impl Default for SocketIoConfig { @@ -46,6 +50,7 @@ impl Default for SocketIoConfig { }, ack_timeout: Duration::from_secs(5), connect_timeout: Duration::from_secs(45), + parser: Parser::default(), } } } @@ -900,7 +905,8 @@ impl From>> for SocketIo { } } -#[cfg(any(test, socketioxide_test))] +#[doc(hidden)] +#[cfg(feature = "__test_harness")] impl SocketIo { /// Create a dummy socket for testing purpose with a /// receiver to get the packets sent to the client @@ -919,6 +925,8 @@ impl SocketIo { #[cfg(test)] mod tests { + use crate::client::SocketData; + use super::*; #[test] @@ -949,7 +957,8 @@ mod tests { let sid = Sid::new(); let (_, io) = SocketIo::builder().build_svc(); io.ns("/", || {}); - let socket = Socket::new_dummy(sid, Box::new(|_, _| {})); + let socket = Socket::>::new_dummy(sid, Box::new(|_, _| {})); + socket.data.parser.set(Parser::default()).unwrap(); io.0.get_ns("/") .unwrap() .connect(sid, socket, None) diff --git a/socketioxide/src/lib.rs b/socketioxide/src/lib.rs index ef612260..c0c0105c 100644 --- a/socketioxide/src/lib.rs +++ b/socketioxide/src/lib.rs @@ -294,6 +294,7 @@ pub mod handler; pub mod layer; pub mod operators; pub mod packet; +pub mod parser; pub mod service; pub mod socket; diff --git a/socketioxide/src/ns.rs b/socketioxide/src/ns.rs index dde8d81c..c4c4fec8 100644 --- a/socketioxide/src/ns.rs +++ b/socketioxide/src/ns.rs @@ -12,6 +12,7 @@ use crate::{ }; use crate::{client::SocketData, errors::AdapterError}; use engineioxide::{sid::Sid, Str}; +use serde_json::Value; /// A [`Namespace`] constructor used for dynamic namespaces /// A namespace constructor only hold a common handler that will be cloned @@ -71,7 +72,7 @@ impl Namespace { self: Arc, sid: Sid, esocket: Arc>>, - auth: Option, + auth: Option, ) -> Result<(), ConnectFail> { let socket: Arc> = Socket::new(sid, self.clone(), esocket.clone()).into(); @@ -182,7 +183,8 @@ impl Namespace { } } -#[cfg(any(test, socketioxide_test))] +#[doc(hidden)] +#[cfg(feature = "__test_harness")] impl Namespace { pub fn new_dummy(sockets: [Sid; S]) -> Arc { let ns = Namespace::new("/".into(), || {}); diff --git a/socketioxide/src/operators.rs b/socketioxide/src/operators.rs index 22a4aeb2..4ae8a0ee 100644 --- a/socketioxide/src/operators.rs +++ b/socketioxide/src/operators.rs @@ -360,7 +360,7 @@ impl ConfOperators<'_, A> { } }; let packet = self.get_packet(event, data)?; - permit.send(packet); + permit.send(packet, self.socket.parser()); Ok(()) } diff --git a/socketioxide/src/packet.rs b/socketioxide/src/packet.rs index 261414a9..c9ef6402 100644 --- a/socketioxide/src/packet.rs +++ b/socketioxide/src/packet.rs @@ -5,10 +5,9 @@ use std::borrow::Cow; use crate::ProtocolVersion; use bytes::Bytes; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; -use crate::errors::Error; use engineioxide::{sid::Sid, Str}; /// The socket.io packet type. @@ -53,7 +52,7 @@ impl<'a> Packet<'a> { /// Sends a connect packet with payload. fn connect_v5(ns: Str, sid: Sid) -> Self { - let val = serde_json::to_string(&ConnectPacket { sid }).unwrap(); + let val: Value = serde_json::to_value(ConnectPacket { sid }).unwrap(); Self { inner: PacketData::Connect(Some(val)), ns, @@ -121,54 +120,6 @@ impl<'a> Packet<'a> { ns: ns.into(), } } - - /// Get the max size the packet could have when serialized - /// This is used to pre-allocate a buffer for the packet - /// - /// #### Disclaimer: The size does not include serialized `Value` size - fn get_size_hint(&self) -> usize { - use PacketData::*; - const PACKET_INDEX_SIZE: usize = 1; - const BINARY_PUNCTUATION_SIZE: usize = 2; - const ACK_PUNCTUATION_SIZE: usize = 1; - const NS_PUNCTUATION_SIZE: usize = 1; - - let data_size = match &self.inner { - Connect(Some(data)) => data.len(), - Connect(None) => 0, - Disconnect => 0, - Event(_, _, Some(ack)) => { - ack.checked_ilog10().unwrap_or(0) as usize + ACK_PUNCTUATION_SIZE - } - Event(_, _, None) => 0, - BinaryEvent(_, bin, None) => { - bin.payload_count.checked_ilog10().unwrap_or(0) as usize + BINARY_PUNCTUATION_SIZE - } - BinaryEvent(_, bin, Some(ack)) => { - ack.checked_ilog10().unwrap_or(0) as usize - + bin.payload_count.checked_ilog10().unwrap_or(0) as usize - + ACK_PUNCTUATION_SIZE - + BINARY_PUNCTUATION_SIZE - } - EventAck(_, ack) => ack.checked_ilog10().unwrap_or(0) as usize + ACK_PUNCTUATION_SIZE, - BinaryAck(bin, ack) => { - ack.checked_ilog10().unwrap_or(0) as usize - + bin.payload_count.checked_ilog10().unwrap_or(0) as usize - + ACK_PUNCTUATION_SIZE - + BINARY_PUNCTUATION_SIZE - } - ConnectError(data) => data.len(), - }; - - let nsp_size = if self.ns == "/" { - 0 - } else if self.ns.starts_with('/') { - self.ns.len() + NS_PUNCTUATION_SIZE - } else { - self.ns.len() + NS_PUNCTUATION_SIZE + 1 // (1 for the leading slash) - }; - data_size + nsp_size + PACKET_INDEX_SIZE - } } /// | Type | ID | Usage | @@ -183,7 +134,7 @@ impl<'a> Packet<'a> { #[derive(Debug, Clone, PartialEq, Eq)] pub enum PacketData<'a> { /// Connect packet with optional payload (only used with v5 for response) - Connect(Option), + Connect(Option), /// Disconnect packet, used to disconnect from a namespace Disconnect, /// Event packet with optional ack id, to request an ack from the other side @@ -206,19 +157,19 @@ pub struct BinaryPacket { /// Binary payload pub bin: Vec, /// The number of expected payloads (used when receiving data) - payload_count: usize, + pub payload_count: usize, } impl<'a> PacketData<'a> { - fn index(&self) -> char { + pub(crate) fn index(&self) -> usize { match self { - PacketData::Connect(_) => '0', - PacketData::Disconnect => '1', - PacketData::Event(_, _, _) => '2', - PacketData::EventAck(_, _) => '3', - PacketData::ConnectError(_) => '4', - PacketData::BinaryEvent(_, _, _) => '5', - PacketData::BinaryAck(_, _) => '6', + PacketData::Connect(_) => 0, + PacketData::Disconnect => 1, + PacketData::Event(_, _, _) => 2, + PacketData::EventAck(_, _) => 3, + PacketData::ConnectError(_) => 4, + PacketData::BinaryEvent(_, _, _) => 5, + PacketData::BinaryAck(_, _) => 6, } } @@ -240,6 +191,16 @@ impl<'a> PacketData<'a> { PacketData::BinaryEvent(_, _, _) | PacketData::BinaryAck(_, _) ) } + + /// Check if the binary packet is complete, it means that all payloads have been received + pub(crate) fn is_complete(&self) -> bool { + match self { + PacketData::BinaryEvent(_, bin, _) | PacketData::BinaryAck(bin, _) => { + bin.payload_count == bin.bin.len() + } + _ => true, + } + } } impl BinaryPacket { @@ -303,630 +264,8 @@ impl BinaryPacket { } } -impl<'a> From> for String { - fn from(mut packet: Packet<'a>) -> String { - use PacketData::*; - - // Serialize the data if there is any - // pre-serializing allows to preallocate the buffer - let data = match &mut packet.inner { - Event(e, data, _) | BinaryEvent(e, BinaryPacket { data, .. }, _) => { - // Expand the packet if it is an array with data -> ["event", ...data] - let packet = match data { - Value::Array(ref mut v) if !v.is_empty() => { - v.insert(0, Value::String((*e).to_string())); - serde_json::to_string(&v) - } - Value::Array(_) => serde_json::to_string::<(_, [(); 0])>(&(e, [])), - _ => serde_json::to_string(&(e, data)), - } - .unwrap(); - Some(packet) - } - EventAck(data, _) | BinaryAck(BinaryPacket { data, .. }, _) => { - // Enforce that the packet is an array -> [data] - let packet = match data { - Value::Array(_) => serde_json::to_string(&data), - Value::Null => Ok("[]".to_string()), - _ => serde_json::to_string(&[data]), - } - .unwrap(); - Some(packet) - } - _ => None, - }; - - let capacity = packet.get_size_hint() + data.as_ref().map(|d| d.len()).unwrap_or(0); - let mut res = String::with_capacity(capacity); - res.push(packet.inner.index()); - - // Add the ns if it is not the default one and the packet is not binary - // In case of bin packet, we should first add the payload count before ns - let push_nsp = |res: &mut String| { - if !packet.ns.is_empty() && packet.ns != "/" { - if !packet.ns.starts_with('/') { - res.push('/'); - } - res.push_str(&packet.ns); - res.push(','); - } - }; - - if !packet.inner.is_binary() { - push_nsp(&mut res); - } - - let mut itoa_buf = itoa::Buffer::new(); - - match packet.inner { - PacketData::Connect(Some(data)) => res.push_str(&data), - PacketData::Disconnect | PacketData::Connect(None) => (), - PacketData::Event(_, _, ack) => { - if let Some(ack) = ack { - res.push_str(itoa_buf.format(ack)); - } - - res.push_str(&data.unwrap()) - } - PacketData::EventAck(_, ack) => { - res.push_str(itoa_buf.format(ack)); - res.push_str(&data.unwrap()) - } - PacketData::ConnectError(data) => res.push_str(&data), - PacketData::BinaryEvent(_, bin, ack) => { - res.push_str(itoa_buf.format(bin.payload_count)); - res.push('-'); - - push_nsp(&mut res); - - if let Some(ack) = ack { - res.push_str(itoa_buf.format(ack)); - } - - res.push_str(&data.unwrap()) - } - PacketData::BinaryAck(packet, ack) => { - res.push_str(itoa_buf.format(packet.payload_count)); - res.push('-'); - - push_nsp(&mut res); - - res.push_str(itoa_buf.format(ack)); - res.push_str(&data.unwrap()) - } - }; - - res - } -} - -/// Deserialize an event packet from a string, formated as: -/// ```text -/// ["", ...] -/// ``` -fn deserialize_event_packet(data: &str) -> Result<(String, Value), Error> { - #[cfg(feature = "tracing")] - tracing::debug!("Deserializing event packet: {:?}", data); - let packet = match serde_json::from_str::(data)? { - Value::Array(packet) => packet, - _ => return Err(Error::InvalidEventName), - }; - - let event = packet - .first() - .ok_or(Error::InvalidEventName)? - .as_str() - .ok_or(Error::InvalidEventName)? - .to_string(); - let payload = Value::from_iter(packet.into_iter().skip(1)); - Ok((event, payload)) -} - -fn deserialize_packet(data: &str) -> Result, serde_json::Error> { - #[cfg(feature = "tracing")] - tracing::debug!("Deserializing packet: {:?}", data); - let packet = if data.is_empty() { - None - } else { - Some(serde_json::from_str(data)?) - }; - Ok(packet) -} - -/// Deserialize a packet from a string -/// The string should be in the format of: -/// ```text -/// [<# of binary attachments>-][,][][JSON-stringified payload without binary] -/// + binary attachments extracted -/// ``` -impl<'a> TryFrom for Packet<'a> { - type Error = Error; - - fn try_from(value: Str) -> Result { - let chars = value.as_bytes(); - // It is possible to parse the packet from a byte slice because separators are only ASCII - let index = (b'0'..=b'6') - .contains(&chars[0]) - .then_some(chars[0]) - .ok_or(Error::InvalidPacketType)?; - - // Move the cursor to skip the payload count if it is a binary packet - let mut i = if index == b'5' || index == b'6' { - chars - .iter() - .position(|x| *x == b'-') - .ok_or(Error::InvalidPacketType)? - + 1 - } else { - 1 - }; - - let start_index = i; - // Custom nsps will start with a slash - let ns = if chars.get(i) == Some(&b'/') { - loop { - match chars.get(i) { - Some(b',') => { - i += 1; - break value.slice(start_index..i - 1); - } - // It maybe possible depending on clients that ns does not end with a comma - // if it is the end of the packet - // e.g `1/custom` - None => { - break value.slice(start_index..i); - } - Some(_) => i += 1, - } - } - } else { - Str::from("/") - }; - - let start_index = i; - let ack: Option = loop { - match chars.get(i) { - Some(c) if c.is_ascii_digit() => i += 1, - Some(b'[' | b'{') if i > start_index => break value[start_index..i].parse().ok(), - _ => break None, - } - }; - - let data = &value[i..]; - let inner = match index { - b'0' => PacketData::Connect((!data.is_empty()).then(|| data.to_string())), - b'1' => PacketData::Disconnect, - b'2' => { - let (event, payload) = deserialize_event_packet(data)?; - PacketData::Event(event.into(), payload, ack) - } - b'3' => { - let packet = deserialize_packet(data)?.ok_or(Error::InvalidPacketType)?; - PacketData::EventAck(packet, ack.ok_or(Error::InvalidPacketType)?) - } - b'5' => { - let (event, payload) = deserialize_event_packet(data)?; - PacketData::BinaryEvent(event.into(), BinaryPacket::incoming(payload), ack) - } - b'6' => { - let packet = deserialize_packet(data)?.ok_or(Error::InvalidPacketType)?; - PacketData::BinaryAck( - BinaryPacket::incoming(packet), - ack.ok_or(Error::InvalidPacketType)?, - ) - } - _ => return Err(Error::InvalidPacketType), - }; - - Ok(Self { inner, ns }) - } -} - -#[cfg(any(test, socketioxide_test))] -impl<'a> TryFrom for Packet<'a> { - type Error = Error; - - fn try_from(value: String) -> Result { - Packet::try_from(Str::from(value)) - } -} - /// Connect packet sent by the client #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ConnectPacket { - sid: Sid, -} - -#[cfg(test)] -mod test { - use serde_json::json; - - use super::*; - - #[test] - fn packet_decode_connect() { - let sid = Sid::new(); - let payload = format!("0{}", json!({ "sid": sid })); - let packet = Packet::try_from(payload).unwrap(); - - assert_eq!(Packet::connect("/", sid, ProtocolVersion::V5), packet); - - let payload = format!("0/admin™,{}", json!({ "sid": sid })); - let packet = Packet::try_from(payload).unwrap(); - - assert_eq!(Packet::connect("/admin™", sid, ProtocolVersion::V5), packet); - } - - #[test] - fn packet_encode_connect() { - let sid = Sid::new(); - let payload = format!("0{}", json!({ "sid": sid })); - let packet: String = Packet::connect("/", sid, ProtocolVersion::V5).into(); - assert_eq!(packet, payload); - - let payload = format!("0/admin™,{}", json!({ "sid": sid })); - let packet: String = Packet::connect("/admin™", sid, ProtocolVersion::V5).into(); - assert_eq!(packet, payload); - } - - // Disconnect, - - #[test] - fn packet_decode_disconnect() { - let payload = "1".to_string(); - let packet = Packet::try_from(payload).unwrap(); - assert_eq!(Packet::disconnect("/"), packet); - - let payload = "1/admin™,".to_string(); - let packet = Packet::try_from(payload).unwrap(); - assert_eq!(Packet::disconnect("/admin™"), packet); - } - - #[test] - fn packet_encode_disconnect() { - let payload = "1".to_string(); - let packet: String = Packet::disconnect("/").into(); - assert_eq!(packet, payload); - - let payload = "1/admin™,".to_string(); - let packet: String = Packet::disconnect("/admin™").into(); - assert_eq!(packet, payload); - } - - // Event(String, Value, Option), - #[test] - fn packet_decode_event() { - let payload = format!("2{}", json!(["event", { "data": "value" }])); - let packet = Packet::try_from(payload).unwrap(); - - assert_eq!( - Packet::event("/", "event", json!([{"data": "value"}])), - packet - ); - - // Check with ack ID - let payload = format!("21{}", json!(["event", { "data": "value" }])); - let packet = Packet::try_from(payload).unwrap(); - - let mut comparison_packet = Packet::event("/", "event", json!([{"data": "value"}])); - comparison_packet.inner.set_ack_id(1); - assert_eq!(packet, comparison_packet); - - // Check with NS - let payload = format!("2/admin™,{}", json!(["event", { "data": "value™" }])); - let packet = Packet::try_from(payload).unwrap(); - - assert_eq!( - Packet::event("/admin™", "event", json!([{"data": "value™"}])), - packet - ); - - // Check with ack ID and NS - let payload = format!("2/admin™,1{}", json!(["event", { "data": "value™" }])); - let mut packet = Packet::try_from(payload).unwrap(); - packet.inner.set_ack_id(1); - - let mut comparison_packet = Packet::event("/admin™", "event", json!([{"data": "value™"}])); - comparison_packet.inner.set_ack_id(1); - - assert_eq!(packet, comparison_packet); - } - - #[test] - fn packet_encode_event() { - let payload = format!("2{}", json!(["event", { "data": "value™" }])); - let packet: String = Packet::event("/", "event", json!({ "data": "value™" })).into(); - - assert_eq!(packet, payload); - - // Encode empty data - let payload = format!("2{}", json!(["event", []])); - let packet: String = Packet::event("/", "event", json!([])).into(); - - assert_eq!(packet, payload); - - // Encode with ack ID - let payload = format!("21{}", json!(["event", { "data": "value™" }])); - let mut packet = Packet::event("/", "event", json!({ "data": "value™" })); - packet.inner.set_ack_id(1); - let packet: String = packet.into(); - - assert_eq!(packet, payload); - - // Encode with NS - let payload = format!("2/admin™,{}", json!(["event", { "data": "value™" }])); - let packet: String = Packet::event("/admin™", "event", json!({"data": "value™"})).into(); - - assert_eq!(packet, payload); - - // Encode with NS and ack ID - let payload = format!("2/admin™,1{}", json!(["event", { "data": "value™" }])); - let mut packet = Packet::event("/admin™", "event", json!([{"data": "value™"}])); - packet.inner.set_ack_id(1); - let packet: String = packet.into(); - assert_eq!(packet, payload); - } - - // EventAck(Value, i64), - #[test] - fn packet_decode_event_ack() { - let payload = "354[\"data\"]".to_string(); - let packet = Packet::try_from(payload).unwrap(); - - assert_eq!(Packet::ack("/", json!(["data"]), 54), packet); - - let payload = "3/admin™,54[\"data\"]".to_string(); - let packet = Packet::try_from(payload).unwrap(); - - assert_eq!(Packet::ack("/admin™", json!(["data"]), 54), packet); - } - - #[test] - fn packet_encode_event_ack() { - let payload = "354[\"data\"]".to_string(); - let packet: String = Packet::ack("/", json!("data"), 54).into(); - assert_eq!(packet, payload); - - let payload = "3/admin™,54[\"data\"]".to_string(); - let packet: String = Packet::ack("/admin™", json!("data"), 54).into(); - assert_eq!(packet, payload); - } - - #[test] - fn packet_encode_connect_error() { - let payload = format!("4{}", json!({ "message": "Invalid namespace" })); - let packet: String = Packet::connect_error("/", "Invalid namespace").into(); - assert_eq!(packet, payload); - - let payload = format!("4/admin™,{}", json!({ "message": "Invalid namespace" })); - let packet: String = Packet::connect_error("/admin™", "Invalid namespace").into(); - assert_eq!(packet, payload); - } - - // BinaryEvent(String, BinaryPacket, Option), - #[test] - fn packet_encode_binary_event() { - let json = json!(["event", { "data": "value™" }, { "_placeholder": true, "num": 0}]); - - let payload = format!("51-{}", json); - let packet: String = Packet::bin_event( - "/", - "event", - json!({ "data": "value™" }), - vec![Bytes::from_static(&[1])], - ) - .into(); - - assert_eq!(packet, payload); - - // Encode with ack ID - let payload = format!("51-254{}", json); - let mut packet = Packet::bin_event( - "/", - "event", - json!({ "data": "value™" }), - vec![Bytes::from_static(&[1])], - ); - packet.inner.set_ack_id(254); - let packet: String = packet.into(); - - assert_eq!(packet, payload); - - // Encode with NS - let payload = format!("51-/admin™,{}", json); - let packet: String = Packet::bin_event( - "/admin™", - "event", - json!([{"data": "value™"}]), - vec![Bytes::from_static(&[1])], - ) - .into(); - - assert_eq!(packet, payload); - - // Encode with NS and ack ID - let payload = format!("51-/admin™,254{}", json); - let mut packet = Packet::bin_event( - "/admin™", - "event", - json!([{"data": "value™"}]), - vec![Bytes::from_static(&[1])], - ); - packet.inner.set_ack_id(254); - let packet: String = packet.into(); - assert_eq!(packet, payload); - } - - #[test] - fn packet_decode_binary_event() { - let json = json!(["event", { "data": "value™" }, { "_placeholder": true, "num": 0}]); - let comparison_packet = |ack, ns: &'static str| Packet { - inner: PacketData::BinaryEvent( - "event".into(), - BinaryPacket { - bin: vec![Bytes::from_static(&[1])], - data: json!([{"data": "value™"}]), - payload_count: 1, - }, - ack, - ), - ns: ns.into(), - }; - - let payload = format!("51-{}", json); - let mut packet = Packet::try_from(payload).unwrap(); - if let PacketData::BinaryEvent(_, ref mut bin, _) = packet.inner { - bin.add_payload(Bytes::from_static(&[1])); - } - - assert_eq!(packet, comparison_packet(None, "/")); - - // Check with ack ID - let payload = format!("51-254{}", json); - let mut packet = Packet::try_from(payload).unwrap(); - if let PacketData::BinaryEvent(_, ref mut bin, _) = packet.inner { - bin.add_payload(Bytes::from_static(&[1])); - } - - assert_eq!(packet, comparison_packet(Some(254), "/")); - - // Check with NS - let payload = format!("51-/admin™,{}", json); - let mut packet = Packet::try_from(payload).unwrap(); - if let PacketData::BinaryEvent(_, ref mut bin, _) = packet.inner { - bin.add_payload(Bytes::from_static(&[1])); - } - - assert_eq!(packet, comparison_packet(None, "/admin™")); - - // Check with ack ID and NS - let payload = format!("51-/admin™,254{}", json); - let mut packet = Packet::try_from(payload).unwrap(); - if let PacketData::BinaryEvent(_, ref mut bin, _) = packet.inner { - bin.add_payload(Bytes::from_static(&[1])); - } - - assert_eq!(packet, comparison_packet(Some(254), "/admin™")); - } - - // BinaryAck(BinaryPacket, i64), - #[test] - fn packet_encode_binary_ack() { - let json = json!([{ "data": "value™" }, { "_placeholder": true, "num": 0}]); - - let payload = format!("61-54{}", json); - let packet: String = Packet::bin_ack( - "/", - json!({ "data": "value™" }), - vec![Bytes::from_static(&[1])], - 54, - ) - .into(); - - assert_eq!(packet, payload); - - // Encode with NS - let payload = format!("61-/admin™,54{}", json); - let packet: String = Packet::bin_ack( - "/admin™", - json!({ "data": "value™" }), - vec![Bytes::from_static(&[1])], - 54, - ) - .into(); - - assert_eq!(packet, payload); - } - - #[test] - fn packet_decode_binary_ack() { - let json = json!([{ "data": "value™" }, { "_placeholder": true, "num": 0}]); - let comparison_packet = |ack, ns: &'static str| Packet { - inner: PacketData::BinaryAck( - BinaryPacket { - bin: vec![Bytes::from_static(&[1])], - data: json!([{"data": "value™"}]), - payload_count: 1, - }, - ack, - ), - ns: ns.into(), - }; - - let payload = format!("61-54{}", json); - let mut packet = Packet::try_from(payload).unwrap(); - if let PacketData::BinaryAck(ref mut bin, _) = packet.inner { - bin.add_payload(Bytes::from_static(&[1])); - } - - assert_eq!(packet, comparison_packet(54, "/")); - - // Check with NS - let payload = format!("61-/admin™,54{}", json); - let mut packet = Packet::try_from(payload).unwrap(); - if let PacketData::BinaryAck(ref mut bin, _) = packet.inner { - bin.add_payload(Bytes::from_static(&[1])); - } - - assert_eq!(packet, comparison_packet(54, "/admin™")); - } - - #[test] - fn packet_size_hint() { - let sid = Sid::new(); - let len = serde_json::to_string(&ConnectPacket { sid }).unwrap().len(); - let packet = Packet::connect("/", sid, ProtocolVersion::V5); - assert_eq!(packet.get_size_hint(), len + 1); - - let packet = Packet::connect("/admin", sid, ProtocolVersion::V5); - assert_eq!(packet.get_size_hint(), len + 8); - - let packet = Packet::connect("admin", sid, ProtocolVersion::V4); - assert_eq!(packet.get_size_hint(), 8); - - let packet = Packet::disconnect("/"); - assert_eq!(packet.get_size_hint(), 1); - - let packet = Packet::disconnect("/admin"); - assert_eq!(packet.get_size_hint(), 8); - - let packet = Packet::event("/", "event", json!({ "data": "value™" })); - assert_eq!(packet.get_size_hint(), 1); - - let packet = Packet::event("/admin", "event", json!({ "data": "value™" })); - assert_eq!(packet.get_size_hint(), 8); - - let packet = Packet::ack("/", json!("data"), 54); - assert_eq!(packet.get_size_hint(), 3); - - let packet = Packet::ack("/admin", json!("data"), 54); - assert_eq!(packet.get_size_hint(), 10); - - let packet = Packet::bin_event( - "/", - "event", - json!({ "data": "value™" }), - vec![Bytes::from_static(&[1])], - ); - assert_eq!(packet.get_size_hint(), 3); - - let packet = Packet::bin_event( - "/admin", - "event", - json!({ "data": "value™" }), - vec![Bytes::from_static(&[1])], - ); - assert_eq!(packet.get_size_hint(), 10); - - let packet = Packet::bin_ack("/", json!("data"), vec![Bytes::from_static(&[1])], 54); - assert_eq!(packet.get_size_hint(), 5); - } - - #[test] - fn packet_reject_invalid_binary_event() { - let payload = "5invalid".to_owned(); - let err = Packet::try_from(payload).unwrap_err(); - - assert!(matches!(err, Error::InvalidPacketType)); - } + pub(crate) sid: Sid, } diff --git a/socketioxide/src/parser/common.rs b/socketioxide/src/parser/common.rs new file mode 100644 index 00000000..d14abc9d --- /dev/null +++ b/socketioxide/src/parser/common.rs @@ -0,0 +1,741 @@ +use std::sync::Mutex; + +use bytes::Bytes; +use engineioxide::Str; +use serde_json::Value; + +use crate::{ + packet::{BinaryPacket, Packet, PacketData}, + parser::{Error, TransportPayload}, +}; + +/// Parse and serialize from and into the socket.io common packet format. +/// +/// The resulting string should be in the format of: +/// ```text +/// [<# of binary attachments>-][,][][JSON-stringified payload without binary] +/// + binary attachments extracted +/// ``` +#[derive(Debug, Default)] +pub struct CommonParser { + /// Partial binary packet that is being received + /// Stored here until all the binary payloads are received + pub partial_bin_packet: Mutex>>, +} + +impl super::Parse for CommonParser { + fn encode(&self, mut packet: Packet<'_>) -> (TransportPayload, Vec) { + use PacketData::*; + + // Serialize the data if there is any + // pre-serializing allows to preallocate the buffer + let data = match &mut packet.inner { + Connect(Some(data)) => Some(serde_json::to_string(data).unwrap()), + Event(e, data, _) | BinaryEvent(e, BinaryPacket { data, .. }, _) => { + // Expand the packet if it is an array with data -> ["event", ...data] + let packet = match data { + Value::Array(ref mut v) if !v.is_empty() => { + v.insert(0, Value::String((*e).to_string())); + serde_json::to_string(&v) + } + Value::Array(_) => serde_json::to_string::<(_, [(); 0])>(&(e, [])), + _ => serde_json::to_string(&(e, data)), + } + .unwrap(); + Some(packet) + } + EventAck(data, _) | BinaryAck(BinaryPacket { data, .. }, _) => { + // Enforce that the packet is an array -> [data] + let packet = match data { + Value::Array(_) => serde_json::to_string(&data), + Value::Null => Ok("[]".to_string()), + _ => serde_json::to_string(&[data]), + } + .unwrap(); + Some(packet) + } + _ => None, + }; + + let capacity = get_size_hint(&packet) + data.as_ref().map(|d| d.len()).unwrap_or(0); + let mut res = String::with_capacity(capacity); + res.push(char::from_digit(packet.inner.index() as u32, 10).unwrap()); + + // Add the ns if it is not the default one and the packet is not binary + // In case of bin packet, we should first add the payload count before ns + let push_nsp = |res: &mut String| { + if !packet.ns.is_empty() && packet.ns != "/" { + if !packet.ns.starts_with('/') { + res.push('/'); + } + res.push_str(&packet.ns); + res.push(','); + } + }; + + if !packet.inner.is_binary() { + push_nsp(&mut res); + } + + let mut itoa_buf = itoa::Buffer::new(); + + match &packet.inner { + PacketData::Connect(Some(_)) => res.push_str(&data.unwrap()), + PacketData::Disconnect | PacketData::Connect(None) => (), + PacketData::Event(_, _, ack) => { + if let Some(ack) = *ack { + res.push_str(itoa_buf.format(ack)); + } + + res.push_str(&data.unwrap()) + } + PacketData::EventAck(_, ack) => { + res.push_str(itoa_buf.format(*ack)); + res.push_str(&data.unwrap()) + } + PacketData::ConnectError(data) => res.push_str(data), + PacketData::BinaryEvent(_, ref bin, ack) => { + res.push_str(itoa_buf.format(bin.payload_count)); + res.push('-'); + + push_nsp(&mut res); + + if let Some(ack) = *ack { + res.push_str(itoa_buf.format(ack)); + } + + res.push_str(&data.unwrap()) + } + PacketData::BinaryAck(ref packet, ack) => { + res.push_str(itoa_buf.format(packet.payload_count)); + res.push('-'); + + push_nsp(&mut res); + + res.push_str(itoa_buf.format(*ack)); + res.push_str(&data.unwrap()) + } + }; + + let bins = match packet.inner { + PacketData::BinaryEvent(_, bin, _) | PacketData::BinaryAck(bin, _) => { + Vec::from_iter(bin.bin.into_iter().map(Bytes::from)) + } + _ => Vec::new(), + }; + (TransportPayload::Str(res.into()), bins) + } + + fn decode_str(&self, value: Str) -> Result, Error> { + let chars = value.as_bytes(); + // It is possible to parse the packet from a byte slice because separators are only ASCII + let index = (b'0'..=b'6') + .contains(&chars[0]) + .then_some(chars[0]) + .ok_or(Error::InvalidPacketType)?; + + // Move the cursor to skip the payload count if it is a binary packet + let mut i = if index == b'5' || index == b'6' { + chars + .iter() + .position(|x| *x == b'-') + .ok_or(Error::InvalidPacketType)? + + 1 + } else { + 1 + }; + + let start_index = i; + // Custom nsps will start with a slash + let ns = if chars.get(i) == Some(&b'/') { + loop { + match chars.get(i) { + Some(b',') => { + i += 1; + break value.slice(start_index..i - 1); + } + // It maybe possible depending on clients that ns does not end with a comma + // if it is the end of the packet + // e.g `1/custom` + None => { + break value.slice(start_index..i); + } + Some(_) => i += 1, + } + } + } else { + Str::from("/") + }; + + let start_index = i; + let ack: Option = loop { + match chars.get(i) { + Some(c) if c.is_ascii_digit() => i += 1, + Some(b'[' | b'{') if i > start_index => break value[start_index..i].parse().ok(), + _ => break None, + } + }; + + let data = &value[i..]; + let inner = match index { + b'0' => { + if data.is_empty() { + PacketData::Connect(None) + } else { + PacketData::Connect(serde_json::from_str(data)?) + } + } + b'1' => PacketData::Disconnect, + b'2' => { + let (event, payload) = deserialize_event_packet(data)?; + PacketData::Event(event.into(), payload, ack) + } + b'3' => { + let packet = deserialize_packet(data)?.ok_or(Error::InvalidPacketType)?; + PacketData::EventAck(packet, ack.ok_or(Error::InvalidPacketType)?) + } + b'5' => { + let (event, payload) = deserialize_event_packet(data)?; + PacketData::BinaryEvent(event.into(), BinaryPacket::incoming(payload), ack) + } + b'6' => { + let packet = deserialize_packet(data)?.ok_or(Error::InvalidPacketType)?; + PacketData::BinaryAck( + BinaryPacket::incoming(packet), + ack.ok_or(Error::InvalidPacketType)?, + ) + } + _ => return Err(Error::InvalidPacketType), + }; + + if inner.is_binary() && !inner.is_complete() { + *self.partial_bin_packet.lock().unwrap() = Some(Packet { inner, ns }); + Err(Error::NeedsMoreBinaryData) + } else { + Ok(Packet { inner, ns }) + } + } + + fn decode_bin(&self, data: Bytes) -> Result, Error> { + #[cfg(feature = "tracing")] + tracing::debug!("[sid=] applying payload on packet"); // TODO: log sid + let packet = &mut *self.partial_bin_packet.lock().unwrap(); + match packet { + Some(Packet { + inner: + PacketData::BinaryEvent(_, ref mut bin, _) | PacketData::BinaryAck(ref mut bin, _), + .. + }) => { + bin.add_payload(data); + if !bin.is_complete() { + Err(Error::NeedsMoreBinaryData) + } else { + Ok(packet.take().unwrap()) + } + } + _ => Err(Error::UnexpectedBinaryPacket), + } + } +} + +impl CommonParser { + /// Create a new [`CommonParser`]. This is the default socket.io packet parser. + pub fn new() -> Self { + Self::default() + } +} + +/// Get the max size the packet could have when serialized +/// This is used to pre-allocate a buffer for the packet +/// +/// #### Disclaimer: The size does not include serialized `Value` size +fn get_size_hint(packet: &Packet<'_>) -> usize { + use PacketData::*; + const PACKET_INDEX_SIZE: usize = 1; + const BINARY_PUNCTUATION_SIZE: usize = 2; + const ACK_PUNCTUATION_SIZE: usize = 1; + const NS_PUNCTUATION_SIZE: usize = 1; + + let data_size = match &packet.inner { + Connect(Some(_)) => 0, + Connect(None) => 0, + Disconnect => 0, + Event(_, _, Some(ack)) => ack.checked_ilog10().unwrap_or(0) as usize + ACK_PUNCTUATION_SIZE, + Event(_, _, None) => 0, + BinaryEvent(_, bin, None) => { + bin.payload_count.checked_ilog10().unwrap_or(0) as usize + BINARY_PUNCTUATION_SIZE + } + BinaryEvent(_, bin, Some(ack)) => { + ack.checked_ilog10().unwrap_or(0) as usize + + bin.payload_count.checked_ilog10().unwrap_or(0) as usize + + ACK_PUNCTUATION_SIZE + + BINARY_PUNCTUATION_SIZE + } + EventAck(_, ack) => ack.checked_ilog10().unwrap_or(0) as usize + ACK_PUNCTUATION_SIZE, + BinaryAck(bin, ack) => { + ack.checked_ilog10().unwrap_or(0) as usize + + bin.payload_count.checked_ilog10().unwrap_or(0) as usize + + ACK_PUNCTUATION_SIZE + + BINARY_PUNCTUATION_SIZE + } + ConnectError(data) => data.len(), + }; + + let nsp_size = if packet.ns == "/" { + 0 + } else if packet.ns.starts_with('/') { + packet.ns.len() + NS_PUNCTUATION_SIZE + } else { + packet.ns.len() + NS_PUNCTUATION_SIZE + 1 // (1 for the leading slash) + }; + data_size + nsp_size + PACKET_INDEX_SIZE +} + +/// Deserialize an event packet from a string, formated as: +/// ```text +/// ["", ...] +/// ``` +fn deserialize_event_packet(data: &str) -> Result<(String, Value), Error> { + #[cfg(feature = "tracing")] + tracing::debug!("Deserializing event packet: {:?}", data); + let packet = match serde_json::from_str::(data)? { + Value::Array(packet) => packet, + _ => return Err(Error::InvalidEventName), + }; + + let event = packet + .first() + .ok_or(Error::InvalidEventName)? + .as_str() + .ok_or(Error::InvalidEventName)? + .to_string(); + let payload = Value::from_iter(packet.into_iter().skip(1)); + Ok((event, payload)) +} + +fn deserialize_packet( + data: &str, +) -> Result, serde_json::Error> { + #[cfg(feature = "tracing")] + tracing::debug!("Deserializing packet: {:?}", data); + let packet = if data.is_empty() { + None + } else { + Some(serde_json::from_str(data)?) + }; + Ok(packet) +} + +#[cfg(test)] +mod test { + use engineioxide::sid::Sid; + use serde_json::json; + + use crate::{parser::Parse, ProtocolVersion}; + + use super::*; + + fn encode(packet: Packet<'_>) -> String { + match CommonParser::default().encode(packet).0 { + TransportPayload::Str(d) => d.into(), + TransportPayload::Bytes(_) => panic!("testing only returns str"), + } + } + fn decode(value: String) -> Packet<'static> { + CommonParser::default().decode_str(value.into()).unwrap() + } + + #[test] + fn packet_decode_connect() { + let sid = Sid::new(); + let payload = format!("0{}", json!({ "sid": sid })); + let packet = decode(payload); + + assert_eq!(Packet::connect("/", sid, ProtocolVersion::V5), packet); + + let payload = format!("0/admin™,{}", json!({ "sid": sid })); + let packet = decode(payload); + + assert_eq!(Packet::connect("/admin™", sid, ProtocolVersion::V5), packet); + } + + #[test] + fn packet_encode_connect() { + let sid = Sid::new(); + let payload = format!("0{}", json!({ "sid": sid })); + let packet = encode(Packet::connect("/", sid, ProtocolVersion::V5)); + assert_eq!(packet, payload); + + let payload = format!("0/admin™,{}", json!({ "sid": sid })); + let packet: String = encode(Packet::connect("/admin™", sid, ProtocolVersion::V5)); + assert_eq!(packet, payload); + } + + // Disconnect, + + #[test] + fn packet_decode_disconnect() { + let payload = "1".to_string(); + let packet = decode(payload); + assert_eq!(Packet::disconnect("/"), packet); + + let payload = "1/admin™,".to_string(); + let packet = decode(payload); + assert_eq!(Packet::disconnect("/admin™"), packet); + } + + #[test] + fn packet_encode_disconnect() { + let payload = "1".to_string(); + let packet = encode(Packet::disconnect("/")); + assert_eq!(packet, payload); + + let payload = "1/admin™,".to_string(); + let packet = encode(Packet::disconnect("/admin™")); + assert_eq!(packet, payload); + } + + // Event(String, Value, Option), + #[test] + fn packet_decode_event() { + let payload = format!("2{}", json!(["event", { "data": "value" }])); + let packet = decode(payload); + + assert_eq!( + Packet::event("/", "event", json!([{"data": "value"}])), + packet + ); + + // Check with ack ID + let payload = format!("21{}", json!(["event", { "data": "value" }])); + let packet = decode(payload); + + let mut comparison_packet = Packet::event("/", "event", json!([{"data": "value"}])); + comparison_packet.inner.set_ack_id(1); + assert_eq!(packet, comparison_packet); + + // Check with NS + let payload = format!("2/admin™,{}", json!(["event", { "data": "value™" }])); + let packet = decode(payload); + + assert_eq!( + Packet::event("/admin™", "event", json!([{"data": "value™"}])), + packet + ); + + // Check with ack ID and NS + let payload = format!("2/admin™,1{}", json!(["event", { "data": "value™" }])); + let mut packet = decode(payload); + packet.inner.set_ack_id(1); + + let mut comparison_packet = Packet::event("/admin™", "event", json!([{"data": "value™"}])); + comparison_packet.inner.set_ack_id(1); + + assert_eq!(packet, comparison_packet); + } + + #[test] + fn packet_encode_event() { + let payload = format!("2{}", json!(["event", { "data": "value™" }])); + let packet = encode(Packet::event("/", "event", json!({ "data": "value™" }))); + + assert_eq!(packet, payload); + + // Encode empty data + let payload = format!("2{}", json!(["event", []])); + let packet = encode(Packet::event("/", "event", json!([]))); + + assert_eq!(packet, payload); + + // Encode with ack ID + let payload = format!("21{}", json!(["event", { "data": "value™" }])); + let mut packet = Packet::event("/", "event", json!({ "data": "value™" })); + packet.inner.set_ack_id(1); + let packet = encode(packet); + + assert_eq!(packet, payload); + + // Encode with NS + let payload = format!("2/admin™,{}", json!(["event", { "data": "value™" }])); + let packet = encode(Packet::event("/admin™", "event", json!({"data": "value™"}))); + + assert_eq!(packet, payload); + + // Encode with NS and ack ID + let payload = format!("2/admin™,1{}", json!(["event", { "data": "value™" }])); + let mut packet = Packet::event("/admin™", "event", json!([{"data": "value™"}])); + packet.inner.set_ack_id(1); + let packet = encode(packet); + assert_eq!(packet, payload); + } + + // EventAck(Value, i64), + #[test] + fn packet_decode_event_ack() { + let payload = "354[\"data\"]".to_string(); + let packet = decode(payload); + + assert_eq!(Packet::ack("/", json!(["data"]), 54), packet); + + let payload = "3/admin™,54[\"data\"]".to_string(); + let packet = decode(payload); + + assert_eq!(Packet::ack("/admin™", json!(["data"]), 54), packet); + } + + #[test] + fn packet_encode_event_ack() { + let payload = "354[\"data\"]".to_string(); + let packet = encode(Packet::ack("/", json!("data"), 54)); + assert_eq!(packet, payload); + + let payload = "3/admin™,54[\"data\"]".to_string(); + let packet = encode(Packet::ack("/admin™", json!("data"), 54)); + assert_eq!(packet, payload); + } + + #[test] + fn packet_encode_connect_error() { + let payload = format!("4{}", json!({ "message": "Invalid namespace" })); + let packet = encode(Packet::connect_error("/", "Invalid namespace")); + assert_eq!(packet, payload); + + let payload = format!("4/admin™,{}", json!({ "message": "Invalid namespace" })); + let packet = encode(Packet::connect_error("/admin™", "Invalid namespace")); + assert_eq!(packet, payload); + } + + // BinaryEvent(String, BinaryPacket, Option), + #[test] + fn packet_encode_binary_event() { + let json = json!(["event", { "data": "value™" }, { "_placeholder": true, "num": 0}]); + + let payload = format!("51-{}", json); + let packet = encode(Packet::bin_event( + "/", + "event", + json!({ "data": "value™" }), + vec![Bytes::from_static(&[1])], + )); + + assert_eq!(packet, payload); + + // Encode with ack ID + let payload = format!("51-254{}", json); + let mut packet = Packet::bin_event( + "/", + "event", + json!({ "data": "value™" }), + vec![Bytes::from_static(&[1])], + ); + packet.inner.set_ack_id(254); + let packet = encode(packet); + + assert_eq!(packet, payload); + + // Encode with NS + let payload = format!("51-/admin™,{}", json); + let packet = encode(Packet::bin_event( + "/admin™", + "event", + json!([{"data": "value™"}]), + vec![Bytes::from_static(&[1])], + )); + + assert_eq!(packet, payload); + + // Encode with NS and ack ID + let payload = format!("51-/admin™,254{}", json); + let mut packet = Packet::bin_event( + "/admin™", + "event", + json!([{"data": "value™"}]), + vec![Bytes::from_static(&[1])], + ); + packet.inner.set_ack_id(254); + let packet = encode(packet); + assert_eq!(packet, payload); + } + + #[test] + fn packet_decode_binary_event() { + let json = json!(["event", { "data": "value™" }, { "_placeholder": true, "num": 0}]); + let comparison_packet = |ack, ns: &'static str| Packet { + inner: PacketData::BinaryEvent( + "event".into(), + BinaryPacket { + bin: vec![Bytes::from_static(&[1])], + data: json!([{"data": "value™"}]), + payload_count: 1, + }, + ack, + ), + ns: ns.into(), + }; + let parser = CommonParser::default(); + let payload = format!("51-{}", json); + assert!(matches!( + parser.decode_str(payload.into()), + Err(Error::NeedsMoreBinaryData) + )); + let packet = parser.decode_bin(Bytes::from_static(&[1])).unwrap(); + + assert_eq!(packet, comparison_packet(None, "/")); + + // Check with ack ID + let parser = CommonParser::default(); + let payload = format!("51-254{}", json); + assert!(matches!( + parser.decode_str(payload.into()), + Err(Error::NeedsMoreBinaryData) + )); + let packet = parser.decode_bin(Bytes::from_static(&[1])).unwrap(); + + assert_eq!(packet, comparison_packet(Some(254), "/")); + + // Check with NS + let parser = CommonParser::default(); + let payload = format!("51-/admin™,{}", json); + assert!(matches!( + parser.decode_str(payload.into()), + Err(Error::NeedsMoreBinaryData) + )); + let packet = parser.decode_bin(Bytes::from_static(&[1])).unwrap(); + + assert_eq!(packet, comparison_packet(None, "/admin™")); + + // Check with ack ID and NS + let parser = CommonParser::default(); + let payload = format!("51-/admin™,254{}", json); + assert!(matches!( + parser.decode_str(payload.into()), + Err(Error::NeedsMoreBinaryData) + )); + let packet = parser.decode_bin(Bytes::from_static(&[1])).unwrap(); + + assert_eq!(packet, comparison_packet(Some(254), "/admin™")); + } + + // BinaryAck(BinaryPacket, i64), + #[test] + fn packet_encode_binary_ack() { + let json = json!([{ "data": "value™" }, { "_placeholder": true, "num": 0}]); + + let payload = format!("61-54{}", json); + let packet = encode(Packet::bin_ack( + "/", + json!({ "data": "value™" }), + vec![Bytes::from_static(&[1])], + 54, + )); + + assert_eq!(packet, payload); + + // Encode with NS + let payload = format!("61-/admin™,54{}", json); + let packet = encode(Packet::bin_ack( + "/admin™", + json!({ "data": "value™" }), + vec![Bytes::from_static(&[1])], + 54, + )); + + assert_eq!(packet, payload); + } + + #[test] + fn packet_decode_binary_ack() { + let json = json!([{ "data": "value™" }, { "_placeholder": true, "num": 0}]); + let comparison_packet = |ack, ns: &'static str| Packet { + inner: PacketData::BinaryAck( + BinaryPacket { + bin: vec![Bytes::from_static(&[1])], + data: json!([{"data": "value™"}]), + payload_count: 1, + }, + ack, + ), + ns: ns.into(), + }; + + let payload = format!("61-54{}", json); + let parser = CommonParser::default(); + assert!(matches!( + parser.decode_str(payload.into()), + Err(Error::NeedsMoreBinaryData) + )); + let packet = parser.decode_bin(Bytes::from_static(&[1])).unwrap(); + + assert_eq!(packet, comparison_packet(54, "/")); + + // Check with NS + let parser = CommonParser::default(); + let payload = format!("61-/admin™,54{}", json); + assert!(matches!( + parser.decode_str(payload.into()), + Err(Error::NeedsMoreBinaryData) + )); + let packet = parser.decode_bin(Bytes::from_static(&[1])).unwrap(); + assert_eq!(packet, comparison_packet(54, "/admin™")); + } + + #[test] + fn packet_size_hint() { + let sid = Sid::new(); + let packet = Packet::connect("/", sid, ProtocolVersion::V5); + assert_eq!(get_size_hint(&packet), 1); + + let packet = Packet::connect("/admin", sid, ProtocolVersion::V5); + assert_eq!(get_size_hint(&packet), 8); + + let packet = Packet::connect("admin", sid, ProtocolVersion::V4); + assert_eq!(get_size_hint(&packet), 8); + + let packet = Packet::disconnect("/"); + assert_eq!(get_size_hint(&packet), 1); + + let packet = Packet::disconnect("/admin"); + assert_eq!(get_size_hint(&packet), 8); + + let packet = Packet::event("/", "event", json!({ "data": "value™" })); + assert_eq!(get_size_hint(&packet), 1); + + let packet = Packet::event("/admin", "event", json!({ "data": "value™" })); + assert_eq!(get_size_hint(&packet), 8); + + let packet = Packet::ack("/", json!("data"), 54); + assert_eq!(get_size_hint(&packet), 3); + + let packet = Packet::ack("/admin", json!("data"), 54); + assert_eq!(get_size_hint(&packet), 10); + + let packet = Packet::bin_event( + "/", + "event", + json!({ "data": "value™" }), + vec![Bytes::from_static(&[1])], + ); + assert_eq!(get_size_hint(&packet), 3); + + let packet = Packet::bin_event( + "/admin", + "event", + json!({ "data": "value™" }), + vec![Bytes::from_static(&[1])], + ); + assert_eq!(get_size_hint(&packet), 10); + + let packet = Packet::bin_ack("/", json!("data"), vec![Bytes::from_static(&[1])], 54); + assert_eq!(get_size_hint(&packet), 5); + } + + #[test] + fn packet_reject_invalid_binary_event() { + let payload = "5invalid".to_owned(); + let err = CommonParser::default() + .decode_str(payload.into()) + .unwrap_err(); + + assert!(matches!(err, Error::InvalidPacketType)); + } +} diff --git a/socketioxide/src/parser/mod.rs b/socketioxide/src/parser/mod.rs new file mode 100644 index 00000000..c4ae3300 --- /dev/null +++ b/socketioxide/src/parser/mod.rs @@ -0,0 +1,132 @@ +//! Contains all the parser implementations for the socket.io protocol. +//! +//! The default parser is the [`CommonParser`] +use bytes::Bytes; + +mod common; + +pub use common::CommonParser; +use engineioxide::Str; + +use crate::packet::Packet; + +/// Represent a socket.io payload that can be sent over an engine.io connection +pub enum TransportPayload { + /// A string payload that will be sent as a string engine.io packet + Str(engineioxide::Str), + /// A binary payload that will be sent as a binary engine.io packet + Bytes(bytes::Bytes), +} +impl TransportPayload { + /// If the payload is a [`TransportPayload::Str`] or returns it + /// or None otherwise. + pub fn into_str(self) -> Option { + match self { + TransportPayload::Str(str) => Some(str), + TransportPayload::Bytes(_) => None, + } + } + + /// If the payload is a [`TransportPayload::Bytes`] or returns it + /// or None otherwise. + pub fn into_bytes(self) -> Option { + match self { + TransportPayload::Str(_) => None, + TransportPayload::Bytes(bytes) => Some(bytes), + } + } +} + +/// All socket.io parser should implement this trait +pub trait Parse: Default { + /// Convert a packet into multiple payloads to be sent + fn encode(&self, packet: Packet<'_>) -> (TransportPayload, Vec); + + /// Parse a given input string. If the payload needs more adjacent binary packet, + /// the partial packet will be kept and a [`Error::NeedsMoreBinaryData`] will be returned + fn decode_str(&self, data: Str) -> Result, Error>; + + /// Parse a given input binary. + fn decode_bin(&self, bin: Bytes) -> Result, Error>; +} + +/// All the parser available. +/// It also implements the [`Parse`] trait and therefore the +/// parser implementation is done over enum dispatch. +#[non_exhaustive] +#[derive(Debug)] +pub enum Parser { + /// The default parser + Common(CommonParser), +} +impl Default for Parser { + fn default() -> Self { + Parser::Common(CommonParser::default()) + } +} + +/// Recreate a new parser of the same type. +impl Clone for Parser { + fn clone(&self) -> Self { + match self { + Parser::Common(_) => Parser::Common(CommonParser::default()), + } + } +} + +impl Parse for Parser { + fn encode(&self, packet: Packet<'_>) -> (TransportPayload, Vec) { + match self { + Parser::Common(p) => p.encode(packet), + } + } + + fn decode_bin(&self, bin: Bytes) -> Result, Error> { + match self { + Parser::Common(p) => p.decode_bin(bin), + } + } + fn decode_str(&self, data: Str) -> Result, Error> { + match self { + Parser::Common(p) => p.decode_str(data), + } + } +} + +/// Errors when parsing/serializing socket.io packets +#[derive(thiserror::Error, Debug)] +pub enum Error { + /// Invalid packet type + #[error("invalid packet type")] + InvalidPacketType, + + /// Invalid event name + #[error("invalid event name")] + InvalidEventName, + + /// Invalid namespace + #[error("invalid namespace")] + InvalidNamespace, + + /// Received unexpected binary data + #[error( + "received unexpected binary data. Make sure you are using the same parser on both ends." + )] + UnexpectedBinaryPacket, + + /// Received unexpected string data + #[error( + "received unexpected string data. Make sure you are using the same parser on both ends." + )] + UnexpectedStringPacket, + + /// Needs more binary data before deserialization. It is not exactly an error, it is used for control flow, + /// e.g the common parser needs adjacent binary packets and therefore will returns [`NeedsMoreBinaryData`] n times for n adjacent binary packet expected. + /// In this case the user should call again the parser with the next binary payload. + #[error("needs more binary data before deserialization")] + NeedsMoreBinaryData, + + /// Error serializing json packet + #[error("error serializing json packet: {0:?}")] + Serialize(#[from] serde_json::Error), +} diff --git a/socketioxide/src/socket.rs b/socketioxide/src/socket.rs index a4167b00..9e168665 100644 --- a/socketioxide/src/socket.rs +++ b/socketioxide/src/socket.rs @@ -32,6 +32,7 @@ use crate::{ ns::Namespace, operators::{BroadcastOperators, ConfOperators, RoomParam}, packet::{BinaryPacket, Packet, PacketData}, + parser::{Parse, Parser, TransportPayload}, AckError, SocketIo, }; use crate::{ @@ -103,23 +104,16 @@ impl From for DisconnectReason { } pub(crate) trait PermitExt<'a> { - fn send(self, packet: Packet<'_>); + fn send(self, packet: Packet<'_>, parser: &Parser); } impl<'a> PermitExt<'a> for Permit<'a> { - fn send(self, mut packet: Packet<'_>) { - let bin_payloads = match packet.inner { - PacketData::BinaryEvent(_, ref mut bin, _) | PacketData::BinaryAck(ref mut bin, _) => { - Some(std::mem::take(&mut bin.bin)) - } - _ => None, - }; - - let msg = packet.into(); - - if let Some(bin_payloads) = bin_payloads { - self.emit_many(msg, bin_payloads); - } else { - self.emit(msg); + fn send(self, packet: Packet<'_>, parser: &Parser) { + let (msg, bin_payloads) = parser.encode(packet); + match msg { + TransportPayload::Str(msg) if bin_payloads.is_empty() => self.emit(msg), + TransportPayload::Str(msg) => self.emit_many(msg, bin_payloads), + TransportPayload::Bytes(bin) if bin_payloads.is_empty() => self.emit_binary(bin), + TransportPayload::Bytes(bin) => self.emit_many_binary(bin, bin_payloads), } } } @@ -321,7 +315,7 @@ impl Socket { let ns = self.ns.path.clone(); let data = serde_json::to_value(data)?; - permit.send(Packet::event(ns, event.into(), data)); + permit.send(Packet::event(ns, event.into(), data), self.parser()); Ok(()) } @@ -655,13 +649,18 @@ impl Socket { &self.ns.path } + #[inline] + pub(crate) fn parser(&self) -> &Parser { + self.esocket.data.parser() + } + pub(crate) fn reserve(&self) -> Result, SocketError<()>> { Ok(self.esocket.reserve()?) } pub(crate) fn send(&self, packet: Packet<'_>) -> Result<(), SocketError<()>> { let permit = self.reserve()?; - permit.send(packet); + permit.send(packet, self.parser()); Ok(()) } @@ -674,7 +673,7 @@ impl Socket { let ack = self.ack_counter.fetch_add(1, Ordering::SeqCst) + 1; packet.inner.set_ack_id(ack); - permit.send(packet); + permit.send(packet, self.parser()); self.ack_message.lock().unwrap().insert(ack, tx); rx } @@ -821,7 +820,8 @@ impl PartialEq for Socket { } } -#[cfg(any(test, socketioxide_test))] +#[doc(hidden)] +#[cfg(feature = "__test_harness")] impl Socket { /// Creates a dummy socket for testing purposes pub fn new_dummy(sid: Sid, ns: Arc>) -> Socket { @@ -836,6 +836,7 @@ impl Socket { ))); let s = Socket::new(sid, ns, engineioxide::Socket::new_dummy(sid, close_fn)); s.esocket.data.io.set(io).unwrap(); + s.esocket.data.parser.set(Parser::default()).unwrap(); s.set_connected(true); s } diff --git a/socketioxide/tests/acknowledgement.rs b/socketioxide/tests/acknowledgement.rs index 18d01f2e..71ea718d 100644 --- a/socketioxide/tests/acknowledgement.rs +++ b/socketioxide/tests/acknowledgement.rs @@ -4,7 +4,8 @@ mod utils; use engineioxide::Packet::*; use futures_util::StreamExt; use socketioxide::extract::SocketRef; -use socketioxide::packet::{Packet, PacketData}; +use socketioxide::packet::PacketData; +use socketioxide::parser::{CommonParser, Parse}; use socketioxide::SocketIo; use tokio::sync::mpsc; use tokio::time::Duration; @@ -95,13 +96,13 @@ pub async fn broadcast_with_ack() { let (stx, mut srx) = io.new_dummy_sock("/", ()).await; assert_some!(srx.recv().await); assert_some!(srx.recv().await); - + let parser = CommonParser::default(); while let Some(msg) = srx.recv().await { let msg = match msg { Message(msg) => msg, msg => panic!("Unexpected message: {:?}", msg), }; - let ack = match assert_ok!(Packet::try_from(msg)).inner { + let ack = match assert_ok!(parser.decode_str(msg)).inner { PacketData::Event(_, _, Some(ack)) => ack, _ => panic!("Unexpected packet"), }; diff --git a/socketioxide/tests/connect.rs b/socketioxide/tests/connect.rs index 7fb5050b..b9709560 100644 --- a/socketioxide/tests/connect.rs +++ b/socketioxide/tests/connect.rs @@ -3,7 +3,11 @@ mod utils; use bytes::Bytes; use engineioxide::Packet::*; use socketioxide::{ - extract::SocketRef, handler::ConnectHandler, packet::Packet, SendError, SocketError, SocketIo, + extract::SocketRef, + handler::ConnectHandler, + packet::Packet, + parser::{CommonParser, Parse, TransportPayload}, + SendError, SocketError, SocketIo, }; use tokio::sync::mpsc; @@ -12,8 +16,11 @@ fn create_msg( event: &str, data: impl Into, ) -> engineioxide::Packet { - let packet: String = Packet::event(ns, event, data.into()).into(); - Message(packet.into()) + let packet = Packet::event(ns, event, data.into()); + match CommonParser::default().encode(packet).0 { + TransportPayload::Str(data) => Message(data), + TransportPayload::Bytes(bin) => Binary(bin), + } } async fn timeout_rcv(srx: &mut tokio::sync::mpsc::Receiver) -> T { tokio::time::timeout(std::time::Duration::from_millis(10), srx.recv()) diff --git a/socketioxide/tests/extractors.rs b/socketioxide/tests/extractors.rs index 91a26b41..cf02342b 100644 --- a/socketioxide/tests/extractors.rs +++ b/socketioxide/tests/extractors.rs @@ -5,6 +5,7 @@ use std::time::Duration; use serde_json::{json, Value}; use socketioxide::extract::{Data, Extension, MaybeExtension, SocketRef, State, TryData}; use socketioxide::handler::ConnectHandler; +use socketioxide::parser::{CommonParser, Parse}; use tokio::sync::mpsc; use engineioxide::Packet as EioPacket; @@ -26,8 +27,9 @@ async fn timeout_rcv_err(srx: &mut tokio::sync::mpsc::Receiv } fn create_msg(ns: &'static str, event: &str, data: impl Into) -> EioPacket { - let packet: String = Packet::event(ns, event, data.into()).into(); - EioPacket::Message(packet.into()) + let parser = CommonParser::default(); + let (payload, _) = parser.encode(Packet::event(ns, event, data.into())); + EioPacket::Message(payload.into_str().unwrap()) } #[tokio::test]