From 9553267c9afe86930ef92c098bf8bc0c232beed6 Mon Sep 17 00:00:00 2001 From: totodore Date: Fri, 19 Apr 2024 17:32:46 -0300 Subject: [PATCH 01/19] feat(socketio/extensions): use `RwLock` rather than `DashMap` --- socketioxide/src/extensions.rs | 165 ++++++++------------------------- 1 file changed, 41 insertions(+), 124 deletions(-) diff --git a/socketioxide/src/extensions.rs b/socketioxide/src/extensions.rs index ac5a9e76..a7a59259 100644 --- a/socketioxide/src/extensions.rs +++ b/socketioxide/src/extensions.rs @@ -2,13 +2,14 @@ //! //! It is heavily inspired by the [`http::Extensions`] type from the `http` crate. //! -//! The main difference is that it uses a [`DashMap`] instead of a [`HashMap`](std::collections::HashMap) to allow concurrent access. +//! The main difference is that the inner [`HashMap`](std::collections::HashMap) is wrapped with an [`RwLock`] +//! to allow concurrent access. Moreover, any value extracted from the map is cloned before being returned. //! //! This is necessary because [`Extensions`] are shared between all the threads that handle the same socket. -use dashmap::DashMap; +use std::collections::HashMap; use std::fmt; -use std::ops::{Deref, DerefMut}; +use std::sync::RwLock; use std::{ any::{Any, TypeId}, hash::{BuildHasherDefault, Hasher}, @@ -17,78 +18,8 @@ use std::{ /// TypeMap value type AnyVal = Box; -/// The `AnyDashMap` is a `DashMap` that uses `TypeId` as keys and `Any` as values. -type AnyDashMap = DashMap>; - -/// A wrapper for a `MappedRef` that implements `Deref` and `DerefMut` to allow -/// easy access to the value. -pub struct Ref<'a, T>( - dashmap::mapref::one::MappedRef<'a, TypeId, AnyVal, T, BuildHasherDefault>, -); - -impl<'a, T> Ref<'a, T> { - /// Get a reference to the value. - pub fn value(&self) -> &T { - self.0.value() - } -} -impl<'a, T> Deref for Ref<'a, T> { - type Target = T; - - fn deref(&self) -> &T { - self.value() - } -} - -/// A wrapper for a `MappedRefMut` that implements `Deref` and `DerefMut` to allow -/// easy access to the value. -pub struct RefMut<'a, T>( - dashmap::mapref::one::MappedRefMut<'a, TypeId, AnyVal, T, BuildHasherDefault>, -); - -impl<'a, T> RefMut<'a, T> { - /// Get a reference to the value. - pub fn value(&self) -> &T { - self.0.value() - } - /// Get a mutable reference to the value. - pub fn value_mut(&mut self) -> &mut T { - self.0.value_mut() - } -} -impl<'a, T> Deref for RefMut<'a, T> { - type Target = T; - - fn deref(&self) -> &T { - self.value() - } -} -impl<'a, T> DerefMut for RefMut<'a, T> { - fn deref_mut(&mut self) -> &mut T { - self.value_mut() - } -} - -impl<'a, T: fmt::Debug> fmt::Debug for Ref<'a, T> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.value().fmt(f) - } -} -impl<'a, T: fmt::Debug> fmt::Debug for RefMut<'a, T> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.value().fmt(f) - } -} -impl<'a, T: fmt::Display> fmt::Display for Ref<'a, T> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.value().fmt(f) - } -} -impl<'a, T: fmt::Display> fmt::Display for RefMut<'a, T> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.value().fmt(f) - } -} +/// The `AnyDashMap` is a `HashMap` that uses `TypeId` as keys and `Any` as values. +type AnyHashMap = RwLock>>; // With TypeIds as keys, there's no need to hash them. They are already hashes // themselves, coming from the compiler. The IdHasher just holds the u64 of @@ -116,14 +47,13 @@ impl Hasher for IdHasher { /// /// It is heavily inspired by the `Extensions` type from the `http` crate. /// -/// The main difference is that it uses a `DashMap` instead of a `HashMap` to allow concurrent access. +/// The main difference is that the inner Map is wrapped with an `RwLock` to allow concurrent access. /// /// This is necessary because `Extensions` are shared between all the threads that handle the same socket. #[derive(Default)] pub struct Extensions { - /// The underlying map. It is not wrapped with an option because it would require insert calls to take a mutable reference. - /// Therefore an anydashmap will be allocated for every socket, even if it is not used. - map: AnyDashMap, + /// The underlying map + map: AnyHashMap, } impl Extensions { @@ -131,12 +61,14 @@ impl Extensions { #[inline] pub fn new() -> Extensions { Extensions { - map: AnyDashMap::default(), + map: AnyHashMap::default(), } } /// Insert a type into this `Extensions`. /// + /// The type must be cloneable and thread safe to be stored. + /// /// If a extension of this type already existed, it will /// be returned. /// @@ -149,18 +81,15 @@ impl Extensions { /// assert!(ext.insert(4u8).is_none()); /// assert_eq!(ext.insert(9i32), Some(5i32)); /// ``` - pub fn insert(&self, val: T) -> Option { + pub fn insert(&self, val: T) -> Option { self.map + .write() + .unwrap() .insert(TypeId::of::(), Box::new(val)) - .and_then(|boxed| { - (boxed as Box) - .downcast() - .ok() - .map(|boxed| *boxed) - }) + .and_then(|v| v.downcast().ok().map(|boxed| *boxed)) } - /// Get a reference to a type previously inserted on this `Extensions`. + /// Get a cloned value of a type previously inserted on this `Extensions`. /// /// # Example /// @@ -170,32 +99,15 @@ impl Extensions { /// assert!(ext.get::().is_none()); /// ext.insert(5i32); /// - /// assert_eq!(*ext.get::().unwrap(), 5i32); + /// assert_eq!(ext.get::().unwrap(), 5i32); /// ``` - pub fn get(&self) -> Option> { + pub fn get(&self) -> Option { self.map + .read() + .unwrap() .get(&TypeId::of::()) - .and_then(|entry| entry.try_map(|r| r.downcast_ref::()).ok()) - .map(|r| Ref(r)) - } - - /// Get a mutable reference to a type previously inserted on this `Extensions`. - /// - /// # Example - /// - /// ``` - /// # use socketioxide::extensions::Extensions; - /// let ext = Extensions::new(); - /// ext.insert(String::from("Hello")); - /// ext.get_mut::().unwrap().push_str(" World"); - /// - /// assert_eq!(*ext.get::().unwrap(), "Hello World"); - /// ``` - pub fn get_mut(&self) -> Option> { - self.map - .get_mut(&TypeId::of::()) - .and_then(|entry| entry.try_map(|r| r.downcast_mut::()).ok()) - .map(|r| RefMut(r)) + .and_then(|v| v.downcast_ref::()) + .cloned() } /// Remove a type from this `Extensions`. @@ -212,12 +124,11 @@ impl Extensions { /// assert!(ext.get::().is_none()); /// ``` pub fn remove(&self) -> Option { - self.map.remove(&TypeId::of::()).and_then(|(_, boxed)| { - (boxed as Box) - .downcast() - .ok() - .map(|boxed| *boxed) - }) + self.map + .write() + .unwrap() + .remove(&TypeId::of::()) + .and_then(|v| v.downcast().ok().map(|boxed| *boxed)) } /// Clear the `Extensions` of all inserted extensions. @@ -234,7 +145,7 @@ impl Extensions { /// ``` #[inline] pub fn clear(&self) { - self.map.clear(); + self.map.write().unwrap().clear(); } /// Check whether the extension set is empty or not. @@ -250,7 +161,7 @@ impl Extensions { /// ``` #[inline] pub fn is_empty(&self) -> bool { - self.map.is_empty() + self.map.read().unwrap().is_empty() } /// Get the number of extensions available. @@ -266,7 +177,7 @@ impl Extensions { /// ``` #[inline] pub fn len(&self) -> usize { - self.map.len() + self.map.read().unwrap().len() } } @@ -278,20 +189,26 @@ impl fmt::Debug for Extensions { #[test] fn test_extensions() { - #[derive(Debug, PartialEq)] + use std::sync::Arc; + #[derive(Debug, Clone, PartialEq)] struct MyType(i32); + #[derive(Debug, PartialEq)] + struct ComplexSharedType(u64); + let shared = Arc::new(ComplexSharedType(20)); + let extensions = Extensions::new(); extensions.insert(5i32); extensions.insert(MyType(10)); + extensions.insert(shared.clone()); - assert_eq!(extensions.get().as_deref(), Some(&5i32)); - assert_eq!(extensions.get_mut().as_deref_mut(), Some(&mut 5i32)); + assert_eq!(extensions.get(), Some(5i32)); + assert_eq!(extensions.get::>(), Some(shared)); assert_eq!(extensions.remove::(), Some(5i32)); assert!(extensions.get::().is_none()); assert!(extensions.get::().is_none()); - assert_eq!(extensions.get().as_deref(), Some(&MyType(10))); + assert_eq!(extensions.get(), Some(MyType(10))); } From 46c7a1b20803bff2ac49d17426eac3f378662d20 Mon Sep 17 00:00:00 2001 From: totodore Date: Fri, 19 Apr 2024 19:17:43 -0300 Subject: [PATCH 02/19] chore(bench): add bencher ci --- .github/workflows/bencher.yml | 71 ++++++++++++++++++++++++++++++ Cargo.toml | 4 +- socketioxide/Cargo.toml | 5 +++ socketioxide/benches/extensions.rs | 30 +++++++++++++ 4 files changed, 109 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/bencher.yml create mode 100644 socketioxide/benches/extensions.rs diff --git a/.github/workflows/bencher.yml b/.github/workflows/bencher.yml new file mode 100644 index 00000000..a9f1e556 --- /dev/null +++ b/.github/workflows/bencher.yml @@ -0,0 +1,71 @@ +name: CI + +on: + push: + branches: main + pull_request: + types: [opened, reopened, edited, synchronize] + +jobs: + benchmark_base: + if: github.event_name == 'push' && github.ref == 'refs/heads/main' + runs-on: ubuntu-latest + strategy: + matrix: + package: [socketioxide, engineioxide] + steps: + - uses: actions/checkout@v4 + - uses: bencherdev/bencher@main + - uses: actions/cache@v4 + with: + path: | + ~/.cargo/bin/ + ~/.cargo/registry/index/ + ~/.cargo/registry/cache/ + ~/.cargo/git/db/ + target/ + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + - name: Base branch benchmarks + run: | + bencher run \ + --branch main \ + --testbed ubuntu-latest \ + --err + env: + BENCHER_API_TOKEN: ${{ secrets.BENCHER_API_TOKEN }} + BENCHER_CMD: cargo bench --all-features -p ${{ matrix.package }} + BENCHER_PROJECT: socketioxide + + benchmark_pr: + if: github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository + permissions: + pull-requests: write + strategy: + matrix: + package: [socketioxide, engineioxide] + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/cache@v4 + with: + path: | + ~/.cargo/bin/ + ~/.cargo/registry/index/ + ~/.cargo/registry/cache/ + ~/.cargo/git/db/ + target/ + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + - uses: bencherdev/bencher@main + - name: PR benchmarks + run: | + bencher run \ + --branch '${{ github.head_ref }}' \ + --branch-start-point '${{ github.base_ref }}' \ + --branch-start-point-hash '${{ github.event.pull_request.base.sha }}' \ + --testbed ubuntu-latest \ + --err \ + --github-actions '${{ secrets.GITHUB_TOKEN }}' + env: + BENCHER_API_TOKEN: ${{ secrets.BENCHER_API_TOKEN }} + BENCHER_CMD: cargo bench --all-features -p ${{ matrix.package }} + BENCHER_PROJECT: socketioxide \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 257632d1..a54e2e80 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,5 +38,7 @@ pin-project-lite = "0.2.13" # Dev deps tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } -criterion = { version = "0.5.1", features = ["html_reports"] } +criterion = { version = "0.5.1", features = [ + "rayon", +], default-features = false } axum = "0.7.2" diff --git a/socketioxide/Cargo.toml b/socketioxide/Cargo.toml index d5d89337..d5958756 100644 --- a/socketioxide/Cargo.toml +++ b/socketioxide/Cargo.toml @@ -78,3 +78,8 @@ harness = false name = "itoa_bench" path = "benches/itoa_bench.rs" harness = false + +[[bench]] +name = "extensions" +path = "benches/extensions.rs" +harness = false diff --git a/socketioxide/benches/extensions.rs b/socketioxide/benches/extensions.rs new file mode 100644 index 00000000..4897fa5e --- /dev/null +++ b/socketioxide/benches/extensions.rs @@ -0,0 +1,30 @@ +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use socketioxide::extensions::Extensions; + +fn bench_extensions(c: &mut Criterion) { + let mut group = c.benchmark_group("extensions"); + group.bench_function("concurrent_inserts", |b| { + let ext = Extensions::new(); + b.iter(|| { + ext.insert(5i32); + }); + }); + // group.bench_function("concurrent_get", |b| { + // b.iter(|| { + // let mut ext = Extensions::new(); + // ext.insert(5i32); + // ext.clear(); + // }) + // }); + // group.bench_function("concurrent_get_inserts", |b| { + // b.iter(|| { + // let mut ext = Extensions::new(); + // ext.insert(5i32); + // ext.clear(); + // }) + // }); + group.finish(); +} + +criterion_group!(benches, bench_extensions); +criterion_main!(benches); From d7d6a3f115394fb74bd16d017f766649be93fdac Mon Sep 17 00:00:00 2001 From: totodore Date: Fri, 19 Apr 2024 19:18:10 -0300 Subject: [PATCH 03/19] fix: socketioxide benches with `Bytes` --- socketioxide/benches/packet_decode.rs | 11 ++++++----- socketioxide/benches/packet_encode.rs | 12 +++++++----- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/socketioxide/benches/packet_decode.rs b/socketioxide/benches/packet_decode.rs index ac190e09..a22783c8 100644 --- a/socketioxide/benches/packet_decode.rs +++ b/socketioxide/benches/packet_decode.rs @@ -1,3 +1,4 @@ +use bytes::Bytes; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use engineioxide::sid::Sid; use socketioxide::{ @@ -24,7 +25,7 @@ fn criterion_benchmark(c: &mut Criterion) { }); const DATA: &str = r#"{"_placeholder":true,"num":0}"#; - const BINARY: [u8; 10] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; + const BINARY: Bytes = Bytes::from_static(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); c.bench_function("Decode packet event on /", |b| { let data = serde_json::to_value(DATA).unwrap(); let packet: String = @@ -100,7 +101,7 @@ fn criterion_benchmark(c: &mut Criterion) { black_box("/"), black_box("event"), black_box(data.clone()), - black_box(vec![BINARY.to_vec().clone()]), + black_box(vec![BINARY.clone()]), ) .try_into() .unwrap(); @@ -113,7 +114,7 @@ fn criterion_benchmark(c: &mut Criterion) { black_box("/custom_nsp"), black_box("event"), black_box(data.clone()), - black_box(vec![BINARY.to_vec().clone()]), + black_box(vec![BINARY.clone()]), ) .try_into() .unwrap(); @@ -125,7 +126,7 @@ fn criterion_benchmark(c: &mut Criterion) { let packet: String = Packet::bin_ack( black_box("/"), black_box(data.clone()), - black_box(vec![BINARY.to_vec().clone()]), + black_box(vec![BINARY.clone()]), black_box(0), ) .try_into() @@ -138,7 +139,7 @@ fn criterion_benchmark(c: &mut Criterion) { let packet: String = Packet::bin_ack( black_box("/custom_nsp"), black_box(data.clone()), - black_box(vec![BINARY.to_vec().clone()]), + black_box(vec![BINARY.clone()]), black_box(0), ) .try_into() diff --git a/socketioxide/benches/packet_encode.rs b/socketioxide/benches/packet_encode.rs index b55f912a..7901c14d 100644 --- a/socketioxide/benches/packet_encode.rs +++ b/socketioxide/benches/packet_encode.rs @@ -1,3 +1,4 @@ +use bytes::Bytes; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use engineioxide::sid::Sid; use socketioxide::{ @@ -23,7 +24,8 @@ fn criterion_benchmark(c: &mut Criterion) { }); const DATA: &str = r#"{"_placeholder":true,"num":0}"#; - const BINARY: [u8; 10] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; + const BINARY: Bytes = Bytes::from_static(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); + c.bench_function("Encode packet event on /", |b| { let data = serde_json::to_value(DATA).unwrap(); let packet = Packet::event(black_box("/"), black_box("event"), black_box(data.clone())); @@ -98,7 +100,7 @@ fn criterion_benchmark(c: &mut Criterion) { black_box("/"), black_box("event"), black_box(data.clone()), - black_box(vec![BINARY.to_vec().clone()]), + black_box(vec![BINARY.clone()]), ); b.iter(|| { let _: String = packet.clone().try_into().unwrap(); @@ -111,7 +113,7 @@ fn criterion_benchmark(c: &mut Criterion) { black_box("/custom_nsp"), black_box("event"), black_box(data.clone()), - black_box(vec![BINARY.to_vec().clone()]), + black_box(vec![BINARY.clone()]), ); b.iter(|| { let _: String = packet.clone().try_into().unwrap(); @@ -123,7 +125,7 @@ fn criterion_benchmark(c: &mut Criterion) { let packet = Packet::bin_ack( black_box("/"), black_box(data.clone()), - black_box(vec![BINARY.to_vec().clone()]), + black_box(vec![BINARY.clone()]), black_box(0), ); b.iter(|| { @@ -136,7 +138,7 @@ fn criterion_benchmark(c: &mut Criterion) { let packet = Packet::bin_ack( black_box("/custom_nsp"), black_box(data.clone()), - black_box(vec![BINARY.to_vec().clone()]), + black_box(vec![BINARY.clone()]), black_box(0), ); b.iter(|| { From c307a73bb182704809771186f71489aab1a2bcea Mon Sep 17 00:00:00 2001 From: totodore Date: Fri, 19 Apr 2024 19:21:21 -0300 Subject: [PATCH 04/19] chore(bench): fix ci name --- .github/workflows/bencher.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/bencher.yml b/.github/workflows/bencher.yml index a9f1e556..6b0121d7 100644 --- a/.github/workflows/bencher.yml +++ b/.github/workflows/bencher.yml @@ -1,4 +1,4 @@ -name: CI +name: Bencher on: push: From 28943175ce720fb7e4e07abfc56b52ced3e8a198 Mon Sep 17 00:00:00 2001 From: totodore Date: Fri, 19 Apr 2024 19:22:24 -0300 Subject: [PATCH 05/19] chore(bench): add RUSTFLAG for testing --- .github/workflows/bencher.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/bencher.yml b/.github/workflows/bencher.yml index 6b0121d7..57b94ac1 100644 --- a/.github/workflows/bencher.yml +++ b/.github/workflows/bencher.yml @@ -35,6 +35,7 @@ jobs: BENCHER_API_TOKEN: ${{ secrets.BENCHER_API_TOKEN }} BENCHER_CMD: cargo bench --all-features -p ${{ matrix.package }} BENCHER_PROJECT: socketioxide + RUSTFLAGS: --cfg=socketioxide_test benchmark_pr: if: github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository @@ -68,4 +69,5 @@ jobs: env: BENCHER_API_TOKEN: ${{ secrets.BENCHER_API_TOKEN }} BENCHER_CMD: cargo bench --all-features -p ${{ matrix.package }} - BENCHER_PROJECT: socketioxide \ No newline at end of file + BENCHER_PROJECT: socketioxide + RUSTFLAGS: --cfg=socketioxide_test \ No newline at end of file From 44d73ba4a851d5f64f185b1fea0cc380c41a68fa Mon Sep 17 00:00:00 2001 From: totodore Date: Fri, 19 Apr 2024 19:27:30 -0300 Subject: [PATCH 06/19] fix: engineioxide benches --- engineioxide/benches/packet_decode.rs | 6 +++--- engineioxide/benches/packet_encode.rs | 4 +++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/engineioxide/benches/packet_decode.rs b/engineioxide/benches/packet_decode.rs index ea223048..a17bc47b 100644 --- a/engineioxide/benches/packet_decode.rs +++ b/engineioxide/benches/packet_decode.rs @@ -1,3 +1,4 @@ +use bytes::Bytes; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use engineioxide::Packet; @@ -21,9 +22,8 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| Packet::try_from(packet.as_str()).unwrap()) }); c.bench_function("Decode packet binary b64", |b| { - let packet: String = Packet::Binary(black_box(vec![0x00, 0x01, 0x02, 0x03, 0x04, 0x05])) - .try_into() - .unwrap(); + const BYTES: Bytes = Bytes::from_static(&[0x00, 0x01, 0x02, 0x03, 0x04, 0x05]); + let packet: String = Packet::Binary(BYTES).try_into().unwrap(); b.iter(|| Packet::try_from(packet.clone()).unwrap()) }); } diff --git a/engineioxide/benches/packet_encode.rs b/engineioxide/benches/packet_encode.rs index 499e9e2b..a8b44a62 100644 --- a/engineioxide/benches/packet_encode.rs +++ b/engineioxide/benches/packet_encode.rs @@ -1,3 +1,4 @@ +use bytes::Bytes; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use engineioxide::{config::EngineIoConfig, sid::Sid, OpenPacket, Packet, TransportType}; @@ -27,7 +28,8 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| TryInto::::try_into(packet.clone())) }); c.bench_function("Encode packet binary b64", |b| { - let packet = Packet::Binary(black_box(vec![0x00, 0x01, 0x02, 0x03, 0x04, 0x05])); + const BYTES: Bytes = Bytes::from_static(&[0x00, 0x01, 0x02, 0x03, 0x04, 0x05]); + let packet = Packet::Binary(BYTES); b.iter(|| TryInto::::try_into(packet.clone())) }); } From 191e3fa7e69656c97b1b260690fd0a4a34d76f7f Mon Sep 17 00:00:00 2001 From: totodore Date: Fri, 19 Apr 2024 19:43:04 -0300 Subject: [PATCH 07/19] chore(bench): remove matrix test --- .github/workflows/bencher.yml | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/.github/workflows/bencher.yml b/.github/workflows/bencher.yml index 57b94ac1..e7fa898d 100644 --- a/.github/workflows/bencher.yml +++ b/.github/workflows/bencher.yml @@ -10,9 +10,6 @@ jobs: benchmark_base: if: github.event_name == 'push' && github.ref == 'refs/heads/main' runs-on: ubuntu-latest - strategy: - matrix: - package: [socketioxide, engineioxide] steps: - uses: actions/checkout@v4 - uses: bencherdev/bencher@main @@ -33,7 +30,7 @@ jobs: --err env: BENCHER_API_TOKEN: ${{ secrets.BENCHER_API_TOKEN }} - BENCHER_CMD: cargo bench --all-features -p ${{ matrix.package }} + BENCHER_CMD: cargo bench --all-features BENCHER_PROJECT: socketioxide RUSTFLAGS: --cfg=socketioxide_test @@ -41,9 +38,6 @@ jobs: if: github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository permissions: pull-requests: write - strategy: - matrix: - package: [socketioxide, engineioxide] runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -68,6 +62,6 @@ jobs: --github-actions '${{ secrets.GITHUB_TOKEN }}' env: BENCHER_API_TOKEN: ${{ secrets.BENCHER_API_TOKEN }} - BENCHER_CMD: cargo bench --all-features -p ${{ matrix.package }} + BENCHER_CMD: cargo bench --all-features BENCHER_PROJECT: socketioxide RUSTFLAGS: --cfg=socketioxide_test \ No newline at end of file From 66aeef9710037c7e62d4c4fbaeecae9f332fe128 Mon Sep 17 00:00:00 2001 From: totodore Date: Fri, 19 Apr 2024 19:50:51 -0300 Subject: [PATCH 08/19] chore(bench): add groups --- engineioxide/benches/packet_decode.rs | 13 +++++---- engineioxide/benches/packet_encode.rs | 15 ++++++----- socketioxide/benches/extensions.rs | 26 +++++++++--------- socketioxide/benches/itoa_bench.rs | 38 --------------------------- socketioxide/benches/packet_decode.rs | 27 ++++++++++--------- socketioxide/benches/packet_encode.rs | 27 ++++++++++--------- 6 files changed, 59 insertions(+), 87 deletions(-) delete mode 100644 socketioxide/benches/itoa_bench.rs diff --git a/engineioxide/benches/packet_decode.rs b/engineioxide/benches/packet_decode.rs index a17bc47b..2984e42f 100644 --- a/engineioxide/benches/packet_decode.rs +++ b/engineioxide/benches/packet_decode.rs @@ -3,29 +3,32 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion}; use engineioxide::Packet; fn criterion_benchmark(c: &mut Criterion) { - c.bench_function("Decode packet ping/pong", |b| { + let mut group = c.benchmark_group("engineio_packet/decode"); + group.bench_function("Decode packet ping/pong", |b| { let packet: String = Packet::Ping.try_into().unwrap(); b.iter(|| Packet::try_from(packet.as_str()).unwrap()) }); - c.bench_function("Decode packet ping/pong upgrade", |b| { + group.bench_function("Decode packet ping/pong upgrade", |b| { let packet: String = Packet::PingUpgrade.try_into().unwrap(); b.iter(|| Packet::try_from(packet.as_str()).unwrap()) }); - c.bench_function("Decode packet message", |b| { + group.bench_function("Decode packet message", |b| { let packet: String = Packet::Message(black_box("Hello").to_string()) .try_into() .unwrap(); b.iter(|| Packet::try_from(packet.as_str()).unwrap()) }); - c.bench_function("Decode packet noop", |b| { + group.bench_function("Decode packet noop", |b| { let packet: String = Packet::Noop.try_into().unwrap(); b.iter(|| Packet::try_from(packet.as_str()).unwrap()) }); - c.bench_function("Decode packet binary b64", |b| { + group.bench_function("Decode packet binary b64", |b| { const BYTES: Bytes = Bytes::from_static(&[0x00, 0x01, 0x02, 0x03, 0x04, 0x05]); let packet: String = Packet::Binary(BYTES).try_into().unwrap(); b.iter(|| Packet::try_from(packet.clone()).unwrap()) }); + + group.finish(); } criterion_group!(benches, criterion_benchmark); diff --git a/engineioxide/benches/packet_encode.rs b/engineioxide/benches/packet_encode.rs index a8b44a62..97492aef 100644 --- a/engineioxide/benches/packet_encode.rs +++ b/engineioxide/benches/packet_encode.rs @@ -3,7 +3,8 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion}; use engineioxide::{config::EngineIoConfig, sid::Sid, OpenPacket, Packet, TransportType}; fn criterion_benchmark(c: &mut Criterion) { - c.bench_function("Encode packet open", |b| { + let mut group = c.benchmark_group("engineio_packet/encode"); + group.bench_function("Encode packet open", |b| { let packet = Packet::Open(OpenPacket::new( black_box(TransportType::Polling), black_box(Sid::ZERO), @@ -11,27 +12,29 @@ fn criterion_benchmark(c: &mut Criterion) { )); b.iter(|| TryInto::::try_into(packet.clone())) }); - c.bench_function("Encode packet ping/pong", |b| { + group.bench_function("Encode packet ping/pong", |b| { let packet = Packet::Ping; b.iter(|| TryInto::::try_into(packet.clone())) }); - c.bench_function("Encode packet ping/pong upgrade", |b| { + group.bench_function("Encode packet ping/pong upgrade", |b| { let packet = Packet::PingUpgrade; b.iter(|| TryInto::::try_into(packet.clone())) }); - c.bench_function("Encode packet message", |b| { + group.bench_function("Encode packet message", |b| { let packet = Packet::Message(black_box("Hello").to_string()); b.iter(|| TryInto::::try_into(packet.clone())) }); - c.bench_function("Encode packet noop", |b| { + group.bench_function("Encode packet noop", |b| { let packet = Packet::Noop; b.iter(|| TryInto::::try_into(packet.clone())) }); - c.bench_function("Encode packet binary b64", |b| { + group.bench_function("Encode packet binary b64", |b| { const BYTES: Bytes = Bytes::from_static(&[0x00, 0x01, 0x02, 0x03, 0x04, 0x05]); let packet = Packet::Binary(BYTES); b.iter(|| TryInto::::try_into(packet.clone())) }); + + group.finish(); } criterion_group!(benches, criterion_benchmark); diff --git a/socketioxide/benches/extensions.rs b/socketioxide/benches/extensions.rs index 4897fa5e..39c172b1 100644 --- a/socketioxide/benches/extensions.rs +++ b/socketioxide/benches/extensions.rs @@ -9,20 +9,18 @@ fn bench_extensions(c: &mut Criterion) { ext.insert(5i32); }); }); - // group.bench_function("concurrent_get", |b| { - // b.iter(|| { - // let mut ext = Extensions::new(); - // ext.insert(5i32); - // ext.clear(); - // }) - // }); - // group.bench_function("concurrent_get_inserts", |b| { - // b.iter(|| { - // let mut ext = Extensions::new(); - // ext.insert(5i32); - // ext.clear(); - // }) - // }); + group.bench_function("concurrent_get", |b| { + let ext = Extensions::new(); + b.iter(|| { + ext.insert(5i32); + }) + }); + group.bench_function("concurrent_get_inserts", |b| { + b.iter(|| { + let mut ext = Extensions::new(); + ext.insert(5i32); + }) + }); group.finish(); } diff --git a/socketioxide/benches/itoa_bench.rs b/socketioxide/benches/itoa_bench.rs deleted file mode 100644 index dd84ba0b..00000000 --- a/socketioxide/benches/itoa_bench.rs +++ /dev/null @@ -1,38 +0,0 @@ -//! # itoa_bench, used to compare best solutions related to this issue: https://github.com/Totodore/socketioxide/issues/143 - -use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; - -/// This solution doesn't imply additional buffer or dependency call -fn insert_number_reverse(mut v: u32, res: &mut String) { - let mut buf = [0u8; 10]; - let mut i = 1; - while v > 0 { - let n = (v % 10) as u8; - buf[10 - i] = n + 0x30; - v /= 10; - i += 1; - } - res.push_str(unsafe { std::str::from_utf8_unchecked(&buf[10 - (i - 1)..]) }); -} - -/// This solution uses the itoa crate -fn insert_number_itoa(v: u32, res: &mut String) { - let mut buffer = itoa::Buffer::new(); - res.push_str(buffer.format(v)); -} - -fn bench_itoa(c: &mut Criterion) { - let mut group = c.benchmark_group("itoa"); - for i in [u32::MAX / 200000, u32::MAX / 2, u32::MAX].iter() { - group.bench_with_input(BenchmarkId::new("number_reverse", i), i, |b, i| { - b.iter(|| insert_number_reverse(*i, &mut String::new())) - }); - group.bench_with_input(BenchmarkId::new("number_itoa", i), i, |b, i| { - b.iter(|| insert_number_itoa(*i, &mut String::new())) - }); - } - group.finish(); -} - -criterion_group!(benches, bench_itoa); -criterion_main!(benches); diff --git a/socketioxide/benches/packet_decode.rs b/socketioxide/benches/packet_decode.rs index a22783c8..db50c769 100644 --- a/socketioxide/benches/packet_decode.rs +++ b/socketioxide/benches/packet_decode.rs @@ -6,14 +6,15 @@ use socketioxide::{ ProtocolVersion, }; fn criterion_benchmark(c: &mut Criterion) { - c.bench_function("Decode packet connect on /", |b| { + let mut group = c.benchmark_group("socketio_packet/decode"); + group.bench_function("Decode packet connect on /", |b| { let packet: String = Packet::connect(black_box("/"), black_box(Sid::ZERO), ProtocolVersion::V5) .try_into() .unwrap(); b.iter(|| Packet::try_from(packet.clone()).unwrap()) }); - c.bench_function("Decode packet connect on /custom_nsp", |b| { + group.bench_function("Decode packet connect on /custom_nsp", |b| { let packet: String = Packet::connect( black_box("/custom_nsp"), black_box(Sid::ZERO), @@ -26,7 +27,7 @@ fn criterion_benchmark(c: &mut Criterion) { const DATA: &str = r#"{"_placeholder":true,"num":0}"#; const BINARY: Bytes = Bytes::from_static(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); - c.bench_function("Decode packet event on /", |b| { + group.bench_function("Decode packet event on /", |b| { let data = serde_json::to_value(DATA).unwrap(); let packet: String = Packet::event(black_box("/"), black_box("event"), black_box(data.clone())) @@ -35,7 +36,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| Packet::try_from(packet.clone()).unwrap()) }); - c.bench_function("Decode packet event on /custom_nsp", |b| { + group.bench_function("Decode packet event on /custom_nsp", |b| { let data = serde_json::to_value(DATA).unwrap(); let packet: String = Packet::event( black_box("custom_nsp"), @@ -47,7 +48,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| Packet::try_from(packet.clone()).unwrap()) }); - c.bench_function("Decode packet event with ack on /", |b| { + group.bench_function("Decode packet event with ack on /", |b| { let data = serde_json::to_value(DATA).unwrap(); let packet: Packet = Packet::event(black_box("/"), black_box("event"), black_box(data.clone())); @@ -59,7 +60,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| Packet::try_from(packet.clone()).unwrap()) }); - c.bench_function("Decode packet event with ack on /custom_nsp", |b| { + group.bench_function("Decode packet event with ack on /custom_nsp", |b| { let data = serde_json::to_value(DATA).unwrap(); let packet = Packet::event( black_box("/custom_nsp"), @@ -75,7 +76,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| Packet::try_from(packet.clone()).unwrap()) }); - c.bench_function("Decode packet ack on /", |b| { + group.bench_function("Decode packet ack on /", |b| { let data = serde_json::to_value(DATA).unwrap(); let packet: String = Packet::ack(black_box("/"), black_box(data.clone()), black_box(0)) .try_into() @@ -83,7 +84,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| Packet::try_from(packet.clone()).unwrap()) }); - c.bench_function("Decode packet ack on /custom_nsp", |b| { + group.bench_function("Decode packet ack on /custom_nsp", |b| { let data = serde_json::to_value(DATA).unwrap(); let packet: String = Packet::ack( black_box("/custom_nsp"), @@ -95,7 +96,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| Packet::try_from(packet.clone()).unwrap()) }); - c.bench_function("Decode packet binary event (b64) on /", |b| { + group.bench_function("Decode packet binary event (b64) on /", |b| { let data = serde_json::to_value(DATA).unwrap(); let packet: String = Packet::bin_event( black_box("/"), @@ -108,7 +109,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| Packet::try_from(packet.clone()).unwrap()) }); - c.bench_function("Decode packet binary event (b64) on /custom_nsp", |b| { + group.bench_function("Decode packet binary event (b64) on /custom_nsp", |b| { let data = serde_json::to_value(DATA).unwrap(); let packet: String = Packet::bin_event( black_box("/custom_nsp"), @@ -121,7 +122,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| Packet::try_from(packet.clone()).unwrap()) }); - c.bench_function("Decode packet binary ack (b64) on /", |b| { + group.bench_function("Decode packet binary ack (b64) on /", |b| { let data = serde_json::to_value(DATA).unwrap(); let packet: String = Packet::bin_ack( black_box("/"), @@ -134,7 +135,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| Packet::try_from(packet.clone()).unwrap()) }); - c.bench_function("Decode packet binary ack (b64) on /custom_nsp", |b| { + group.bench_function("Decode packet binary ack (b64) on /custom_nsp", |b| { let data = serde_json::to_value(DATA).unwrap(); let packet: String = Packet::bin_ack( black_box("/custom_nsp"), @@ -146,6 +147,8 @@ fn criterion_benchmark(c: &mut Criterion) { .unwrap(); b.iter(|| Packet::try_from(packet.clone()).unwrap()) }); + + group.finish(); } criterion_group!(benches, criterion_benchmark); diff --git a/socketioxide/benches/packet_encode.rs b/socketioxide/benches/packet_encode.rs index 7901c14d..d55ee92f 100644 --- a/socketioxide/benches/packet_encode.rs +++ b/socketioxide/benches/packet_encode.rs @@ -6,13 +6,14 @@ use socketioxide::{ ProtocolVersion, }; fn criterion_benchmark(c: &mut Criterion) { - c.bench_function("Encode packet connect on /", |b| { + let mut group = c.benchmark_group("socketio_packet/encode"); + group.bench_function("Encode packet connect on /", |b| { let packet = Packet::connect(black_box("/"), black_box(Sid::ZERO), ProtocolVersion::V5); b.iter(|| { let _: String = packet.clone().try_into().unwrap(); }) }); - c.bench_function("Encode packet connect on /custom_nsp", |b| { + group.bench_function("Encode packet connect on /custom_nsp", |b| { let packet = Packet::connect( black_box("/custom_nsp"), black_box(Sid::ZERO), @@ -26,7 +27,7 @@ fn criterion_benchmark(c: &mut Criterion) { const DATA: &str = r#"{"_placeholder":true,"num":0}"#; const BINARY: Bytes = Bytes::from_static(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); - c.bench_function("Encode packet event on /", |b| { + group.bench_function("Encode packet event on /", |b| { let data = serde_json::to_value(DATA).unwrap(); let packet = Packet::event(black_box("/"), black_box("event"), black_box(data.clone())); b.iter(|| { @@ -34,7 +35,7 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); - c.bench_function("Encode packet event on /custom_nsp", |b| { + group.bench_function("Encode packet event on /custom_nsp", |b| { let data = serde_json::to_value(DATA).unwrap(); let packet = Packet::event( black_box("custom_nsp"), @@ -46,7 +47,7 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); - c.bench_function("Encode packet event with ack on /", |b| { + group.bench_function("Encode packet event with ack on /", |b| { let data = serde_json::to_value(DATA).unwrap(); let packet = Packet::event(black_box("/"), black_box("event"), black_box(data.clone())); match packet.inner { @@ -58,7 +59,7 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); - c.bench_function("Encode packet event with ack on /custom_nsp", |b| { + group.bench_function("Encode packet event with ack on /custom_nsp", |b| { let data = serde_json::to_value(DATA).unwrap(); let packet = Packet::event( black_box("/custom_nsp"), @@ -74,7 +75,7 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); - c.bench_function("Encode packet ack on /", |b| { + group.bench_function("Encode packet ack on /", |b| { let data = serde_json::to_value(DATA).unwrap(); let packet = Packet::ack(black_box("/"), black_box(data.clone()), black_box(0)); b.iter(|| { @@ -82,7 +83,7 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); - c.bench_function("Encode packet ack on /custom_nsp", |b| { + group.bench_function("Encode packet ack on /custom_nsp", |b| { let data = serde_json::to_value(DATA).unwrap(); let packet = Packet::ack( black_box("/custom_nsp"), @@ -94,7 +95,7 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); - c.bench_function("Encode packet binary event (b64) on /", |b| { + group.bench_function("Encode packet binary event (b64) on /", |b| { let data = serde_json::to_value(DATA).unwrap(); let packet = Packet::bin_event( black_box("/"), @@ -107,7 +108,7 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); - c.bench_function("Encode packet binary event (b64) on /custom_nsp", |b| { + group.bench_function("Encode packet binary event (b64) on /custom_nsp", |b| { let data = serde_json::to_value(DATA).unwrap(); let packet = Packet::bin_event( black_box("/custom_nsp"), @@ -120,7 +121,7 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); - c.bench_function("Encode packet binary ack (b64) on /", |b| { + group.bench_function("Encode packet binary ack (b64) on /", |b| { let data = serde_json::to_value(DATA).unwrap(); let packet = Packet::bin_ack( black_box("/"), @@ -133,7 +134,7 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); - c.bench_function("Encode packet binary ack (b64) on /custom_nsp", |b| { + group.bench_function("Encode packet binary ack (b64) on /custom_nsp", |b| { let data = serde_json::to_value(DATA).unwrap(); let packet = Packet::bin_ack( black_box("/custom_nsp"), @@ -145,6 +146,8 @@ fn criterion_benchmark(c: &mut Criterion) { let _: String = packet.clone().try_into().unwrap(); }) }); + + group.finish(); } criterion_group!(benches, criterion_benchmark); From 82fefb593312ca2da350b0a516181f040ddf8c1d Mon Sep 17 00:00:00 2001 From: totodore Date: Fri, 19 Apr 2024 20:15:51 -0300 Subject: [PATCH 09/19] chore(bench): improve extensions bench --- socketioxide/Cargo.toml | 7 +------ socketioxide/benches/extensions.rs | 25 ++++++++++++++++++------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/socketioxide/Cargo.toml b/socketioxide/Cargo.toml index d5958756..02f4961b 100644 --- a/socketioxide/Cargo.toml +++ b/socketioxide/Cargo.toml @@ -57,7 +57,7 @@ criterion.workspace = true hyper = { workspace = true, features = ["server", "http1"] } hyper-util = { workspace = true, features = ["tokio", "client-legacy"] } http-body-util.workspace = true - +rand = { version = "0.8", default-features = false } # docs.rs-specific configuration [package.metadata.docs.rs] features = ["v4", "extensions", "tracing", "state"] @@ -74,11 +74,6 @@ name = "packet_decode" path = "benches/packet_decode.rs" harness = false -[[bench]] -name = "itoa_bench" -path = "benches/itoa_bench.rs" -harness = false - [[bench]] name = "extensions" path = "benches/extensions.rs" diff --git a/socketioxide/benches/extensions.rs b/socketioxide/benches/extensions.rs index 39c172b1..8bd9c05d 100644 --- a/socketioxide/benches/extensions.rs +++ b/socketioxide/benches/extensions.rs @@ -1,25 +1,36 @@ -use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use criterion::{black_box, criterion_group, criterion_main, BatchSize, Criterion}; +use rand::Rng; use socketioxide::extensions::Extensions; fn bench_extensions(c: &mut Criterion) { + let i = black_box(5i32); let mut group = c.benchmark_group("extensions"); group.bench_function("concurrent_inserts", |b| { let ext = Extensions::new(); b.iter(|| { - ext.insert(5i32); + ext.insert(i); }); }); group.bench_function("concurrent_get", |b| { let ext = Extensions::new(); + ext.insert(i); b.iter(|| { - ext.insert(5i32); + ext.get::(); }) }); group.bench_function("concurrent_get_inserts", |b| { - b.iter(|| { - let mut ext = Extensions::new(); - ext.insert(5i32); - }) + let ext = Extensions::new(); + b.iter_batched( + || rand::thread_rng().gen_range(0..3), + |i| { + if i == 0 { + ext.insert(i); + } else { + ext.get::(); + } + }, + BatchSize::SmallInput, + ) }); group.finish(); } From 164a7ae7e65e9d9b0ee40755f5abe81c8209de9d Mon Sep 17 00:00:00 2001 From: totodore Date: Sat, 20 Apr 2024 00:58:57 -0300 Subject: [PATCH 10/19] feat(socketio/extract): refactor extract mod --- socketioxide/src/extract/data.rs | 105 +++++++ socketioxide/src/extract/mod.rs | 185 +++++++++++ socketioxide/src/extract/socket.rs | 196 ++++++++++++ socketioxide/src/extract/state.rs | 105 +++++++ socketioxide/src/handler/extract.rs | 464 ---------------------------- socketioxide/src/handler/mod.rs | 1 - socketioxide/src/lib.rs | 2 +- 7 files changed, 592 insertions(+), 466 deletions(-) create mode 100644 socketioxide/src/extract/data.rs create mode 100644 socketioxide/src/extract/mod.rs create mode 100644 socketioxide/src/extract/socket.rs create mode 100644 socketioxide/src/extract/state.rs delete mode 100644 socketioxide/src/handler/extract.rs diff --git a/socketioxide/src/extract/data.rs b/socketioxide/src/extract/data.rs new file mode 100644 index 00000000..5ccb4b12 --- /dev/null +++ b/socketioxide/src/extract/data.rs @@ -0,0 +1,105 @@ +use std::convert::Infallible; +use std::sync::Arc; + +use crate::handler::{FromConnectParts, FromMessage, FromMessageParts}; +use crate::{adapter::Adapter, socket::Socket}; +use bytes::Bytes; +use serde::de::DeserializeOwned; +use serde_json::Value; + +/// Utility function to unwrap an array with a single element +fn upwrap_array(v: &mut Value) { + match v { + Value::Array(vec) if vec.len() == 1 => { + *v = vec.pop().unwrap(); + } + _ => (), + } +} + +/// An Extractor that returns the serialized auth data without checking errors. +/// If a deserialization error occurs, the [`ConnectHandler`](super::ConnectHandler) won't be called +/// and an error log will be print if the `tracing` feature is enabled. +pub struct Data(pub T); +impl FromConnectParts for Data +where + T: DeserializeOwned, + A: Adapter, +{ + type Error = serde_json::Error; + fn from_connect_parts(_: &Arc>, auth: &Option) -> Result { + auth.as_ref() + .map(|a| serde_json::from_str::(a)) + .unwrap_or(serde_json::from_str::("{}")) + .map(Data) + } +} +impl FromMessageParts for Data +where + T: DeserializeOwned, + A: Adapter, +{ + type Error = serde_json::Error; + fn from_message_parts( + _: &Arc>, + v: &mut serde_json::Value, + _: &mut Vec, + _: &Option, + ) -> Result { + upwrap_array(v); + serde_json::from_value(v.clone()).map(Data) + } +} + +/// An Extractor that returns the deserialized data related to the event. +pub struct TryData(pub Result); + +impl FromConnectParts for TryData +where + T: DeserializeOwned, + A: Adapter, +{ + type Error = Infallible; + fn from_connect_parts(_: &Arc>, auth: &Option) -> Result { + let v: Result = auth + .as_ref() + .map(|a| serde_json::from_str(a)) + .unwrap_or(serde_json::from_str("{}")); + Ok(TryData(v)) + } +} +impl FromMessageParts for TryData +where + T: DeserializeOwned, + A: Adapter, +{ + type Error = Infallible; + fn from_message_parts( + _: &Arc>, + v: &mut serde_json::Value, + _: &mut Vec, + _: &Option, + ) -> Result { + upwrap_array(v); + Ok(TryData(serde_json::from_value(v.clone()))) + } +} + +/// An Extractor that returns the binary data of the message. +/// If there is no binary data, it will contain an empty vec. +pub struct Bin(pub Vec); +impl FromMessage for Bin { + type Error = Infallible; + fn from_message( + _: Arc>, + _: serde_json::Value, + bin: Vec, + _: Option, + ) -> Result { + Ok(Bin(bin)) + } +} + +super::__impl_deref!(Bin: Vec); +super::__impl_deref!(TryData: Result); +super::__impl_deref!(Data); diff --git a/socketioxide/src/extract/mod.rs b/socketioxide/src/extract/mod.rs new file mode 100644 index 00000000..b84e92ca --- /dev/null +++ b/socketioxide/src/extract/mod.rs @@ -0,0 +1,185 @@ +//! ### Extractors for [`ConnectHandler`](super::ConnectHandler), [`ConnectMiddleware`](super::ConnectMiddleware), +//! [`MessageHandler`](super::MessageHandler) +//! and [`DisconnectHandler`](super::DisconnectHandler). +//! +//! They can be used to extract data from the context of the handler and get specific params. Here are some examples of extractors: +//! * [`Data`]: extracts and deserialize to json any data, if a deserialization error occurs the handler won't be called: +//! - for [`ConnectHandler`](super::ConnectHandler): extracts and deserialize to json the auth data +//! - for [`ConnectMiddleware`](super::ConnectMiddleware): extract and deserialize to json the auth data. +//! In case of error, the middleware chain stops and a `connect_error` event is sent. +//! - for [`MessageHandler`](super::MessageHandler): extracts and deserialize to json the message data +//! * [`TryData`]: extracts and deserialize to json any data but with a `Result` type in case of error: +//! - for [`ConnectHandler`](super::ConnectHandler) and [`ConnectMiddleware`](super::ConnectMiddleware): +//! extracts and deserialize to json the auth data +//! - for [`MessageHandler`](super::MessageHandler): extracts and deserialize to json the message data +//! * [`SocketRef`]: extracts a reference to the [`Socket`] +//! * [`Bin`]: extract a binary payload for a given message. Because it consumes the event it should be the last argument +//! * [`AckSender`]: Can be used to send an ack response to the current message event +//! * [`ProtocolVersion`](crate::ProtocolVersion): extracts the protocol version +//! * [`TransportType`](crate::TransportType): extracts the transport type +//! * [`DisconnectReason`](crate::socket::DisconnectReason): extracts the reason of the disconnection +//! * [`State`]: extracts a reference to a state previously set with [`SocketIoBuilder::with_state`](crate::io::SocketIoBuilder). +//! * [`Extension`]: extracts an extension of the given type +//! * [`MaybeExtension`]: extracts an extension of the given type if it exists or `None` otherwise +//! * [`HttpExtension`]: extracts an http extension of the given type coming from the request. +//! (Similar to axum's [`extract::Extension`](https://docs.rs/axum/latest/axum/struct.Extension.html) +//! * [`MaybeHttpExtension`]: extracts an http extension of the given type if it exists or `None` otherwise. +//! +//! ### You can also implement your own Extractor with the [`FromConnectParts`], [`FromMessageParts`] and [`FromDisconnectParts`] traits +//! When implementing these traits, if you clone the [`Arc`] make sure that it is dropped at least when the socket is disconnected. +//! Otherwise it will create a memory leak. It is why the [`SocketRef`] extractor is used instead of cloning the socket for common usage. +//! +//! #### Example that extracts a user id from the query params +//! ```rust +//! # use bytes::Bytes; +//! # use socketioxide::handler::{FromConnectParts, FromMessageParts}; +//! # use socketioxide::adapter::Adapter; +//! # use socketioxide::socket::Socket; +//! # use std::sync::Arc; +//! # use std::convert::Infallible; +//! # use socketioxide::SocketIo; +//! +//! struct UserId(String); +//! +//! #[derive(Debug)] +//! struct UserIdNotFound; +//! impl std::fmt::Display for UserIdNotFound { +//! fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +//! write!(f, "User id not found") +//! } +//! } +//! impl std::error::Error for UserIdNotFound {} +//! +//! impl FromConnectParts for UserId { +//! type Error = Infallible; +//! fn from_connect_parts(s: &Arc>, _: &Option) -> Result { +//! // In a real app it would be better to parse the query params with a crate like `url` +//! let uri = &s.req_parts().uri; +//! let uid = uri +//! .query() +//! .and_then(|s| s.split('&').find(|s| s.starts_with("id=")).map(|s| &s[3..])) +//! .unwrap_or_default(); +//! // Currently, it is not possible to have lifetime on the extracted data +//! Ok(UserId(uid.to_string())) +//! } +//! } +//! +//! // Here, if the user id is not found, the handler won't be called +//! // and a tracing `error` log will be printed (if the `tracing` feature is enabled) +//! impl FromMessageParts for UserId { +//! type Error = UserIdNotFound; +//! +//! fn from_message_parts( +//! s: &Arc>, +//! _: &mut serde_json::Value, +//! _: &mut Vec, +//! _: &Option, +//! ) -> Result { +//! // In a real app it would be better to parse the query params with a crate like `url` +//! let uri = &s.req_parts().uri; +//! let uid = uri +//! .query() +//! .and_then(|s| s.split('&').find(|s| s.starts_with("id=")).map(|s| &s[3..])) +//! .ok_or(UserIdNotFound)?; +//! // Currently, it is not possible to have lifetime on the extracted data +//! Ok(UserId(uid.to_string())) +//! } +//! } +//! +//! fn handler(user_id: UserId) { +//! println!("User id: {}", user_id.0); +//! } +//! let (svc, io) = SocketIo::new_svc(); +//! io.ns("/", handler); +//! // Use the service with your favorite http server + +mod data; +mod extensions; +mod socket; + +#[cfg(feature = "state")] +#[cfg_attr(docsrs, doc(cfg(feature = "state")))] +mod state; + +pub use data::*; +pub use socket::*; +#[cfg(feature = "state")] +#[cfg_attr(docsrs, doc(cfg(feature = "state")))] +pub use state::*; + +/// Private API. +#[doc(hidden)] +macro_rules! __impl_deref { + ($ident:ident) => { + impl std::ops::Deref for $ident { + type Target = T; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.0 + } + } + + impl std::ops::DerefMut for $ident { + #[inline] + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } + } + }; + + ($ident:ident<$($gen:ident),+>) => { + impl<$($gen),+> std::ops::Deref for $ident<$($gen),+> { + type Target = $($gen),+; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.0 + } + } + + impl<$($gen),+> std::ops::DerefMut for $ident<$($gen),+> { + #[inline] + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } + } + }; + + ($ident:ident<$($gen:ident),+>: $ty:ty) => { + impl<$($gen),+> std::ops::Deref for $ident<$($gen),+> { + type Target = $ty; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.0 + } + } + + impl<$($gen),+> std::ops::DerefMut for $ident<$($gen),+> { + #[inline] + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } + } + }; + + ($ident:ident: $ty:ty) => { + impl std::ops::Deref for $ident { + type Target = $ty; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.0 + } + } + + impl std::ops::DerefMut for $ident { + #[inline] + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } + } + }; +} +pub(crate) use __impl_deref; diff --git a/socketioxide/src/extract/socket.rs b/socketioxide/src/extract/socket.rs new file mode 100644 index 00000000..0a08ba27 --- /dev/null +++ b/socketioxide/src/extract/socket.rs @@ -0,0 +1,196 @@ +use std::convert::Infallible; +use std::sync::Arc; + +use crate::errors::{DisconnectError, SendError}; +use crate::handler::{FromConnectParts, FromDisconnectParts, FromMessageParts}; +use crate::socket::DisconnectReason; +use crate::{ + adapter::{Adapter, LocalAdapter}, + packet::Packet, + socket::Socket, +}; +use bytes::Bytes; +use serde::Serialize; + +/// An Extractor that returns a reference to a [`Socket`]. +#[derive(Debug)] +pub struct SocketRef(Arc>); + +impl FromConnectParts for SocketRef { + type Error = Infallible; + fn from_connect_parts(s: &Arc>, _: &Option) -> Result { + Ok(SocketRef(s.clone())) + } +} +impl FromMessageParts for SocketRef { + type Error = Infallible; + fn from_message_parts( + s: &Arc>, + _: &mut serde_json::Value, + _: &mut Vec, + _: &Option, + ) -> Result { + Ok(SocketRef(s.clone())) + } +} +impl FromDisconnectParts for SocketRef { + type Error = Infallible; + fn from_disconnect_parts(s: &Arc>, _: DisconnectReason) -> Result { + Ok(SocketRef(s.clone())) + } +} + +impl std::ops::Deref for SocketRef { + type Target = Socket; + #[inline(always)] + fn deref(&self) -> &Self::Target { + &self.0 + } +} +impl PartialEq for SocketRef { + #[inline(always)] + fn eq(&self, other: &Self) -> bool { + self.0.id == other.0.id + } +} +impl From>> for SocketRef { + #[inline(always)] + fn from(socket: Arc>) -> Self { + Self(socket) + } +} + +impl Clone for SocketRef { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +impl SocketRef { + /// Disconnect the socket from the current namespace, + /// + /// It will also call the disconnect handler if it is set. + #[inline(always)] + pub fn disconnect(self) -> Result<(), DisconnectError> { + self.0.disconnect() + } +} + +/// An Extractor to send an ack response corresponding to the current event. +/// If the client sent a normal message without expecting an ack, the ack callback will do nothing. +#[derive(Debug)] +pub struct AckSender { + binary: Vec, + socket: Arc>, + ack_id: Option, +} +impl FromMessageParts for AckSender { + type Error = Infallible; + fn from_message_parts( + s: &Arc>, + _: &mut serde_json::Value, + _: &mut Vec, + ack_id: &Option, + ) -> Result { + Ok(Self::new(s.clone(), *ack_id)) + } +} +impl AckSender { + pub(crate) fn new(socket: Arc>, ack_id: Option) -> Self { + Self { + binary: vec![], + socket, + ack_id, + } + } + + /// Add binary data to the ack response. + pub fn bin(mut self, bin: impl IntoIterator>) -> Self { + self.binary = bin.into_iter().map(Into::into).collect(); + self + } + + /// Send the ack response to the client. + pub fn send(self, data: T) -> Result<(), SendError> { + use crate::socket::PermitExt; + if let Some(ack_id) = self.ack_id { + let permit = match self.socket.reserve() { + Ok(permit) => permit, + Err(e) => { + #[cfg(feature = "tracing")] + tracing::debug!("sending error during emit message: {e:?}"); + return Err(e.with_value(data).into()); + } + }; + let ns = self.socket.ns(); + let data = serde_json::to_value(data)?; + let packet = if self.binary.is_empty() { + Packet::ack(ns, data, ack_id) + } else { + Packet::bin_ack(ns, data, self.binary, ack_id) + }; + permit.send(packet); + Ok(()) + } else { + Ok(()) + } + } +} + +impl FromConnectParts for crate::ProtocolVersion { + type Error = Infallible; + fn from_connect_parts(s: &Arc>, _: &Option) -> Result { + Ok(s.protocol()) + } +} +impl FromMessageParts for crate::ProtocolVersion { + type Error = Infallible; + fn from_message_parts( + s: &Arc>, + _: &mut serde_json::Value, + _: &mut Vec, + _: &Option, + ) -> Result { + Ok(s.protocol()) + } +} +impl FromDisconnectParts for crate::ProtocolVersion { + type Error = Infallible; + fn from_disconnect_parts(s: &Arc>, _: DisconnectReason) -> Result { + Ok(s.protocol()) + } +} + +impl FromConnectParts for crate::TransportType { + type Error = Infallible; + fn from_connect_parts(s: &Arc>, _: &Option) -> Result { + Ok(s.transport_type()) + } +} +impl FromMessageParts for crate::TransportType { + type Error = Infallible; + fn from_message_parts( + s: &Arc>, + _: &mut serde_json::Value, + _: &mut Vec, + _: &Option, + ) -> Result { + Ok(s.transport_type()) + } +} +impl FromDisconnectParts for crate::TransportType { + type Error = Infallible; + fn from_disconnect_parts(s: &Arc>, _: DisconnectReason) -> Result { + Ok(s.transport_type()) + } +} + +impl FromDisconnectParts for DisconnectReason { + type Error = Infallible; + fn from_disconnect_parts( + _: &Arc>, + reason: DisconnectReason, + ) -> Result { + Ok(reason) + } +} diff --git a/socketioxide/src/extract/state.rs b/socketioxide/src/extract/state.rs new file mode 100644 index 00000000..4c7f67ad --- /dev/null +++ b/socketioxide/src/extract/state.rs @@ -0,0 +1,105 @@ +use bytes::Bytes; + +use crate::state::get_state; +use std::ops::Deref; +use std::sync::Arc; + +use crate::adapter::Adapter; +use crate::handler::{FromConnectParts, FromDisconnectParts, FromMessageParts}; +use crate::socket::{DisconnectReason, Socket}; + +/// An Extractor that contains a reference to a state previously set with [`SocketIoBuilder::with_state`](crate::io::SocketIoBuilder). +/// It implements [`std::ops::Deref`] to access the inner type so you can use it as a normal reference. +/// +/// The specified state type must be the same as the one set with [`SocketIoBuilder::with_state`](crate::io::SocketIoBuilder). +/// If it is not the case, the handler won't be called and an error log will be print if the `tracing` feature is enabled. +/// +/// The state is shared between the entire socket.io app context. +/// +/// ### Example +/// ``` +/// # use socketioxide::{SocketIo, extract::{SocketRef, State}}; +/// # use serde::{Serialize, Deserialize}; +/// # use std::sync::atomic::AtomicUsize; +/// # use std::sync::atomic::Ordering; +/// #[derive(Default)] +/// struct MyAppData { +/// user_cnt: AtomicUsize, +/// } +/// impl MyAppData { +/// fn add_user(&self) { +/// self.user_cnt.fetch_add(1, Ordering::SeqCst); +/// } +/// fn rm_user(&self) { +/// self.user_cnt.fetch_sub(1, Ordering::SeqCst); +/// } +/// } +/// let (_, io) = SocketIo::builder().with_state(MyAppData::default()).build_svc(); +/// io.ns("/", |socket: SocketRef, state: State| { +/// state.add_user(); +/// println!("User count: {}", state.user_cnt.load(Ordering::SeqCst)); +/// }); +pub struct State(pub &'static T); + +/// It was impossible to find the given state and therefore the handler won't be called. +pub struct StateNotFound(std::marker::PhantomData); + +impl std::fmt::Display for StateNotFound { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "State of type {} not found, maybe you forgot to insert it in the extensions map?", + std::any::type_name::() + ) + } +} +impl std::fmt::Debug for StateNotFound { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "StateNotFound {}", std::any::type_name::()) + } +} +impl std::error::Error for StateNotFound {} + +impl FromConnectParts for State { + type Error = StateNotFound; + fn from_connect_parts( + _: &Arc>, + _: &Option, + ) -> Result> { + get_state::() + .map(State) + .ok_or(StateNotFound(std::marker::PhantomData)) + } +} +impl FromDisconnectParts for State { + type Error = StateNotFound; + fn from_disconnect_parts( + _: &Arc>, + _: DisconnectReason, + ) -> Result> { + get_state::() + .map(State) + .ok_or(StateNotFound(std::marker::PhantomData)) + } +} +impl FromMessageParts for State { + type Error = StateNotFound; + fn from_message_parts( + _: &Arc>, + _: &mut serde_json::Value, + _: &mut Vec, + _: &Option, + ) -> Result> { + get_state::() + .map(State) + .ok_or(StateNotFound(std::marker::PhantomData)) + } +} + +impl Deref for State { + type Target = &'static T; + #[inline(always)] + fn deref(&self) -> &Self::Target { + &self.0 + } +} diff --git a/socketioxide/src/handler/extract.rs b/socketioxide/src/handler/extract.rs deleted file mode 100644 index 568cf8b3..00000000 --- a/socketioxide/src/handler/extract.rs +++ /dev/null @@ -1,464 +0,0 @@ -//! ### Extractors for [`ConnectHandler`](super::ConnectHandler), [`ConnectMiddleware`](super::ConnectMiddleware), -//! [`MessageHandler`](super::MessageHandler) -//! and [`DisconnectHandler`](super::DisconnectHandler). -//! -//! They can be used to extract data from the context of the handler and get specific params. Here are some examples of extractors: -//! * [`Data`]: extracts and deserialize to json any data, if a deserialization error occurs the handler won't be called: -//! - for [`ConnectHandler`](super::ConnectHandler): extracts and deserialize to json the auth data -//! - for [`ConnectMiddleware`](super::ConnectMiddleware): extract and deserialize to json the auth data. -//! In case of error, the middleware chain stops and a `connect_error` event is sent. -//! - for [`MessageHandler`](super::MessageHandler): extracts and deserialize to json the message data -//! * [`TryData`]: extracts and deserialize to json any data but with a `Result` type in case of error: -//! - for [`ConnectHandler`](super::ConnectHandler) and [`ConnectMiddleware`](super::ConnectMiddleware): -//! extracts and deserialize to json the auth data -//! - for [`MessageHandler`](super::MessageHandler): extracts and deserialize to json the message data -//! * [`SocketRef`]: extracts a reference to the [`Socket`] -//! * [`Bin`]: extract a binary payload for a given message. Because it consumes the event it should be the last argument -//! * [`AckSender`]: Can be used to send an ack response to the current message event -//! * [`ProtocolVersion`](crate::ProtocolVersion): extracts the protocol version -//! * [`TransportType`](crate::TransportType): extracts the transport type -//! * [`DisconnectReason`]: extracts the reason of the disconnection -//! * [`State`]: extracts a reference to a state previously set with [`SocketIoBuilder::with_state`](crate::io::SocketIoBuilder). -//! -//! ### You can also implement your own Extractor with the [`FromConnectParts`], [`FromMessageParts`] and [`FromDisconnectParts`] traits -//! When implementing these traits, if you clone the [`Arc`] make sure that it is dropped at least when the socket is disconnected. -//! Otherwise it will create a memory leak. It is why the [`SocketRef`] extractor is used instead of cloning the socket for common usage. -//! -//! #### Example that extracts a user id from the query params -//! ```rust -//! # use bytes::Bytes; -//! # use socketioxide::handler::{FromConnectParts, FromMessageParts}; -//! # use socketioxide::adapter::Adapter; -//! # use socketioxide::socket::Socket; -//! # use std::sync::Arc; -//! # use std::convert::Infallible; -//! # use socketioxide::SocketIo; -//! -//! struct UserId(String); -//! -//! #[derive(Debug)] -//! struct UserIdNotFound; -//! impl std::fmt::Display for UserIdNotFound { -//! fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { -//! write!(f, "User id not found") -//! } -//! } -//! impl std::error::Error for UserIdNotFound {} -//! -//! impl FromConnectParts for UserId { -//! type Error = Infallible; -//! fn from_connect_parts(s: &Arc>, _: &Option) -> Result { -//! // In a real app it would be better to parse the query params with a crate like `url` -//! let uri = &s.req_parts().uri; -//! let uid = uri -//! .query() -//! .and_then(|s| s.split('&').find(|s| s.starts_with("id=")).map(|s| &s[3..])) -//! .unwrap_or_default(); -//! // Currently, it is not possible to have lifetime on the extracted data -//! Ok(UserId(uid.to_string())) -//! } -//! } -//! -//! // Here, if the user id is not found, the handler won't be called -//! // and a tracing `error` log will be printed (if the `tracing` feature is enabled) -//! impl FromMessageParts for UserId { -//! type Error = UserIdNotFound; -//! -//! fn from_message_parts( -//! s: &Arc>, -//! _: &mut serde_json::Value, -//! _: &mut Vec, -//! _: &Option, -//! ) -> Result { -//! // In a real app it would be better to parse the query params with a crate like `url` -//! let uri = &s.req_parts().uri; -//! let uid = uri -//! .query() -//! .and_then(|s| s.split('&').find(|s| s.starts_with("id=")).map(|s| &s[3..])) -//! .ok_or(UserIdNotFound)?; -//! // Currently, it is not possible to have lifetime on the extracted data -//! Ok(UserId(uid.to_string())) -//! } -//! } -//! -//! fn handler(user_id: UserId) { -//! println!("User id: {}", user_id.0); -//! } -//! let (svc, io) = SocketIo::new_svc(); -//! io.ns("/", handler); -//! // Use the service with your favorite http server -use std::convert::Infallible; -use std::sync::Arc; - -use super::message::FromMessageParts; -use super::FromDisconnectParts; -use super::{connect::FromConnectParts, message::FromMessage}; -use crate::errors::{DisconnectError, SendError}; -use crate::socket::DisconnectReason; -use crate::{ - adapter::{Adapter, LocalAdapter}, - packet::Packet, - socket::Socket, -}; -use bytes::Bytes; -use serde::{de::DeserializeOwned, Serialize}; -use serde_json::Value; - -#[cfg(feature = "state")] -#[cfg_attr(docsrs, doc(cfg(feature = "state")))] -pub use state_extract::*; - -/// Utility function to unwrap an array with a single element -fn upwrap_array(v: &mut Value) { - match v { - Value::Array(vec) if vec.len() == 1 => { - *v = vec.pop().unwrap(); - } - _ => (), - } -} - -/// An Extractor that returns the serialized auth data without checking errors. -/// If a deserialization error occurs, the [`ConnectHandler`](super::ConnectHandler) won't be called -/// and an error log will be print if the `tracing` feature is enabled. -pub struct Data(pub T); -impl FromConnectParts for Data -where - T: DeserializeOwned, - A: Adapter, -{ - type Error = serde_json::Error; - fn from_connect_parts(_: &Arc>, auth: &Option) -> Result { - auth.as_ref() - .map(|a| serde_json::from_str::(a)) - .unwrap_or(serde_json::from_str::("{}")) - .map(Data) - } -} -impl FromMessageParts for Data -where - T: DeserializeOwned, - A: Adapter, -{ - type Error = serde_json::Error; - fn from_message_parts( - _: &Arc>, - v: &mut serde_json::Value, - _: &mut Vec, - _: &Option, - ) -> Result { - upwrap_array(v); - serde_json::from_value(v.clone()).map(Data) - } -} - -/// An Extractor that returns the deserialized data related to the event. -pub struct TryData(pub Result); - -impl FromConnectParts for TryData -where - T: DeserializeOwned, - A: Adapter, -{ - type Error = Infallible; - fn from_connect_parts(_: &Arc>, auth: &Option) -> Result { - let v: Result = auth - .as_ref() - .map(|a| serde_json::from_str(a)) - .unwrap_or(serde_json::from_str("{}")); - Ok(TryData(v)) - } -} -impl FromMessageParts for TryData -where - T: DeserializeOwned, - A: Adapter, -{ - type Error = Infallible; - fn from_message_parts( - _: &Arc>, - v: &mut serde_json::Value, - _: &mut Vec, - _: &Option, - ) -> Result { - upwrap_array(v); - Ok(TryData(serde_json::from_value(v.clone()))) - } -} -/// An Extractor that returns a reference to a [`Socket`]. -#[derive(Debug)] -pub struct SocketRef(Arc>); - -impl FromConnectParts for SocketRef { - type Error = Infallible; - fn from_connect_parts(s: &Arc>, _: &Option) -> Result { - Ok(SocketRef(s.clone())) - } -} -impl FromMessageParts for SocketRef { - type Error = Infallible; - fn from_message_parts( - s: &Arc>, - _: &mut serde_json::Value, - _: &mut Vec, - _: &Option, - ) -> Result { - Ok(SocketRef(s.clone())) - } -} -impl FromDisconnectParts for SocketRef { - type Error = Infallible; - fn from_disconnect_parts(s: &Arc>, _: DisconnectReason) -> Result { - Ok(SocketRef(s.clone())) - } -} - -impl std::ops::Deref for SocketRef { - type Target = Socket; - #[inline(always)] - fn deref(&self) -> &Self::Target { - &self.0 - } -} -impl PartialEq for SocketRef { - #[inline(always)] - fn eq(&self, other: &Self) -> bool { - self.0.id == other.0.id - } -} -impl From>> for SocketRef { - #[inline(always)] - fn from(socket: Arc>) -> Self { - Self(socket) - } -} - -impl Clone for SocketRef { - fn clone(&self) -> Self { - Self(self.0.clone()) - } -} - -impl SocketRef { - /// Disconnect the socket from the current namespace, - /// - /// It will also call the disconnect handler if it is set. - #[inline(always)] - pub fn disconnect(self) -> Result<(), DisconnectError> { - self.0.disconnect() - } -} - -/// An Extractor that returns the binary data of the message. -/// If there is no binary data, it will contain an empty vec. -pub struct Bin(pub Vec); -impl FromMessage for Bin { - type Error = Infallible; - fn from_message( - _: Arc>, - _: serde_json::Value, - bin: Vec, - _: Option, - ) -> Result { - Ok(Bin(bin)) - } -} - -/// An Extractor to send an ack response corresponding to the current event. -/// If the client sent a normal message without expecting an ack, the ack callback will do nothing. -#[derive(Debug)] -pub struct AckSender { - binary: Vec, - socket: Arc>, - ack_id: Option, -} -impl FromMessageParts for AckSender { - type Error = Infallible; - fn from_message_parts( - s: &Arc>, - _: &mut serde_json::Value, - _: &mut Vec, - ack_id: &Option, - ) -> Result { - Ok(Self::new(s.clone(), *ack_id)) - } -} -impl AckSender { - pub(crate) fn new(socket: Arc>, ack_id: Option) -> Self { - Self { - binary: vec![], - socket, - ack_id, - } - } - - /// Add binary data to the ack response. - pub fn bin(mut self, bin: impl IntoIterator>) -> Self { - self.binary = bin.into_iter().map(Into::into).collect(); - self - } - - /// Send the ack response to the client. - pub fn send(self, data: T) -> Result<(), SendError> { - use crate::socket::PermitExt; - if let Some(ack_id) = self.ack_id { - let permit = match self.socket.reserve() { - Ok(permit) => permit, - Err(e) => { - #[cfg(feature = "tracing")] - tracing::debug!("sending error during emit message: {e:?}"); - return Err(e.with_value(data).into()); - } - }; - let ns = self.socket.ns(); - let data = serde_json::to_value(data)?; - let packet = if self.binary.is_empty() { - Packet::ack(ns, data, ack_id) - } else { - Packet::bin_ack(ns, data, self.binary, ack_id) - }; - permit.send(packet); - Ok(()) - } else { - Ok(()) - } - } -} - -impl FromConnectParts for crate::ProtocolVersion { - type Error = Infallible; - fn from_connect_parts(s: &Arc>, _: &Option) -> Result { - Ok(s.protocol()) - } -} -impl FromMessageParts for crate::ProtocolVersion { - type Error = Infallible; - fn from_message_parts( - s: &Arc>, - _: &mut serde_json::Value, - _: &mut Vec, - _: &Option, - ) -> Result { - Ok(s.protocol()) - } -} -impl FromDisconnectParts for crate::ProtocolVersion { - type Error = Infallible; - fn from_disconnect_parts(s: &Arc>, _: DisconnectReason) -> Result { - Ok(s.protocol()) - } -} - -impl FromConnectParts for crate::TransportType { - type Error = Infallible; - fn from_connect_parts(s: &Arc>, _: &Option) -> Result { - Ok(s.transport_type()) - } -} -impl FromMessageParts for crate::TransportType { - type Error = Infallible; - fn from_message_parts( - s: &Arc>, - _: &mut serde_json::Value, - _: &mut Vec, - _: &Option, - ) -> Result { - Ok(s.transport_type()) - } -} -impl FromDisconnectParts for crate::TransportType { - type Error = Infallible; - fn from_disconnect_parts(s: &Arc>, _: DisconnectReason) -> Result { - Ok(s.transport_type()) - } -} - -impl FromDisconnectParts for DisconnectReason { - type Error = Infallible; - fn from_disconnect_parts( - _: &Arc>, - reason: DisconnectReason, - ) -> Result { - Ok(reason) - } -} - -#[cfg(feature = "state")] -mod state_extract { - use super::*; - use crate::state::get_state; - - /// An Extractor that contains a reference to a state previously set with [`SocketIoBuilder::with_state`](crate::io::SocketIoBuilder). - /// It implements [`std::ops::Deref`] to access the inner type so you can use it as a normal reference. - /// - /// The specified state type must be the same as the one set with [`SocketIoBuilder::with_state`](crate::io::SocketIoBuilder). - /// If it is not the case, the handler won't be called and an error log will be print if the `tracing` feature is enabled. - /// - /// The state is shared between the entire socket.io app context. - /// - /// ### Example - /// ``` - /// # use socketioxide::{SocketIo, extract::{SocketRef, State}}; - /// # use serde::{Serialize, Deserialize}; - /// # use std::sync::atomic::AtomicUsize; - /// # use std::sync::atomic::Ordering; - /// #[derive(Default)] - /// struct MyAppData { - /// user_cnt: AtomicUsize, - /// } - /// impl MyAppData { - /// fn add_user(&self) { - /// self.user_cnt.fetch_add(1, Ordering::SeqCst); - /// } - /// fn rm_user(&self) { - /// self.user_cnt.fetch_sub(1, Ordering::SeqCst); - /// } - /// } - /// let (_, io) = SocketIo::builder().with_state(MyAppData::default()).build_svc(); - /// io.ns("/", |socket: SocketRef, state: State| { - /// state.add_user(); - /// println!("User count: {}", state.user_cnt.load(Ordering::SeqCst)); - /// }); - pub struct State(pub &'static T); - /// It was impossible to find the given state and therefore the handler won't be called. - #[derive(Debug, thiserror::Error)] - #[error("State not found")] - pub struct StateNotFound; - - impl std::ops::Deref for State { - type Target = T; - fn deref(&self) -> &Self::Target { - self.0 - } - } - - impl FromConnectParts for State { - type Error = StateNotFound; - fn from_connect_parts( - _: &Arc>, - _: &Option, - ) -> Result { - get_state::().map(State).ok_or(StateNotFound) - } - } - impl FromDisconnectParts for State { - type Error = StateNotFound; - fn from_disconnect_parts( - _: &Arc>, - _: DisconnectReason, - ) -> Result { - get_state::().map(State).ok_or(StateNotFound) - } - } - impl FromMessageParts for State { - type Error = StateNotFound; - fn from_message_parts( - _: &Arc>, - _: &mut serde_json::Value, - _: &mut Vec, - _: &Option, - ) -> Result { - get_state::().map(State).ok_or(StateNotFound) - } - } -} diff --git a/socketioxide/src/handler/mod.rs b/socketioxide/src/handler/mod.rs index 3e48f1f7..9f51a6dd 100644 --- a/socketioxide/src/handler/mod.rs +++ b/socketioxide/src/handler/mod.rs @@ -3,7 +3,6 @@ //! All handlers can be async or not. pub mod connect; pub mod disconnect; -pub mod extract; pub mod message; pub(crate) use connect::BoxedConnectHandler; diff --git a/socketioxide/src/lib.rs b/socketioxide/src/lib.rs index cc9ae9f6..5f34544d 100644 --- a/socketioxide/src/lib.rs +++ b/socketioxide/src/lib.rs @@ -281,6 +281,7 @@ pub mod extensions; mod state; pub mod ack; +pub mod extract; pub mod handler; pub mod layer; pub mod operators; @@ -290,7 +291,6 @@ pub mod socket; pub use engineioxide::TransportType; pub use errors::{AckError, AdapterError, BroadcastError, DisconnectError, SendError, SocketError}; -pub use handler::extract; pub use io::{SocketIo, SocketIoBuilder, SocketIoConfig}; mod client; From f6008a559d6fa23f4e9a61467cca10e9ccba9ac2 Mon Sep 17 00:00:00 2001 From: totodore Date: Sat, 20 Apr 2024 00:59:30 -0300 Subject: [PATCH 11/19] feat(socketio/extract): add `(Maybe)(Http)Extension` extractors --- socketioxide/src/extract/extensions.rs | 183 +++++++++++++++++++++++++ socketioxide/src/extract/mod.rs | 1 + 2 files changed, 184 insertions(+) create mode 100644 socketioxide/src/extract/extensions.rs diff --git a/socketioxide/src/extract/extensions.rs b/socketioxide/src/extract/extensions.rs new file mode 100644 index 00000000..9a898503 --- /dev/null +++ b/socketioxide/src/extract/extensions.rs @@ -0,0 +1,183 @@ +use std::convert::Infallible; +use std::sync::Arc; + +use crate::adapter::Adapter; +use crate::handler::{FromConnectParts, FromDisconnectParts, FromMessageParts}; +use crate::socket::{DisconnectReason, Socket}; +use bytes::Bytes; + +#[cfg(feature = "extensions")] +#[cfg_attr(docsrs, doc(cfg(feature = "extensions")))] +pub use extensions_extract::*; + +/// It was impossible to find the given extension +pub struct ExtensionNotFound(std::marker::PhantomData); + +impl std::fmt::Display for ExtensionNotFound { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Extension of type {} not found, maybe you forgot to insert it in the extensions map?", + std::any::type_name::() + ) + } +} +impl std::fmt::Debug for ExtensionNotFound { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "ExtensionNotFound {}", std::any::type_name::()) + } +} +impl std::error::Error for ExtensionNotFound {} + +fn extract_http_extension( + s: &Arc>, +) -> Result> { + s.req_parts() + .extensions + .get::() + .cloned() + .ok_or(ExtensionNotFound(std::marker::PhantomData)) +} + +/// An Extractor that returns a clone extension from the request parts. +pub struct HttpExtension(pub T); +/// An Extractor that returns a clone extension from the request parts if it exists. +pub struct MaybeHttpExtension(pub Option); + +impl FromConnectParts for HttpExtension { + type Error = ExtensionNotFound; + fn from_connect_parts( + s: &Arc>, + _: &Option, + ) -> Result> { + extract_http_extension(s).map(HttpExtension) + } +} + +impl FromConnectParts for MaybeHttpExtension { + type Error = Infallible; + fn from_connect_parts(s: &Arc>, _: &Option) -> Result { + Ok(MaybeHttpExtension(extract_http_extension(s).ok())) + } +} + +impl FromDisconnectParts for HttpExtension { + type Error = ExtensionNotFound; + fn from_disconnect_parts( + s: &Arc>, + _: DisconnectReason, + ) -> Result> { + extract_http_extension(s).map(HttpExtension) + } +} +impl FromDisconnectParts + for MaybeHttpExtension +{ + type Error = Infallible; + fn from_disconnect_parts(s: &Arc>, _: DisconnectReason) -> Result { + Ok(MaybeHttpExtension(extract_http_extension(s).ok())) + } +} + +impl FromMessageParts for HttpExtension { + type Error = ExtensionNotFound; + fn from_message_parts( + s: &Arc>, + _: &mut serde_json::Value, + _: &mut Vec, + _: &Option, + ) -> Result> { + extract_http_extension(s).map(HttpExtension) + } +} +impl FromMessageParts for MaybeHttpExtension { + type Error = Infallible; + fn from_message_parts( + s: &Arc>, + _: &mut serde_json::Value, + _: &mut Vec, + _: &Option, + ) -> Result { + Ok(MaybeHttpExtension(extract_http_extension(s).ok())) + } +} + +super::__impl_deref!(HttpExtension); +super::__impl_deref!(MaybeHttpExtension: Option); + +#[cfg(feature = "extensions")] +mod extensions_extract { + use super::*; + + fn extract_extension( + s: &Arc>, + ) -> Result> { + s.extensions + .get::() + .ok_or(ExtensionNotFound(std::marker::PhantomData)) + } + + /// An Extractor that returns the extension of the given type. + /// If the extension is not found, + /// the handler won't be called and an error log will be print if the `tracing` feature is enabled. + pub struct Extension(pub T); + + /// An Extractor that returns the extension of the given type if it exists or `None` otherwise. + pub struct MaybeExtension(pub Option); + + impl FromConnectParts for Extension { + type Error = ExtensionNotFound; + fn from_connect_parts( + s: &Arc>, + _: &Option, + ) -> Result> { + extract_extension(s).map(Extension) + } + } + impl FromConnectParts for MaybeExtension { + type Error = Infallible; + fn from_connect_parts(s: &Arc>, _: &Option) -> Result { + Ok(MaybeExtension(extract_extension(s).ok())) + } + } + impl FromDisconnectParts for Extension { + type Error = ExtensionNotFound; + fn from_disconnect_parts( + s: &Arc>, + _: DisconnectReason, + ) -> Result> { + extract_extension(s).map(Extension) + } + } + impl FromDisconnectParts for MaybeExtension { + type Error = Infallible; + fn from_disconnect_parts( + s: &Arc>, + _: DisconnectReason, + ) -> Result { + Ok(MaybeExtension(extract_extension(s).ok())) + } + } + impl FromMessageParts for Extension { + type Error = ExtensionNotFound; + fn from_message_parts( + s: &Arc>, + _: &mut serde_json::Value, + _: &mut Vec, + _: &Option, + ) -> Result> { + extract_extension(s).map(Extension) + } + } + impl FromMessageParts for MaybeExtension { + type Error = Infallible; + fn from_message_parts( + s: &Arc>, + _: &mut serde_json::Value, + _: &mut Vec, + _: &Option, + ) -> Result { + Ok(MaybeExtension(extract_extension(s).ok())) + } + } +} diff --git a/socketioxide/src/extract/mod.rs b/socketioxide/src/extract/mod.rs index b84e92ca..ed6869d6 100644 --- a/socketioxide/src/extract/mod.rs +++ b/socketioxide/src/extract/mod.rs @@ -102,6 +102,7 @@ mod socket; mod state; pub use data::*; +pub use extensions::*; pub use socket::*; #[cfg(feature = "state")] #[cfg_attr(docsrs, doc(cfg(feature = "state")))] From ecff81a26159004ef71aa07cdd589732380b694c Mon Sep 17 00:00:00 2001 From: totodore Date: Sat, 20 Apr 2024 02:04:40 -0300 Subject: [PATCH 12/19] docs(example): update examples with `Extension` extractor --- examples/chat/src/main.rs | 17 ++- examples/private-messaging/Cargo.toml | 1 + examples/private-messaging/src/handlers.rs | 119 +++++++++++---------- examples/private-messaging/src/store.rs | 67 ++++++++++-- 4 files changed, 130 insertions(+), 74 deletions(-) diff --git a/examples/chat/src/main.rs b/examples/chat/src/main.rs index 909194a4..ea006ed0 100644 --- a/examples/chat/src/main.rs +++ b/examples/chat/src/main.rs @@ -2,7 +2,7 @@ use std::sync::atomic::AtomicUsize; use serde::{Deserialize, Serialize}; use socketioxide::{ - extract::{Data, SocketRef, State}, + extract::{Data, SocketRef, State, Extension, MaybeExtension}, SocketIo, }; use tower::ServiceBuilder; @@ -59,8 +59,7 @@ async fn main() -> Result<(), Box> { let (layer, io) = SocketIo::builder().with_state(UserCnt::new()).build_layer(); io.ns("/", |s: SocketRef| { - s.on("new message", |s: SocketRef, Data::(msg)| { - let username = s.extensions.get::().unwrap().clone(); + s.on("new message", |s: SocketRef, Data::(msg), Extension::(username)| { let msg = Res::Message { username, message: msg, @@ -86,26 +85,24 @@ async fn main() -> Result<(), Box> { }, ); - s.on("typing", |s: SocketRef| { - let username = s.extensions.get::().unwrap().clone(); + s.on("typing", |s: SocketRef, Extension::(username)| { s.broadcast() .emit("typing", Res::Username { username }) .ok(); }); - s.on("stop typing", |s: SocketRef| { - let username = s.extensions.get::().unwrap().clone(); + s.on("stop typing", |s: SocketRef, Extension::(username)| { s.broadcast() .emit("stop typing", Res::Username { username }) .ok(); }); - s.on_disconnect(|s: SocketRef, user_cnt: State| { - if let Some(username) = s.extensions.get::() { + s.on_disconnect(|s: SocketRef, user_cnt: State, MaybeExtension::(username)| { + if let Some(username) = username { let num_users = user_cnt.remove_user(); let res = Res::UserEvent { num_users, - username: username.clone(), + username, }; s.broadcast().emit("user left", res).ok(); } diff --git a/examples/private-messaging/Cargo.toml b/examples/private-messaging/Cargo.toml index e9bcb9e4..ce4d8fcb 100644 --- a/examples/private-messaging/Cargo.toml +++ b/examples/private-messaging/Cargo.toml @@ -8,6 +8,7 @@ edition = "2021" socketioxide = { path = "../../socketioxide", features = [ "extensions", "state", + "tracing", ] } axum.workspace = true tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } diff --git a/examples/private-messaging/src/handlers.rs b/examples/private-messaging/src/handlers.rs index b72dfa88..01f32789 100644 --- a/examples/private-messaging/src/handlers.rs +++ b/examples/private-messaging/src/handlers.rs @@ -1,6 +1,8 @@ +use std::sync::{atomic::Ordering, Arc}; + use anyhow::anyhow; use serde::{Deserialize, Serialize}; -use socketioxide::extract::{Data, SocketRef, State, TryData}; +use socketioxide::extract::{Data, Extension, SocketRef, State}; use uuid::Uuid; use crate::store::{Message, Messages, Session, Sessions}; @@ -22,12 +24,19 @@ struct UserConnectedRes { messages: Vec, } +#[derive(Debug, Serialize, Clone)] +struct UserDisconnectedRes { + #[serde(rename = "userID")] + user_id: Uuid, + username: String, +} + impl UserConnectedRes { fn new(session: &Session, messages: Vec) -> Self { Self { user_id: session.user_id, username: session.username.clone(), - connected: session.connected, + connected: session.connected.load(Ordering::SeqCst), messages, } } @@ -38,86 +47,80 @@ struct PrivateMessageReq { content: String, } -pub fn on_connection(s: SocketRef) { +pub fn on_connection( + s: SocketRef, + Extension::>(session): Extension>, + State(sessions): State, + State(msgs): State, +) { + s.emit("session", (*session).clone()).unwrap(); + + let users = sessions + .get_all_other_sessions(session.user_id) + .into_iter() + .map(|session| { + let messages = msgs.get_all_for_user(session.user_id); + UserConnectedRes::new(&session, messages) + }) + .collect::>(); + + s.emit("users", [users]).unwrap(); + + let res = UserConnectedRes::new(&session, vec![]); + s.broadcast().emit("user connected", res).unwrap(); + s.on( "private message", - |s: SocketRef, Data(PrivateMessageReq { to, content }), State(Messages(msg))| { - let user_id = s.extensions.get::().unwrap().user_id; + |s: SocketRef, + Data(PrivateMessageReq { to, content }), + State::(msgs), + Extension::>(session)| { let message = Message { - from: user_id, + from: session.user_id, to, content, }; - msg.write().unwrap().push(message.clone()); + msgs.add(message.clone()); s.within(to.to_string()) .emit("private message", message) .ok(); }, ); - s.on_disconnect(|s: SocketRef, State(Sessions(sessions))| { - let mut session = s.extensions.get::().unwrap().clone(); - session.connected = false; - - sessions - .write() - .unwrap() - .get_mut(&session.session_id) - .unwrap() - .connected = false; - - s.broadcast().emit("user disconnected", session).ok(); + s.on_disconnect(|s: SocketRef, Extension::>(session)| { + session.set_connected(false); + s.broadcast() + .emit( + "user disconnected", + UserDisconnectedRes { + user_id: session.user_id, + username: session.username.clone(), + }, + ) + .ok(); }); } -/// Handles the connection of a new user +/// Handles the connection of a new user. +/// Be careful to not emit anything to the user before the authentication is done. pub fn authenticate_middleware( s: SocketRef, - TryData(auth): TryData, - State(Sessions(session_state)): State, - State(Messages(msg_state)): State, + Data(auth): Data, + State(sessions): State, ) -> Result<(), anyhow::Error> { - let auth = auth?; - let mut sessions = session_state.write().unwrap(); - if let Some(session) = auth.session_id.and_then(|id| sessions.get_mut(&id)) { - session.connected = true; + let session = if let Some(session) = auth.session_id.and_then(|id| sessions.get(id)) { + session.set_connected(true); s.extensions.insert(session.clone()); + session } else { let username = auth.username.ok_or(anyhow!("invalid username"))?; - let session = Session::new(username); + let session = Arc::new(Session::new(username)); s.extensions.insert(session.clone()); - - sessions.insert(session.session_id, session); + sessions.add(session.clone()); + session }; - drop(sessions); - - let session = s.extensions.get::().unwrap(); - - s.join(session.user_id.to_string()).ok(); - s.emit("session", session.clone())?; - let users = session_state - .read() - .unwrap() - .iter() - .filter(|(id, _)| id != &&session.session_id) - .map(|(_, session)| { - let messages = msg_state - .read() - .unwrap() - .iter() - .filter(|message| message.to == session.user_id || message.from == session.user_id) - .cloned() - .collect(); - - UserConnectedRes::new(session, messages) - }) - .collect::>(); - - s.emit("users", [users])?; - - let res = UserConnectedRes::new(&session, vec![]); + s.join(session.user_id.to_string())?; - s.broadcast().emit("user connected", res)?; Ok(()) } diff --git a/examples/private-messaging/src/store.rs b/examples/private-messaging/src/store.rs index 0bb57ef2..d1553d8d 100644 --- a/examples/private-messaging/src/store.rs +++ b/examples/private-messaging/src/store.rs @@ -1,15 +1,21 @@ -use std::{collections::HashMap, sync::RwLock}; +use std::{ + collections::HashMap, + sync::{atomic::Ordering, Arc, RwLock}, +}; use serde::Serialize; +use std::sync::atomic::AtomicBool; use uuid::Uuid; /// Store Types -#[derive(Debug, Clone, Serialize)] +#[derive(Debug, Serialize)] pub struct Session { + #[serde(rename = "sessionID")] pub session_id: Uuid, + #[serde(rename = "userID")] pub user_id: Uuid, pub username: String, - pub connected: bool, + pub connected: AtomicBool, } impl Session { pub fn new(username: String) -> Self { @@ -17,7 +23,20 @@ impl Session { session_id: Uuid::new_v4(), user_id: Uuid::new_v4(), username, - connected: true, + connected: AtomicBool::new(true), + } + } + pub fn set_connected(&self, connected: bool) { + self.connected.store(connected, Ordering::SeqCst); + } +} +impl Clone for Session { + fn clone(&self) -> Self { + Self { + session_id: self.session_id.clone(), + user_id: self.user_id.clone(), + username: self.username.clone(), + connected: AtomicBool::new(self.connected.load(Ordering::SeqCst)), } } } @@ -29,6 +48,42 @@ pub struct Message { } #[derive(Default)] -pub struct Sessions(pub RwLock>); +pub struct Sessions(RwLock>>); + +impl Sessions { + pub fn get_all_other_sessions(&self, user_id: Uuid) -> Vec> { + self.0 + .read() + .unwrap() + .values() + .filter(|s| s.user_id != user_id) + .cloned() + .collect() + } + + pub fn get(&self, user_id: Uuid) -> Option> { + self.0.read().unwrap().get(&user_id).cloned() + } + + pub fn add(&self, session: Arc) { + self.0.write().unwrap().insert(session.session_id, session); + } +} #[derive(Default)] -pub struct Messages(pub RwLock>); +pub struct Messages(RwLock>); + +impl Messages { + pub fn add(&self, message: Message) { + self.0.write().unwrap().push(message); + } + + pub fn get_all_for_user(&self, user_id: Uuid) -> Vec { + self.0 + .read() + .unwrap() + .iter() + .filter(|m| m.from == user_id || m.to == user_id) + .cloned() + .collect() + } +} From 5070dd0d6d7b8aff415e2e7958965204fa8e9048 Mon Sep 17 00:00:00 2001 From: totodore Date: Sat, 20 Apr 2024 02:44:19 -0300 Subject: [PATCH 13/19] test(socketio/extract): add tests for `Extension` and `MaybeExtension` --- socketioxide/tests/extractors.rs | 105 ++++++++++++++++++++++++++++++- 1 file changed, 103 insertions(+), 2 deletions(-) diff --git a/socketioxide/tests/extractors.rs b/socketioxide/tests/extractors.rs index c0860376..8b2a9d3c 100644 --- a/socketioxide/tests/extractors.rs +++ b/socketioxide/tests/extractors.rs @@ -1,8 +1,10 @@ //! Tests for extractors +use std::convert::Infallible; use std::time::Duration; -use serde_json::json; -use socketioxide::extract::{Data, SocketRef, State, TryData}; +use serde_json::{json, Value}; +use socketioxide::extract::{Data, Extension, MaybeExtension, SocketRef, State, TryData}; +use socketioxide::handler::ConnectHandler; use tokio::sync::mpsc; use engineioxide::Packet as EioPacket; @@ -17,6 +19,11 @@ async fn timeout_rcv(srx: &mut tokio::sync::mpsc::Receiver(srx: &mut tokio::sync::mpsc::Receiver) { + tokio::time::timeout(Duration::from_millis(200), srx.recv()) + .await + .unwrap_err(); +} #[tokio::test] pub async fn state_extractor() { @@ -118,3 +125,97 @@ pub async fn try_data_extractor() { assert_ok!(stx.try_send(packet)); assert_err!(timeout_rcv(&mut rx).await); } + +#[tokio::test] +pub async fn extension_extractor() { + let (_, io) = SocketIo::new_svc(); + + fn on_test(s: SocketRef, Extension(i): Extension) { + s.emit("from_ev_test", i).unwrap(); + } + fn ns_root(s: SocketRef, Extension(i): Extension) { + s.emit("from_ns", i).unwrap(); + s.on("test", on_test); + } + fn set_ext(s: SocketRef) -> Result<(), Infallible> { + s.extensions.insert(123usize); + Ok(()) + } + + // Namespace without errors (the extension is set) + io.ns("/", ns_root.with(set_ext)); + // Namespace with errors (the extension is not set) + io.ns("/test", ns_root); + + // Extract extensions from the socket + let (tx, mut rx) = io.new_dummy_sock("/", ()).await; + assert!(matches!(timeout_rcv(&mut rx).await, EioPacket::Message(s) if s.starts_with("0"))); + assert_eq!( + timeout_rcv(&mut rx).await, + EioPacket::Message("2[\"from_ns\",123]".into()) + ); + let packet = Packet::event("/", "test", Value::Null).into(); + assert_ok!(tx.try_send(EioPacket::Message(packet))); + assert_eq!( + timeout_rcv(&mut rx).await, + EioPacket::Message("2[\"from_ev_test\",123]".into()) + ); + + // Extract unknown extensions from the socket + let (tx, mut rx) = io.new_dummy_sock("/test", ()).await; + assert!(matches!(timeout_rcv(&mut rx).await, EioPacket::Message(s) if s.starts_with("0"))); + timeout_rcv_err(&mut rx).await; + let packet = Packet::event("/test", "test", Value::Null).into(); + assert_ok!(tx.try_send(EioPacket::Message(packet))); + timeout_rcv_err(&mut rx).await; +} + +#[tokio::test] +pub async fn maybe_extension_extractor() { + let (_, io) = SocketIo::new_svc(); + + fn on_test(s: SocketRef, MaybeExtension(i): MaybeExtension) { + s.emit("from_ev_test", i).unwrap(); + } + fn ns_root(s: SocketRef, MaybeExtension(i): MaybeExtension) { + s.emit("from_ns", i).unwrap(); + s.on("test", on_test); + } + fn set_ext(s: SocketRef) -> Result<(), Infallible> { + s.extensions.insert(123usize); + Ok(()) + } + + // Namespace without errors (the extension is set) + io.ns("/", ns_root.with(set_ext)); + // Namespace with errors (the extension is not set) + io.ns("/test", ns_root); + + // Extract extensions from the socket + let (tx, mut rx) = io.new_dummy_sock("/", ()).await; + assert!(matches!(timeout_rcv(&mut rx).await, EioPacket::Message(s) if s.starts_with("0"))); + assert_eq!( + timeout_rcv(&mut rx).await, + EioPacket::Message("2[\"from_ns\",123]".into()) + ); + let packet = Packet::event("/", "test", Value::Null).into(); + assert_ok!(tx.try_send(EioPacket::Message(packet))); + assert_eq!( + timeout_rcv(&mut rx).await, + EioPacket::Message("2[\"from_ev_test\",123]".into()) + ); + + // Extract unknown extensions from the socket + let (tx, mut rx) = io.new_dummy_sock("/test", ()).await; + assert!(matches!(timeout_rcv(&mut rx).await, EioPacket::Message(s) if s.starts_with("0"))); + assert_eq!( + timeout_rcv(&mut rx).await, + EioPacket::Message("2/test,[\"from_ns\",null]".into()) + ); + let packet = Packet::event("/test", "test", Value::Null).into(); + assert_ok!(tx.try_send(EioPacket::Message(packet))); + assert_eq!( + timeout_rcv(&mut rx).await, + EioPacket::Message("2/test,[\"from_ev_test\",null]".into()) + ); +} From a744f7b05467ec3265055af6480c361ebed5e570 Mon Sep 17 00:00:00 2001 From: totodore Date: Sat, 20 Apr 2024 02:47:48 -0300 Subject: [PATCH 14/19] docs(example) fmt chat example --- examples/chat/src/main.rs | 54 ++++++++++++++++++++++----------------- 1 file changed, 31 insertions(+), 23 deletions(-) diff --git a/examples/chat/src/main.rs b/examples/chat/src/main.rs index ea006ed0..839b5f83 100644 --- a/examples/chat/src/main.rs +++ b/examples/chat/src/main.rs @@ -2,7 +2,7 @@ use std::sync::atomic::AtomicUsize; use serde::{Deserialize, Serialize}; use socketioxide::{ - extract::{Data, SocketRef, State, Extension, MaybeExtension}, + extract::{Data, Extension, MaybeExtension, SocketRef, State}, SocketIo, }; use tower::ServiceBuilder; @@ -59,13 +59,16 @@ async fn main() -> Result<(), Box> { let (layer, io) = SocketIo::builder().with_state(UserCnt::new()).build_layer(); io.ns("/", |s: SocketRef| { - s.on("new message", |s: SocketRef, Data::(msg), Extension::(username)| { - let msg = Res::Message { - username, - message: msg, - }; - s.broadcast().emit("new message", msg).ok(); - }); + s.on( + "new message", + |s: SocketRef, Data::(msg), Extension::(username)| { + let msg = Res::Message { + username, + message: msg, + }; + s.broadcast().emit("new message", msg).ok(); + }, + ); s.on( "add user", @@ -91,22 +94,27 @@ async fn main() -> Result<(), Box> { .ok(); }); - s.on("stop typing", |s: SocketRef, Extension::(username)| { - s.broadcast() - .emit("stop typing", Res::Username { username }) - .ok(); - }); + s.on( + "stop typing", + |s: SocketRef, Extension::(username)| { + s.broadcast() + .emit("stop typing", Res::Username { username }) + .ok(); + }, + ); - s.on_disconnect(|s: SocketRef, user_cnt: State, MaybeExtension::(username)| { - if let Some(username) = username { - let num_users = user_cnt.remove_user(); - let res = Res::UserEvent { - num_users, - username, - }; - s.broadcast().emit("user left", res).ok(); - } - }); + s.on_disconnect( + |s: SocketRef, user_cnt: State, MaybeExtension::(username)| { + if let Some(username) = username { + let num_users = user_cnt.remove_user(); + let res = Res::UserEvent { + num_users, + username, + }; + s.broadcast().emit("user left", res).ok(); + } + }, + ); }); let app = axum::Router::new() From 50372f28042372efd38762e44cf23cc0626ddbd3 Mon Sep 17 00:00:00 2001 From: totodore Date: Sun, 21 Apr 2024 18:17:47 -0300 Subject: [PATCH 15/19] test(socketio): fix extractors test --- socketioxide/tests/extractors.rs | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/socketioxide/tests/extractors.rs b/socketioxide/tests/extractors.rs index 88be86c4..292f1053 100644 --- a/socketioxide/tests/extractors.rs +++ b/socketioxide/tests/extractors.rs @@ -154,8 +154,7 @@ pub async fn extension_extractor() { timeout_rcv(&mut rx).await, EioPacket::Message("2[\"from_ns\",123]".into()) ); - let packet = Packet::event("/", "test", Value::Null).into(); - assert_ok!(tx.try_send(EioPacket::Message(packet))); + assert_ok!(tx.try_send(create_msg("/", "test", Value::Null))); assert_eq!( timeout_rcv(&mut rx).await, EioPacket::Message("2[\"from_ev_test\",123]".into()) @@ -165,8 +164,7 @@ pub async fn extension_extractor() { let (tx, mut rx) = io.new_dummy_sock("/test", ()).await; assert!(matches!(timeout_rcv(&mut rx).await, EioPacket::Message(s) if s.starts_with("0"))); timeout_rcv_err(&mut rx).await; - let packet = Packet::event("/test", "test", Value::Null).into(); - assert_ok!(tx.try_send(EioPacket::Message(packet))); + assert_ok!(tx.try_send(create_msg("/test", "test", Value::Null))); timeout_rcv_err(&mut rx).await; } @@ -198,8 +196,7 @@ pub async fn maybe_extension_extractor() { timeout_rcv(&mut rx).await, EioPacket::Message("2[\"from_ns\",123]".into()) ); - let packet = Packet::event("/", "test", Value::Null).into(); - assert_ok!(tx.try_send(EioPacket::Message(packet))); + assert_ok!(tx.try_send(create_msg("/", "test", Value::Null))); assert_eq!( timeout_rcv(&mut rx).await, EioPacket::Message("2[\"from_ev_test\",123]".into()) @@ -212,8 +209,7 @@ pub async fn maybe_extension_extractor() { timeout_rcv(&mut rx).await, EioPacket::Message("2/test,[\"from_ns\",null]".into()) ); - let packet = Packet::event("/test", "test", Value::Null).into(); - assert_ok!(tx.try_send(EioPacket::Message(packet))); + assert_ok!(tx.try_send(create_msg("/test", "test", Value::Null))); assert_eq!( timeout_rcv(&mut rx).await, EioPacket::Message("2/test,[\"from_ev_test\",null]".into()) From 76c72caa3bdd3cd6080eb55fb7df83d4d1df2754 Mon Sep 17 00:00:00 2001 From: totodore Date: Sun, 21 Apr 2024 18:33:13 -0300 Subject: [PATCH 16/19] doc(socketio): improve doc for socketioxide --- socketioxide/src/extensions.rs | 2 +- socketioxide/src/extract/data.rs | 2 +- socketioxide/src/extract/mod.rs | 26 ++++++++++++++++---------- socketioxide/src/handler/connect.rs | 8 ++++---- socketioxide/src/handler/disconnect.rs | 6 +++--- socketioxide/src/handler/message.rs | 8 ++++---- socketioxide/src/lib.rs | 2 +- 7 files changed, 30 insertions(+), 24 deletions(-) diff --git a/socketioxide/src/extensions.rs b/socketioxide/src/extensions.rs index a7a59259..b8d243e3 100644 --- a/socketioxide/src/extensions.rs +++ b/socketioxide/src/extensions.rs @@ -2,7 +2,7 @@ //! //! It is heavily inspired by the [`http::Extensions`] type from the `http` crate. //! -//! The main difference is that the inner [`HashMap`](std::collections::HashMap) is wrapped with an [`RwLock`] +//! The main difference is that the inner [`HashMap`] is wrapped with an [`RwLock`] //! to allow concurrent access. Moreover, any value extracted from the map is cloned before being returned. //! //! This is necessary because [`Extensions`] are shared between all the threads that handle the same socket. diff --git a/socketioxide/src/extract/data.rs b/socketioxide/src/extract/data.rs index 5ccb4b12..51ba5292 100644 --- a/socketioxide/src/extract/data.rs +++ b/socketioxide/src/extract/data.rs @@ -18,7 +18,7 @@ fn upwrap_array(v: &mut Value) { } /// An Extractor that returns the serialized auth data without checking errors. -/// If a deserialization error occurs, the [`ConnectHandler`](super::ConnectHandler) won't be called +/// If a deserialization error occurs, the [`ConnectHandler`](crate::handler::ConnectHandler) won't be called /// and an error log will be print if the `tracing` feature is enabled. pub struct Data(pub T); impl FromConnectParts for Data diff --git a/socketioxide/src/extract/mod.rs b/socketioxide/src/extract/mod.rs index ed6869d6..b7adcb3e 100644 --- a/socketioxide/src/extract/mod.rs +++ b/socketioxide/src/extract/mod.rs @@ -1,18 +1,16 @@ -//! ### Extractors for [`ConnectHandler`](super::ConnectHandler), [`ConnectMiddleware`](super::ConnectMiddleware), -//! [`MessageHandler`](super::MessageHandler) -//! and [`DisconnectHandler`](super::DisconnectHandler). +//! ### Extractors for [`ConnectHandler`], [`ConnectMiddleware`], [`MessageHandler`] and [`DisconnectHandler`](crate::handler::DisconnectHandler). //! //! They can be used to extract data from the context of the handler and get specific params. Here are some examples of extractors: //! * [`Data`]: extracts and deserialize to json any data, if a deserialization error occurs the handler won't be called: -//! - for [`ConnectHandler`](super::ConnectHandler): extracts and deserialize to json the auth data -//! - for [`ConnectMiddleware`](super::ConnectMiddleware): extract and deserialize to json the auth data. +//! - for [`ConnectHandler`]: extracts and deserialize to json the auth data +//! - for [`ConnectMiddleware`]: extract and deserialize to json the auth data. //! In case of error, the middleware chain stops and a `connect_error` event is sent. -//! - for [`MessageHandler`](super::MessageHandler): extracts and deserialize to json the message data +//! - for [`MessageHandler`]: extracts and deserialize to json the message data //! * [`TryData`]: extracts and deserialize to json any data but with a `Result` type in case of error: -//! - for [`ConnectHandler`](super::ConnectHandler) and [`ConnectMiddleware`](super::ConnectMiddleware): +//! - for [`ConnectHandler`] and [`ConnectMiddleware`]: //! extracts and deserialize to json the auth data -//! - for [`MessageHandler`](super::MessageHandler): extracts and deserialize to json the message data -//! * [`SocketRef`]: extracts a reference to the [`Socket`] +//! - for [`MessageHandler`]: extracts and deserialize to json the message data +//! * [`SocketRef`]: extracts a reference to the [`Socket`](crate::socket::Socket) //! * [`Bin`]: extract a binary payload for a given message. Because it consumes the event it should be the last argument //! * [`AckSender`]: Can be used to send an ack response to the current message event //! * [`ProtocolVersion`](crate::ProtocolVersion): extracts the protocol version @@ -26,9 +24,17 @@ //! * [`MaybeHttpExtension`]: extracts an http extension of the given type if it exists or `None` otherwise. //! //! ### You can also implement your own Extractor with the [`FromConnectParts`], [`FromMessageParts`] and [`FromDisconnectParts`] traits -//! When implementing these traits, if you clone the [`Arc`] make sure that it is dropped at least when the socket is disconnected. +//! When implementing these traits, if you clone the [`Arc`](crate::socket::Socket) make sure that it is dropped at least when the socket is disconnected. //! Otherwise it will create a memory leak. It is why the [`SocketRef`] extractor is used instead of cloning the socket for common usage. //! +//! [`FromConnectParts`]: crate::handler::FromConnectParts +//! [`FromMessageParts`]: crate::handler::FromMessageParts +//! [`FromDisconnectParts`]: crate::handler::FromDisconnectParts +//! [`ConnectHandler`]: crate::handler::ConnectHandler +//! [`ConnectMiddleware`]: crate::handler::ConnectMiddleware +//! [`MessageHandler`]: crate::handler::MessageHandler +//! [`DisconnectHandler`]: crate::handler::DisconnectHandler +//! //! #### Example that extracts a user id from the query params //! ```rust //! # use bytes::Bytes; diff --git a/socketioxide/src/handler/connect.rs b/socketioxide/src/handler/connect.rs index 07e34419..888a3b30 100644 --- a/socketioxide/src/handler/connect.rs +++ b/socketioxide/src/handler/connect.rs @@ -2,7 +2,7 @@ //! It has a flexible axum-like API, you can put any arguments as long as it implements the [`FromConnectParts`] trait. //! //! You can also implement the [`FromConnectParts`] trait for your own types. -//! See the [`extract`](super::extract) module doc for more details on available extractors. +//! See the [`extract`](crate::extract) module doc for more details on available extractors. //! //! Handlers can be _optionally_ async. //! @@ -141,7 +141,7 @@ pub(crate) trait ErasedConnectHandler: Send + Sync + 'static { /// in this case the [`ConnectHandler`] is not called. /// /// * See the [`connect`](super::connect) module doc for more details on connect handler. -/// * See the [`extract`](super::extract) module doc for more details on available extractors. +/// * See the [`extract`](crate::extract) module doc for more details on available extractors. pub trait FromConnectParts: Sized { /// The error type returned by the extractor type Error: std::error::Error + Send + 'static; @@ -156,7 +156,7 @@ pub trait FromConnectParts: Sized { /// They must implement the [`FromConnectParts`] trait and return `Result<(), E> where E: Display`. /// /// * See the [`connect`](super::connect) module doc for more details on connect middlewares. -/// * See the [`extract`](super::extract) module doc for more details on available extractors. +/// * See the [`extract`](crate::extract) module doc for more details on available extractors. pub trait ConnectMiddleware: Send + Sync + 'static { /// Call the middleware with the given arguments. fn call<'a>( @@ -175,7 +175,7 @@ pub trait ConnectMiddleware: Send + Sync + 'static { /// It is implemented for closures with up to 16 arguments. They must implement the [`FromConnectParts`] trait. /// /// * See the [`connect`](super::connect) module doc for more details on connect handler. -/// * See the [`extract`](super::extract) module doc for more details on available extractors. +/// * See the [`extract`](crate::extract) module doc for more details on available extractors. pub trait ConnectHandler: Send + Sync + 'static { /// Call the handler with the given arguments. fn call(&self, s: Arc>, auth: Option); diff --git a/socketioxide/src/handler/disconnect.rs b/socketioxide/src/handler/disconnect.rs index b955f736..b63adb15 100644 --- a/socketioxide/src/handler/disconnect.rs +++ b/socketioxide/src/handler/disconnect.rs @@ -2,7 +2,7 @@ //! It has a flexible axum-like API, you can put any arguments as long as it implements the [`FromDisconnectParts`] trait. //! //! You can also implement the [`FromDisconnectParts`] trait for your own types. -//! See the [`extract`](super::extract) module doc for more details on available extractors. +//! See the [`extract`](crate::extract) module doc for more details on available extractors. //! //! Handlers can be _optionally_ async. //! @@ -97,7 +97,7 @@ where /// in this case the [`DisconnectHandler`] is not called. /// /// * See the [`disconnect`](super::disconnect) module doc for more details on disconnect handler. -/// * See the [`extract`](super::extract) module doc for more details on available extractors. +/// * See the [`extract`](crate::extract) module doc for more details on available extractors. pub trait FromDisconnectParts: Sized { /// The error type returned by the extractor type Error: std::error::Error + 'static; @@ -114,7 +114,7 @@ pub trait FromDisconnectParts: Sized { /// It is implemented for closures with up to 16 arguments. They must implement the [`FromDisconnectParts`] trait. /// /// * See the [`disconnect`](super::disconnect) module doc for more details on disconnect handler. -/// * See the [`extract`](super::extract) module doc for more details on available extractors. +/// * See the [`extract`](crate::extract) module doc for more details on available extractors. pub trait DisconnectHandler: Send + Sync + 'static { /// Call the handler with the given arguments. fn call(&self, s: Arc>, reason: DisconnectReason); diff --git a/socketioxide/src/handler/message.rs b/socketioxide/src/handler/message.rs index e754f576..5b4ab255 100644 --- a/socketioxide/src/handler/message.rs +++ b/socketioxide/src/handler/message.rs @@ -5,7 +5,7 @@ //! All the types that implement [`FromMessageParts`] also implement [`FromMessage`]. //! //! You can also implement the [`FromMessageParts`] and [`FromMessage`] traits for your own types. -//! See the [`extract`](super::extract) module doc for more details on available extractors. +//! See the [`extract`](crate::extract) module doc for more details on available extractors. //! //! Handlers can be _optionally_ async. //! @@ -93,7 +93,7 @@ pub(crate) trait ErasedMessageHandler: Send + Sync + 'static { /// It is implemented for closures with up to 16 arguments. They must implement the [`FromMessageParts`] trait or the [`FromMessage`] trait for the last one. /// /// * See the [`message`](super::message) module doc for more details on message handler. -/// * See the [`extract`](super::extract) module doc for more details on available extractors. +/// * See the [`extract`](crate::extract) module doc for more details on available extractors. #[cfg_attr( nightly_error_messages, diagnostic::on_unimplemented( @@ -150,7 +150,7 @@ mod private { /// The `Result` associated type is used to return an error if the extraction fails, in this case the handler is not called. /// /// * See the [`message`](super::message) module doc for more details on message handler. -/// * See the [`extract`](super::extract) module doc for more details on available extractors. +/// * See the [`extract`](crate::extract) module doc for more details on available extractors. #[cfg_attr( nightly_error_messages, diagnostic::on_unimplemented( @@ -175,7 +175,7 @@ pub trait FromMessageParts: Sized { /// The `Result` associated type is used to return an error if the extraction fails, in this case the handler is not called. /// /// * See the [`message`](super::message) module doc for more details on message handler. -/// * See the [`extract`](super::extract) module doc for more details on available extractors. +/// * See the [`extract`](crate::extract) module doc for more details on available extractors. #[cfg_attr( nightly_error_messages, diagnostic::on_unimplemented( diff --git a/socketioxide/src/lib.rs b/socketioxide/src/lib.rs index 5f34544d..74721cb8 100644 --- a/socketioxide/src/lib.rs +++ b/socketioxide/src/lib.rs @@ -137,7 +137,7 @@ //! * Check the [`handler::connect`] module doc for more details on the connect handler and connect middlewares. //! * Check the [`handler::message`] module doc for more details on the message handler. //! * Check the [`handler::disconnect`] module doc for more details on the disconnect handler. -//! * Check the [`handler::extract`] module doc for more details on the extractors. +//! * Check the [`extract`] module doc for more details on the extractors. //! //! ## Extractors //! Handlers params are called extractors and are used to extract data from the incoming connection/message. They are inspired by the axum extractors. From df94f5ceca1e636642dac3c93824d2b3d5cb6e82 Mon Sep 17 00:00:00 2001 From: totodore Date: Sun, 21 Apr 2024 18:37:21 -0300 Subject: [PATCH 17/19] test(socketio): increase timeout --- socketioxide/tests/extractors.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/socketioxide/tests/extractors.rs b/socketioxide/tests/extractors.rs index 292f1053..831771a3 100644 --- a/socketioxide/tests/extractors.rs +++ b/socketioxide/tests/extractors.rs @@ -14,13 +14,13 @@ mod fixture; mod utils; async fn timeout_rcv(srx: &mut tokio::sync::mpsc::Receiver) -> T { - tokio::time::timeout(Duration::from_millis(200), srx.recv()) + tokio::time::timeout(Duration::from_millis(500), srx.recv()) .await .unwrap() .unwrap() } async fn timeout_rcv_err(srx: &mut tokio::sync::mpsc::Receiver) { - tokio::time::timeout(Duration::from_millis(200), srx.recv()) + tokio::time::timeout(Duration::from_millis(500), srx.recv()) .await .unwrap_err(); } From d5c01f7c9caa6cbd9bce4cb9d3aec474bf473a68 Mon Sep 17 00:00:00 2001 From: totodore Date: Tue, 21 May 2024 15:37:37 +0200 Subject: [PATCH 18/19] doc(socketio): improve doc --- socketioxide/src/extensions.rs | 3 +++ socketioxide/src/lib.rs | 1 + 2 files changed, 4 insertions(+) diff --git a/socketioxide/src/extensions.rs b/socketioxide/src/extensions.rs index b8d243e3..126030cf 100644 --- a/socketioxide/src/extensions.rs +++ b/socketioxide/src/extensions.rs @@ -6,6 +6,9 @@ //! to allow concurrent access. Moreover, any value extracted from the map is cloned before being returned. //! //! This is necessary because [`Extensions`] are shared between all the threads that handle the same socket. +//! +//! You can use the [`Extension`](crate::extract::Extension) or +//! [`MaybeExtension`](crate::extract::MaybeExtension) extractor to extract an extension of the given type. use std::collections::HashMap; use std::fmt; diff --git a/socketioxide/src/lib.rs b/socketioxide/src/lib.rs index 74721cb8..868c33e4 100644 --- a/socketioxide/src/lib.rs +++ b/socketioxide/src/lib.rs @@ -188,6 +188,7 @@ //! Because the socket is not yet connected to the namespace, //! you can't send messages to it from the middleware. //! +//! //! See the [`handler::connect`](handler::connect#middleware) module doc for more details on middlewares and examples. //! //! ## [Emiting data](#emiting-data) From afed9fbbe103d78769d99ef82bd04b0b0d3cbaf7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9odore=20Pr=C3=A9vot?= Date: Fri, 24 May 2024 15:38:53 +0000 Subject: [PATCH 19/19] feat(adapter): remove adapter generic param and use boxed --- socketioxide/src/ack.rs | 12 +- socketioxide/src/adapter.rs | 220 ++++++++++++++----------- socketioxide/src/client.rs | 36 ++-- socketioxide/src/errors.rs | 1 - socketioxide/src/extract/data.rs | 30 ++-- socketioxide/src/extract/extensions.rs | 68 ++++---- socketioxide/src/extract/mod.rs | 8 +- socketioxide/src/extract/socket.rs | 78 ++++----- socketioxide/src/extract/state.rs | 18 +- socketioxide/src/handler/connect.rs | 168 +++++++++++-------- socketioxide/src/handler/disconnect.rs | 45 +++-- socketioxide/src/handler/message.rs | 68 ++++---- socketioxide/src/handler/mod.rs | 6 +- socketioxide/src/io.rs | 75 +++++---- socketioxide/src/layer.rs | 21 +-- socketioxide/src/ns.rs | 59 ++++--- socketioxide/src/operators.rs | 66 ++++---- socketioxide/src/service.rs | 41 ++--- socketioxide/src/socket.rs | 58 ++++--- socketioxide/tests/fixture.rs | 2 +- 20 files changed, 572 insertions(+), 508 deletions(-) diff --git a/socketioxide/src/ack.rs b/socketioxide/src/ack.rs index 699f8f74..171511cb 100644 --- a/socketioxide/src/ack.rs +++ b/socketioxide/src/ack.rs @@ -19,7 +19,7 @@ use serde::de::DeserializeOwned; use serde_json::Value; use tokio::{sync::oneshot::Receiver, time::Timeout}; -use crate::{adapter::Adapter, errors::AckError, extract::SocketRef, packet::Packet, SocketError}; +use crate::{errors::AckError, extract::SocketRef, packet::Packet, SocketError}; /// An acknowledgement sent by the client. /// It contains the data sent by the client and the binary payloads if there are any. @@ -145,9 +145,9 @@ impl AckInnerStream { /// /// The [`AckInnerStream`] will wait for the default timeout specified in the config /// (5s by default) if no custom timeout is specified. - pub fn broadcast( + pub fn broadcast( packet: Packet<'static>, - sockets: Vec>, + sockets: Vec, duration: Option, ) -> Self { let rxs = FuturesUnordered::new(); @@ -311,13 +311,13 @@ mod test { use engineioxide::sid::Sid; use futures_util::StreamExt; - use crate::{adapter::LocalAdapter, ns::Namespace, socket::Socket}; + use crate::{ns::Namespace, socket::Socket}; use super::*; - fn create_socket() -> Arc> { + fn create_socket() -> Arc { let sid = Sid::new(); - let ns = Namespace::::new_dummy([sid]).into(); + let ns = Namespace::new_dummy([sid]).into(); let socket = Socket::new_dummy(sid, ns); socket.into() } diff --git a/socketioxide/src/adapter.rs b/socketioxide/src/adapter.rs index 0e875b19..5a998892 100644 --- a/socketioxide/src/adapter.rs +++ b/socketioxide/src/adapter.rs @@ -6,7 +6,6 @@ use std::{ borrow::Cow, collections::{HashMap, HashSet}, - convert::Infallible, sync::{RwLock, Weak}, time::Duration, }; @@ -48,32 +47,28 @@ pub struct BroadcastOptions { pub sid: Option, } //TODO: Make an AsyncAdapter trait -/// An adapter is responsible for managing the state of the server. +/// An adapter is responsible for managing the state of the server. There is one adapter per namespace. /// This adapter can be implemented to share the state between multiple servers. /// The default adapter is the [`LocalAdapter`], which stores the state in memory. pub trait Adapter: std::fmt::Debug + Send + Sync + 'static { - /// An error that can occur when using the adapter. The default [`LocalAdapter`] has an [`Infallible`] error. - type Error: std::error::Error + Into + Send + Sync + 'static; - - /// Create a new adapter and give the namespace ref to retrieve sockets. - fn new(ns: Weak>) -> Self - where - Self: Sized; + /// Returns a boxed clone of the adapter. + /// It is used to create a new empty instance of the adapter for a new namespace. + fn boxed_clone(&self) -> Box; /// Initializes the adapter. - fn init(&self) -> Result<(), Self::Error>; + fn init(&mut self, ns: Weak) -> Result<(), AdapterError>; /// Closes the adapter. - fn close(&self) -> Result<(), Self::Error>; + fn close(&self) -> Result<(), AdapterError>; /// Returns the number of servers. - fn server_count(&self) -> Result; + fn server_count(&self) -> Result; /// Adds the socket to all the rooms. - fn add_all(&self, sid: Sid, rooms: impl RoomParam) -> Result<(), Self::Error>; + fn add_all(&self, sid: Sid, rooms: Vec) -> Result<(), AdapterError>; /// Removes the socket from the rooms. - fn del(&self, sid: Sid, rooms: impl RoomParam) -> Result<(), Self::Error>; + fn del(&self, sid: Sid, rooms: Vec) -> Result<(), AdapterError>; /// Removes the socket from all the rooms. - fn del_all(&self, sid: Sid) -> Result<(), Self::Error>; + fn del_all(&self, sid: Sid) -> Result<(), AdapterError>; /// Broadcasts the packet to the sockets that match the [`BroadcastOptions`]. fn broadcast(&self, packet: Packet<'_>, opts: BroadcastOptions) -> Result<(), BroadcastError>; @@ -87,28 +82,24 @@ pub trait Adapter: std::fmt::Debug + Send + Sync + 'static { ) -> AckInnerStream; /// Returns the sockets ids that match the [`BroadcastOptions`]. - fn sockets(&self, rooms: impl RoomParam) -> Result, Self::Error>; + fn sockets(&self, rooms: Vec) -> Result, AdapterError>; /// Returns the rooms of the socket. - fn socket_rooms(&self, sid: Sid) -> Result, Self::Error>; + fn socket_rooms(&self, sid: Sid) -> Result, AdapterError>; /// Returns the sockets that match the [`BroadcastOptions`]. - fn fetch_sockets(&self, opts: BroadcastOptions) -> Result>, Self::Error> - where - Self: Sized; + fn fetch_sockets(&self, opts: BroadcastOptions) -> Result, AdapterError>; /// Adds the sockets that match the [`BroadcastOptions`] to the rooms. - fn add_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam) - -> Result<(), Self::Error>; + fn add_sockets(&self, opts: BroadcastOptions, rooms: Vec) -> Result<(), AdapterError>; /// Removes the sockets that match the [`BroadcastOptions`] from the rooms. - fn del_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam) - -> Result<(), Self::Error>; + fn del_sockets(&self, opts: BroadcastOptions, rooms: Vec) -> Result<(), AdapterError>; /// Disconnects the sockets that match the [`BroadcastOptions`]. fn disconnect_socket(&self, opts: BroadcastOptions) -> Result<(), Vec>; /// Returns all the rooms for this adapter. - fn rooms(&self) -> Result, Self::Error>; + fn rooms(&self) -> Result, AdapterError>; //TODO: implement // fn server_side_emit(&self, packet: Packet, opts: BroadcastOptions) -> Result; @@ -117,33 +108,23 @@ pub trait Adapter: std::fmt::Debug + Send + Sync + 'static { } /// The default adapter. Store the state in memory. -#[derive(Debug)] +#[derive(Debug, Default)] pub struct LocalAdapter { rooms: RwLock>>, - ns: Weak>, -} - -impl From for AdapterError { - fn from(_: Infallible) -> AdapterError { - unreachable!() - } + ns: Weak, } impl Adapter for LocalAdapter { - type Error = Infallible; - - fn new(ns: Weak>) -> Self { - Self { - rooms: HashMap::new().into(), - ns, - } + fn boxed_clone(&self) -> Box { + Box::new(Self::new()) } - fn init(&self) -> Result<(), Infallible> { + fn init(&mut self, ns: Weak) -> Result<(), AdapterError> { + self.ns = ns; Ok(()) } - fn close(&self) -> Result<(), Infallible> { + fn close(&self) -> Result<(), AdapterError> { #[cfg(feature = "tracing")] tracing::debug!("closing local adapter: {}", self.ns.upgrade().unwrap().path); let mut rooms = self.rooms.write().unwrap(); @@ -152,21 +133,21 @@ impl Adapter for LocalAdapter { Ok(()) } - fn server_count(&self) -> Result { + fn server_count(&self) -> Result { Ok(1) } - fn add_all(&self, sid: Sid, rooms: impl RoomParam) -> Result<(), Infallible> { + fn add_all(&self, sid: Sid, rooms: Vec) -> Result<(), AdapterError> { let mut rooms_map = self.rooms.write().unwrap(); - for room in rooms.into_room_iter() { + for room in rooms { rooms_map.entry(room).or_default().insert(sid); } Ok(()) } - fn del(&self, sid: Sid, rooms: impl RoomParam) -> Result<(), Infallible> { + fn del(&self, sid: Sid, rooms: Vec) -> Result<(), AdapterError> { let mut rooms_map = self.rooms.write().unwrap(); - for room in rooms.into_room_iter() { + for room in rooms { if let Some(room) = rooms_map.get_mut(&room) { room.remove(&sid); } @@ -174,7 +155,7 @@ impl Adapter for LocalAdapter { Ok(()) } - fn del_all(&self, sid: Sid) -> Result<(), Infallible> { + fn del_all(&self, sid: Sid) -> Result<(), AdapterError> { let mut rooms_map = self.rooms.write().unwrap(); for room in rooms_map.values_mut() { room.remove(&sid); @@ -214,7 +195,7 @@ impl Adapter for LocalAdapter { AckInnerStream::broadcast(packet, sockets, timeout) } - fn sockets(&self, rooms: impl RoomParam) -> Result, Infallible> { + fn sockets(&self, rooms: Vec) -> Result, AdapterError> { let mut opts = BroadcastOptions::default(); opts.rooms.extend(rooms.into_room_iter()); Ok(self @@ -225,7 +206,7 @@ impl Adapter for LocalAdapter { } //TODO: make this operation O(1) - fn socket_rooms(&self, sid: Sid) -> Result>, Infallible> { + fn socket_rooms(&self, sid: Sid) -> Result>, AdapterError> { let rooms_map = self.rooms.read().unwrap(); Ok(rooms_map .iter() @@ -234,11 +215,11 @@ impl Adapter for LocalAdapter { .collect()) } - fn fetch_sockets(&self, opts: BroadcastOptions) -> Result>, Infallible> { + fn fetch_sockets(&self, opts: BroadcastOptions) -> Result, AdapterError> { Ok(self.apply_opts(opts)) } - fn add_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam) -> Result<(), Infallible> { + fn add_sockets(&self, opts: BroadcastOptions, rooms: Vec) -> Result<(), AdapterError> { let rooms: Vec = rooms.into_room_iter().collect(); for socket in self.apply_opts(opts) { self.add_all(socket.id, rooms.clone()).unwrap(); @@ -246,7 +227,7 @@ impl Adapter for LocalAdapter { Ok(()) } - fn del_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam) -> Result<(), Infallible> { + fn del_sockets(&self, opts: BroadcastOptions, rooms: Vec) -> Result<(), AdapterError> { let rooms: Vec = rooms.into_room_iter().collect(); for socket in self.apply_opts(opts) { self.del(socket.id, rooms.clone()).unwrap(); @@ -270,14 +251,18 @@ impl Adapter for LocalAdapter { } } - fn rooms(&self) -> Result, Self::Error> { + fn rooms(&self) -> Result, AdapterError> { Ok(self.rooms.read().unwrap().keys().cloned().collect()) } } impl LocalAdapter { + /// Creates a new [LocalAdapter]. + pub fn new() -> Self { + Self::default() + } /// Applies the given `opts` and return the sockets that match. - fn apply_opts(&self, opts: BroadcastOptions) -> Vec> { + fn apply_opts(&self, opts: BroadcastOptions) -> Vec { let rooms = opts.rooms; let except = self.get_except_sids(&opts.except); @@ -335,19 +320,26 @@ mod test { }; } + fn local_adapter(ns: &Arc) -> LocalAdapter { + let mut adapter = LocalAdapter::new(); + adapter.init(Arc::downgrade(ns)).unwrap(); + adapter + } #[tokio::test] async fn test_server_count() { - let ns = Namespace::new_dummy([]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); + let ns = Namespace::new_dummy([]).into(); + let adapter = local_adapter(&ns); assert_eq!(adapter.server_count().unwrap(), 1); } #[tokio::test] async fn test_add_all() { let socket = Sid::new(); - let ns = Namespace::new_dummy([socket]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - adapter.add_all(socket, ["room1", "room2"]).unwrap(); + let ns = Namespace::new_dummy([socket]).into(); + let adapter = local_adapter(&ns); + adapter + .add_all(socket, vec!["room1".into(), "room2".into()]) + .unwrap(); let rooms_map = adapter.rooms.read().unwrap(); assert_eq!(rooms_map.len(), 2); assert_eq!(rooms_map.get("room1").unwrap().len(), 1); @@ -357,10 +349,12 @@ mod test { #[tokio::test] async fn test_del() { let socket = Sid::new(); - let ns = Namespace::new_dummy([socket]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - adapter.add_all(socket, ["room1", "room2"]).unwrap(); - adapter.del(socket, "room1").unwrap(); + let ns = Namespace::new_dummy([socket]).into(); + let adapter = local_adapter(&ns); + adapter + .add_all(socket, vec!["room1".into(), "room2".into()]) + .unwrap(); + adapter.del(socket, vec!["room1".into()]).unwrap(); let rooms_map = adapter.rooms.read().unwrap(); assert_eq!(rooms_map.len(), 2); assert_eq!(rooms_map.get("room1").unwrap().len(), 0); @@ -370,9 +364,11 @@ mod test { #[tokio::test] async fn test_del_all() { let socket = Sid::new(); - let ns = Namespace::new_dummy([socket]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - adapter.add_all(socket, ["room1", "room2"]).unwrap(); + let ns = Namespace::new_dummy([socket]).into(); + let adapter = local_adapter(&ns); + adapter + .add_all(socket, vec!["room1".into(), "room2".into()]) + .unwrap(); adapter.del_all(socket).unwrap(); let rooms_map = adapter.rooms.read().unwrap(); assert_eq!(rooms_map.len(), 2); @@ -385,11 +381,13 @@ mod test { let sid1 = Sid::new(); let sid2 = Sid::new(); let sid3 = Sid::new(); - let ns = Namespace::new_dummy([sid1, sid2, sid3]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - adapter.add_all(sid1, ["room1", "room2"]).unwrap(); - adapter.add_all(sid2, ["room1"]).unwrap(); - adapter.add_all(sid3, ["room2"]).unwrap(); + let ns = Namespace::new_dummy([sid1, sid2, sid3]).into(); + let adapter = local_adapter(&ns); + adapter + .add_all(sid1, vec!["room1".into(), "room2".into()]) + .unwrap(); + adapter.add_all(sid2, vec!["room1".into()]).unwrap(); + adapter.add_all(sid3, vec!["room2".into()]).unwrap(); assert!(adapter .socket_rooms(sid1) .unwrap() @@ -405,16 +403,16 @@ mod test { #[tokio::test] async fn test_add_socket() { let socket = Sid::new(); - let ns = Namespace::new_dummy([socket]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - adapter.add_all(socket, ["room1"]).unwrap(); + let ns = Namespace::new_dummy([socket]).into(); + let adapter = local_adapter(&ns); + adapter.add_all(socket, vec!["room1".into()]).unwrap(); let mut opts = BroadcastOptions { sid: Some(socket), ..Default::default() }; opts.rooms = hash_set!["room1".into()]; - adapter.add_sockets(opts, "room2").unwrap(); + adapter.add_sockets(opts, vec!["room2".into()]).unwrap(); let rooms_map = adapter.rooms.read().unwrap(); assert_eq!(rooms_map.len(), 2); @@ -425,16 +423,16 @@ mod test { #[tokio::test] async fn test_del_socket() { let socket = Sid::new(); - let ns = Namespace::new_dummy([socket]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - adapter.add_all(socket, ["room1"]).unwrap(); + let ns = Namespace::new_dummy([socket]).into(); + let adapter = local_adapter(&ns); + adapter.add_all(socket, vec!["room1".into()]).unwrap(); let mut opts = BroadcastOptions { sid: Some(socket), ..Default::default() }; opts.rooms = hash_set!["room1".into()]; - adapter.add_sockets(opts, "room2").unwrap(); + adapter.add_sockets(opts, vec!["room2".into()]).unwrap(); { let rooms_map = adapter.rooms.read().unwrap(); @@ -449,7 +447,7 @@ mod test { ..Default::default() }; opts.rooms = hash_set!["room1".into()]; - adapter.del_sockets(opts, "room2").unwrap(); + adapter.del_sockets(opts, vec!["room2".into()]).unwrap(); { let rooms_map = adapter.rooms.read().unwrap(); @@ -465,23 +463,29 @@ mod test { let socket0 = Sid::new(); let socket1 = Sid::new(); let socket2 = Sid::new(); - let ns = Namespace::new_dummy([socket0, socket1, socket2]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - adapter.add_all(socket0, ["room1", "room2"]).unwrap(); - adapter.add_all(socket1, ["room1", "room3"]).unwrap(); - adapter.add_all(socket2, ["room2", "room3"]).unwrap(); + let ns = Namespace::new_dummy([socket0, socket1, socket2]).into(); + let adapter = local_adapter(&ns); + adapter + .add_all(socket0, vec!["room1".into(), "room2".into()]) + .unwrap(); + adapter + .add_all(socket1, vec!["room1".into(), "room3".into()]) + .unwrap(); + adapter + .add_all(socket2, vec!["room2".into(), "room3".into()]) + .unwrap(); - let sockets = adapter.sockets("room1").unwrap(); + let sockets = adapter.sockets(vec!["room1".into()]).unwrap(); assert_eq!(sockets.len(), 2); assert!(sockets.contains(&socket0)); assert!(sockets.contains(&socket1)); - let sockets = adapter.sockets("room2").unwrap(); + let sockets = adapter.sockets(vec!["room2".into()]).unwrap(); assert_eq!(sockets.len(), 2); assert!(sockets.contains(&socket0)); assert!(sockets.contains(&socket2)); - let sockets = adapter.sockets("room3").unwrap(); + let sockets = adapter.sockets(vec!["room3".into()]).unwrap(); assert_eq!(sockets.len(), 2); assert!(sockets.contains(&socket1)); assert!(sockets.contains(&socket2)); @@ -492,16 +496,25 @@ mod test { let socket0 = Sid::new(); let socket1 = Sid::new(); let socket2 = Sid::new(); - let ns = Namespace::new_dummy([socket0, socket1, socket2]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); + let ns = Namespace::new_dummy([socket0, socket1, socket2]).into(); + let adapter = local_adapter(&ns); adapter - .add_all(socket0, ["room1", "room2", "room4"]) + .add_all( + socket0, + vec!["room1".into(), "room2".into(), "room4".into()], + ) .unwrap(); adapter - .add_all(socket1, ["room1", "room3", "room5"]) + .add_all( + socket1, + vec!["room1".into(), "room3".into(), "room5".into()], + ) .unwrap(); adapter - .add_all(socket2, ["room2", "room3", "room6"]) + .add_all( + socket2, + vec!["room2".into(), "room3".into(), "room6".into()], + ) .unwrap(); let mut opts = BroadcastOptions { @@ -511,7 +524,7 @@ mod test { opts.rooms = hash_set!["room5".into()]; adapter.disconnect_socket(opts).unwrap(); - let sockets = adapter.sockets("room2").unwrap(); + let sockets = adapter.sockets(vec!["room2".into()]).unwrap(); assert_eq!(sockets.len(), 2); assert!(sockets.contains(&socket2)); assert!(sockets.contains(&socket0)); @@ -521,15 +534,22 @@ mod test { let socket0 = Sid::new(); let socket1 = Sid::new(); let socket2 = Sid::new(); - let ns = Namespace::new_dummy([socket0, socket1, socket2]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); + let ns = Namespace::new_dummy([socket0, socket1, socket2]).into(); + let adapter = local_adapter(&ns); // Add socket 0 to room1 and room2 - adapter.add_all(socket0, ["room1", "room2"]).unwrap(); + adapter + .add_all(socket0, vec!["room1".into(), "room2".into()]) + .unwrap(); // Add socket 1 to room1 and room3 - adapter.add_all(socket1, ["room1", "room3"]).unwrap(); + adapter + .add_all(socket1, vec!["room1".into(), "room3".into()]) + .unwrap(); // Add socket 2 to room2 and room3 adapter - .add_all(socket2, ["room1", "room2", "room3"]) + .add_all( + socket2, + vec!["room1".into(), "room2".into(), "room3".into()], + ) .unwrap(); // socket 2 is the sender diff --git a/socketioxide/src/client.rs b/socketioxide/src/client.rs index ef60cf7e..62db0076 100644 --- a/socketioxide/src/client.rs +++ b/socketioxide/src/client.rs @@ -11,7 +11,6 @@ use futures_util::{FutureExt, TryFutureExt}; use engineioxide::sid::Sid; use tokio::sync::oneshot; -use crate::adapter::Adapter; use crate::handler::ConnectHandler; use crate::socket::DisconnectReason; use crate::ProtocolVersion; @@ -22,20 +21,21 @@ use crate::{ SocketIoConfig, }; -#[derive(Debug)] -pub struct Client { +pub struct Client { pub(crate) config: Arc, - ns: RwLock, Arc>>>, + ns: RwLock, Arc>>, + state: Arc, } -impl Client { +impl Client { pub fn new(config: Arc) -> Self { - #[cfg(feature = "state")] - crate::state::freeze_state(); + // #[cfg(feature = "state")] + // crate::state::freeze_state(); Self { config, ns: RwLock::new(HashMap::new()), + state: Arc::new(state::TypeMap::new()), } } @@ -113,12 +113,12 @@ impl Client { /// Adds a new namespace handler pub fn add_ns(&self, path: Cow<'static, str>, callback: C) where - C: ConnectHandler, + C: ConnectHandler, T: Send + Sync + 'static, { #[cfg(feature = "tracing")] tracing::debug!("adding namespace {}", path); - let ns = Namespace::new(path.clone(), callback); + let ns = Namespace::new(path.clone(), callback, self.config.adapter.boxed_clone()); self.ns.write().unwrap().insert(path, ns); } @@ -137,7 +137,7 @@ impl Client { } } - pub fn get_ns(&self, path: &str) -> Option>> { + pub fn get_ns(&self, path: &str) -> Option> { self.ns.read().unwrap().get(path).cloned() } @@ -215,7 +215,7 @@ pub struct SocketData { pub connect_recv_tx: Mutex>>, } -impl EngineIoHandler for Client { +impl EngineIoHandler for Client { type Data = SocketData; #[cfg_attr(feature = "tracing", tracing::instrument(skip(self, socket), fields(sid = socket.id.to_string())))] @@ -333,6 +333,16 @@ impl EngineIoHandler for Client { } } +impl std::fmt::Debug for Client { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Client") + .field("config", &self.config) + .field("ns", &self.ns) + .field("state", &self.state) + .finish() + } +} + /// Utility that applies an incoming binary payload to a partial binary packet /// waiting to be filled with all the payloads /// @@ -363,12 +373,12 @@ mod test { use crate::adapter::LocalAdapter; const CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(10); - fn create_client() -> super::Client { + fn create_client() -> super::Client { let config = crate::SocketIoConfig { connect_timeout: CONNECT_TIMEOUT, ..Default::default() }; - let client = Client::::new(std::sync::Arc::new(config)); + let client = Client::new(std::sync::Arc::new(config)); client.add_ns("/".into(), || {}); client } diff --git a/socketioxide/src/errors.rs b/socketioxide/src/errors.rs index 63e446f8..ead60030 100644 --- a/socketioxide/src/errors.rs +++ b/socketioxide/src/errors.rs @@ -1,7 +1,6 @@ use engineioxide::{sid::Sid, socket::DisconnectReason as EIoDisconnectReason}; use std::fmt::{Debug, Display}; use tokio::{sync::mpsc::error::TrySendError, time::error::Elapsed}; - /// Error type for socketio #[derive(thiserror::Error, Debug)] pub enum Error { diff --git a/socketioxide/src/extract/data.rs b/socketioxide/src/extract/data.rs index 51ba5292..d47d7bc5 100644 --- a/socketioxide/src/extract/data.rs +++ b/socketioxide/src/extract/data.rs @@ -2,7 +2,7 @@ use std::convert::Infallible; use std::sync::Arc; use crate::handler::{FromConnectParts, FromMessage, FromMessageParts}; -use crate::{adapter::Adapter, socket::Socket}; +use crate::socket::Socket; use bytes::Bytes; use serde::de::DeserializeOwned; use serde_json::Value; @@ -21,27 +21,29 @@ fn upwrap_array(v: &mut Value) { /// If a deserialization error occurs, the [`ConnectHandler`](crate::handler::ConnectHandler) won't be called /// and an error log will be print if the `tracing` feature is enabled. pub struct Data(pub T); -impl FromConnectParts for Data +impl FromConnectParts for Data where T: DeserializeOwned, - A: Adapter, { type Error = serde_json::Error; - fn from_connect_parts(_: &Arc>, auth: &Option) -> Result { + fn from_connect_parts( + _: &Arc, + auth: &Option, + _: &Arc, + ) -> Result { auth.as_ref() .map(|a| serde_json::from_str::(a)) .unwrap_or(serde_json::from_str::("{}")) .map(Data) } } -impl FromMessageParts for Data +impl FromMessageParts for Data where T: DeserializeOwned, - A: Adapter, { type Error = serde_json::Error; fn from_message_parts( - _: &Arc>, + _: &Arc, v: &mut serde_json::Value, _: &mut Vec, _: &Option, @@ -54,13 +56,12 @@ where /// An Extractor that returns the deserialized data related to the event. pub struct TryData(pub Result); -impl FromConnectParts for TryData +impl FromConnectParts for TryData where T: DeserializeOwned, - A: Adapter, { type Error = Infallible; - fn from_connect_parts(_: &Arc>, auth: &Option) -> Result { + fn from_connect_parts(_: &Arc, auth: &Option) -> Result { let v: Result = auth .as_ref() .map(|a| serde_json::from_str(a)) @@ -68,14 +69,13 @@ where Ok(TryData(v)) } } -impl FromMessageParts for TryData +impl FromMessageParts for TryData where T: DeserializeOwned, - A: Adapter, { type Error = Infallible; fn from_message_parts( - _: &Arc>, + _: &Arc, v: &mut serde_json::Value, _: &mut Vec, _: &Option, @@ -88,10 +88,10 @@ where /// An Extractor that returns the binary data of the message. /// If there is no binary data, it will contain an empty vec. pub struct Bin(pub Vec); -impl FromMessage for Bin { +impl FromMessage for Bin { type Error = Infallible; fn from_message( - _: Arc>, + _: Arc, _: serde_json::Value, bin: Vec, _: Option, diff --git a/socketioxide/src/extract/extensions.rs b/socketioxide/src/extract/extensions.rs index 9a898503..5f9009e6 100644 --- a/socketioxide/src/extract/extensions.rs +++ b/socketioxide/src/extract/extensions.rs @@ -1,7 +1,6 @@ use std::convert::Infallible; use std::sync::Arc; -use crate::adapter::Adapter; use crate::handler::{FromConnectParts, FromDisconnectParts, FromMessageParts}; use crate::socket::{DisconnectReason, Socket}; use bytes::Bytes; @@ -30,7 +29,7 @@ impl std::fmt::Debug for ExtensionNotFound { impl std::error::Error for ExtensionNotFound {} fn extract_http_extension( - s: &Arc>, + s: &Arc, ) -> Result> { s.req_parts() .extensions @@ -44,45 +43,48 @@ pub struct HttpExtension(pub T); /// An Extractor that returns a clone extension from the request parts if it exists. pub struct MaybeHttpExtension(pub Option); -impl FromConnectParts for HttpExtension { +impl FromConnectParts for HttpExtension { type Error = ExtensionNotFound; fn from_connect_parts( - s: &Arc>, + s: &Arc, _: &Option, + _: &Arc, ) -> Result> { extract_http_extension(s).map(HttpExtension) } } -impl FromConnectParts for MaybeHttpExtension { +impl FromConnectParts for MaybeHttpExtension { type Error = Infallible; - fn from_connect_parts(s: &Arc>, _: &Option) -> Result { + fn from_connect_parts( + s: &Arc, + _: &Option, + _: &Arc, + ) -> Result { Ok(MaybeHttpExtension(extract_http_extension(s).ok())) } } -impl FromDisconnectParts for HttpExtension { +impl FromDisconnectParts for HttpExtension { type Error = ExtensionNotFound; fn from_disconnect_parts( - s: &Arc>, + s: &Arc, _: DisconnectReason, ) -> Result> { extract_http_extension(s).map(HttpExtension) } } -impl FromDisconnectParts - for MaybeHttpExtension -{ +impl FromDisconnectParts for MaybeHttpExtension { type Error = Infallible; - fn from_disconnect_parts(s: &Arc>, _: DisconnectReason) -> Result { + fn from_disconnect_parts(s: &Arc, _: DisconnectReason) -> Result { Ok(MaybeHttpExtension(extract_http_extension(s).ok())) } } -impl FromMessageParts for HttpExtension { +impl FromMessageParts for HttpExtension { type Error = ExtensionNotFound; fn from_message_parts( - s: &Arc>, + s: &Arc, _: &mut serde_json::Value, _: &mut Vec, _: &Option, @@ -90,10 +92,10 @@ impl FromMessageParts for HttpE extract_http_extension(s).map(HttpExtension) } } -impl FromMessageParts for MaybeHttpExtension { +impl FromMessageParts for MaybeHttpExtension { type Error = Infallible; fn from_message_parts( - s: &Arc>, + s: &Arc, _: &mut serde_json::Value, _: &mut Vec, _: &Option, @@ -110,7 +112,7 @@ mod extensions_extract { use super::*; fn extract_extension( - s: &Arc>, + s: &Arc, ) -> Result> { s.extensions .get::() @@ -125,43 +127,45 @@ mod extensions_extract { /// An Extractor that returns the extension of the given type if it exists or `None` otherwise. pub struct MaybeExtension(pub Option); - impl FromConnectParts for Extension { + impl FromConnectParts for Extension { type Error = ExtensionNotFound; fn from_connect_parts( - s: &Arc>, + s: &Arc, _: &Option, + _: &Arc, ) -> Result> { extract_extension(s).map(Extension) } } - impl FromConnectParts for MaybeExtension { + impl FromConnectParts for MaybeExtension { type Error = Infallible; - fn from_connect_parts(s: &Arc>, _: &Option) -> Result { + fn from_connect_parts( + s: &Arc, + _: &Option, + _: &Arc, + ) -> Result { Ok(MaybeExtension(extract_extension(s).ok())) } } - impl FromDisconnectParts for Extension { + impl FromDisconnectParts for Extension { type Error = ExtensionNotFound; fn from_disconnect_parts( - s: &Arc>, + s: &Arc, _: DisconnectReason, ) -> Result> { extract_extension(s).map(Extension) } } - impl FromDisconnectParts for MaybeExtension { + impl FromDisconnectParts for MaybeExtension { type Error = Infallible; - fn from_disconnect_parts( - s: &Arc>, - _: DisconnectReason, - ) -> Result { + fn from_disconnect_parts(s: &Arc, _: DisconnectReason) -> Result { Ok(MaybeExtension(extract_extension(s).ok())) } } - impl FromMessageParts for Extension { + impl FromMessageParts for Extension { type Error = ExtensionNotFound; fn from_message_parts( - s: &Arc>, + s: &Arc, _: &mut serde_json::Value, _: &mut Vec, _: &Option, @@ -169,10 +173,10 @@ mod extensions_extract { extract_extension(s).map(Extension) } } - impl FromMessageParts for MaybeExtension { + impl FromMessageParts for MaybeExtension { type Error = Infallible; fn from_message_parts( - s: &Arc>, + s: &Arc, _: &mut serde_json::Value, _: &mut Vec, _: &Option, diff --git a/socketioxide/src/extract/mod.rs b/socketioxide/src/extract/mod.rs index b7adcb3e..9f5ac96f 100644 --- a/socketioxide/src/extract/mod.rs +++ b/socketioxide/src/extract/mod.rs @@ -56,9 +56,9 @@ //! } //! impl std::error::Error for UserIdNotFound {} //! -//! impl FromConnectParts for UserId { +//! impl FromConnectParts for UserId { //! type Error = Infallible; -//! fn from_connect_parts(s: &Arc>, _: &Option) -> Result { +//! fn from_connect_parts(s: &Arc, _: &Option) -> Result { //! // In a real app it would be better to parse the query params with a crate like `url` //! let uri = &s.req_parts().uri; //! let uid = uri @@ -72,11 +72,11 @@ //! //! // Here, if the user id is not found, the handler won't be called //! // and a tracing `error` log will be printed (if the `tracing` feature is enabled) -//! impl FromMessageParts for UserId { +//! impl FromMessageParts for UserId { //! type Error = UserIdNotFound; //! //! fn from_message_parts( -//! s: &Arc>, +//! s: &Arc, //! _: &mut serde_json::Value, //! _: &mut Vec, //! _: &Option, diff --git a/socketioxide/src/extract/socket.rs b/socketioxide/src/extract/socket.rs index 0a08ba27..44609348 100644 --- a/socketioxide/src/extract/socket.rs +++ b/socketioxide/src/extract/socket.rs @@ -4,28 +4,28 @@ use std::sync::Arc; use crate::errors::{DisconnectError, SendError}; use crate::handler::{FromConnectParts, FromDisconnectParts, FromMessageParts}; use crate::socket::DisconnectReason; -use crate::{ - adapter::{Adapter, LocalAdapter}, - packet::Packet, - socket::Socket, -}; +use crate::{packet::Packet, socket::Socket}; use bytes::Bytes; use serde::Serialize; /// An Extractor that returns a reference to a [`Socket`]. #[derive(Debug)] -pub struct SocketRef(Arc>); +pub struct SocketRef(Arc); -impl FromConnectParts for SocketRef { +impl FromConnectParts for SocketRef { type Error = Infallible; - fn from_connect_parts(s: &Arc>, _: &Option) -> Result { + fn from_connect_parts( + s: &Arc, + _: &Option, + _: &Arc, + ) -> Result { Ok(SocketRef(s.clone())) } } -impl FromMessageParts for SocketRef { +impl FromMessageParts for SocketRef { type Error = Infallible; fn from_message_parts( - s: &Arc>, + s: &Arc, _: &mut serde_json::Value, _: &mut Vec, _: &Option, @@ -33,40 +33,40 @@ impl FromMessageParts for SocketRef { Ok(SocketRef(s.clone())) } } -impl FromDisconnectParts for SocketRef { +impl FromDisconnectParts for SocketRef { type Error = Infallible; - fn from_disconnect_parts(s: &Arc>, _: DisconnectReason) -> Result { + fn from_disconnect_parts(s: &Arc, _: DisconnectReason) -> Result { Ok(SocketRef(s.clone())) } } -impl std::ops::Deref for SocketRef { - type Target = Socket; +impl std::ops::Deref for SocketRef { + type Target = Socket; #[inline(always)] fn deref(&self) -> &Self::Target { &self.0 } } -impl PartialEq for SocketRef { +impl PartialEq for SocketRef { #[inline(always)] fn eq(&self, other: &Self) -> bool { self.0.id == other.0.id } } -impl From>> for SocketRef { +impl From> for SocketRef { #[inline(always)] - fn from(socket: Arc>) -> Self { + fn from(socket: Arc) -> Self { Self(socket) } } -impl Clone for SocketRef { +impl Clone for SocketRef { fn clone(&self) -> Self { Self(self.0.clone()) } } -impl SocketRef { +impl SocketRef { /// Disconnect the socket from the current namespace, /// /// It will also call the disconnect handler if it is set. @@ -79,15 +79,15 @@ impl SocketRef { /// An Extractor to send an ack response corresponding to the current event. /// If the client sent a normal message without expecting an ack, the ack callback will do nothing. #[derive(Debug)] -pub struct AckSender { +pub struct AckSender { binary: Vec, - socket: Arc>, + socket: Arc, ack_id: Option, } -impl FromMessageParts for AckSender { +impl FromMessageParts for AckSender { type Error = Infallible; fn from_message_parts( - s: &Arc>, + s: &Arc, _: &mut serde_json::Value, _: &mut Vec, ack_id: &Option, @@ -95,8 +95,8 @@ impl FromMessageParts for AckSender { Ok(Self::new(s.clone(), *ack_id)) } } -impl AckSender { - pub(crate) fn new(socket: Arc>, ack_id: Option) -> Self { +impl AckSender { + pub(crate) fn new(socket: Arc, ack_id: Option) -> Self { Self { binary: vec![], socket, @@ -137,16 +137,16 @@ impl AckSender { } } -impl FromConnectParts for crate::ProtocolVersion { +impl FromConnectParts for crate::ProtocolVersion { type Error = Infallible; - fn from_connect_parts(s: &Arc>, _: &Option) -> Result { + fn from_connect_parts(s: &Arc, _: &Option) -> Result { Ok(s.protocol()) } } -impl FromMessageParts for crate::ProtocolVersion { +impl FromMessageParts for crate::ProtocolVersion { type Error = Infallible; fn from_message_parts( - s: &Arc>, + s: &Arc, _: &mut serde_json::Value, _: &mut Vec, _: &Option, @@ -154,23 +154,23 @@ impl FromMessageParts for crate::ProtocolVersion { Ok(s.protocol()) } } -impl FromDisconnectParts for crate::ProtocolVersion { +impl FromDisconnectParts for crate::ProtocolVersion { type Error = Infallible; - fn from_disconnect_parts(s: &Arc>, _: DisconnectReason) -> Result { + fn from_disconnect_parts(s: &Arc, _: DisconnectReason) -> Result { Ok(s.protocol()) } } -impl FromConnectParts for crate::TransportType { +impl FromConnectParts for crate::TransportType { type Error = Infallible; - fn from_connect_parts(s: &Arc>, _: &Option) -> Result { + fn from_connect_parts(s: &Arc, _: &Option) -> Result { Ok(s.transport_type()) } } -impl FromMessageParts for crate::TransportType { +impl FromMessageParts for crate::TransportType { type Error = Infallible; fn from_message_parts( - s: &Arc>, + s: &Arc, _: &mut serde_json::Value, _: &mut Vec, _: &Option, @@ -178,17 +178,17 @@ impl FromMessageParts for crate::TransportType { Ok(s.transport_type()) } } -impl FromDisconnectParts for crate::TransportType { +impl FromDisconnectParts for crate::TransportType { type Error = Infallible; - fn from_disconnect_parts(s: &Arc>, _: DisconnectReason) -> Result { + fn from_disconnect_parts(s: &Arc, _: DisconnectReason) -> Result { Ok(s.transport_type()) } } -impl FromDisconnectParts for DisconnectReason { +impl FromDisconnectParts for DisconnectReason { type Error = Infallible; fn from_disconnect_parts( - _: &Arc>, + _: &Arc, reason: DisconnectReason, ) -> Result { Ok(reason) diff --git a/socketioxide/src/extract/state.rs b/socketioxide/src/extract/state.rs index 4c7f67ad..3b639114 100644 --- a/socketioxide/src/extract/state.rs +++ b/socketioxide/src/extract/state.rs @@ -4,7 +4,6 @@ use crate::state::get_state; use std::ops::Deref; use std::sync::Arc; -use crate::adapter::Adapter; use crate::handler::{FromConnectParts, FromDisconnectParts, FromMessageParts}; use crate::socket::{DisconnectReason, Socket}; @@ -60,21 +59,24 @@ impl std::fmt::Debug for StateNotFound { } impl std::error::Error for StateNotFound {} -impl FromConnectParts for State { +impl FromConnectParts for State { type Error = StateNotFound; fn from_connect_parts( - _: &Arc>, + _: &Arc, _: &Option, + state: Arc, ) -> Result> { - get_state::() + state + .get::() + .clone() .map(State) .ok_or(StateNotFound(std::marker::PhantomData)) } } -impl FromDisconnectParts for State { +impl FromDisconnectParts for State { type Error = StateNotFound; fn from_disconnect_parts( - _: &Arc>, + _: &Arc, _: DisconnectReason, ) -> Result> { get_state::() @@ -82,10 +84,10 @@ impl FromDisconnectParts for State { .ok_or(StateNotFound(std::marker::PhantomData)) } } -impl FromMessageParts for State { +impl FromMessageParts for State { type Error = StateNotFound; fn from_message_parts( - _: &Arc>, + _: &Arc, _: &mut serde_json::Value, _: &mut Vec, _: &Option, diff --git a/socketioxide/src/handler/connect.rs b/socketioxide/src/handler/connect.rs index 888a3b30..3ec04744 100644 --- a/socketioxide/src/handler/connect.rs +++ b/socketioxide/src/handler/connect.rs @@ -117,22 +117,23 @@ use std::sync::Arc; use futures_core::Future; -use crate::{adapter::Adapter, socket::Socket}; +use crate::socket::Socket; use super::MakeErasedHandler; /// A Type Erased [`ConnectHandler`] so it can be stored in a HashMap -pub(crate) type BoxedConnectHandler = Box>; +pub(crate) type BoxedConnectHandler = Box; type MiddlewareRes = Result<(), Box>; type MiddlewareResFut<'a> = Pin + Send + 'a>>; -pub(crate) trait ErasedConnectHandler: Send + Sync + 'static { - fn call(&self, s: Arc>, auth: Option); +pub(crate) trait ErasedConnectHandler: Send + Sync + 'static { + fn call(&self, s: Arc, auth: Option, state: Arc); fn call_middleware<'a>( &'a self, - s: Arc>, + s: Arc, auth: &'a Option, + state: &'a Arc, ) -> MiddlewareResFut<'a>; } @@ -142,13 +143,17 @@ pub(crate) trait ErasedConnectHandler: Send + Sync + 'static { /// /// * See the [`connect`](super::connect) module doc for more details on connect handler. /// * See the [`extract`](crate::extract) module doc for more details on available extractors. -pub trait FromConnectParts: Sized { +pub trait FromConnectParts: Sized { /// The error type returned by the extractor type Error: std::error::Error + Send + 'static; /// Extract the arguments from the connect event. /// If it fails, the handler is not called - fn from_connect_parts(s: &Arc>, auth: &Option) -> Result; + fn from_connect_parts( + s: &Arc, + auth: &Option, + state: &Arc, + ) -> Result; } /// Define a middleware for the connect event. @@ -157,16 +162,17 @@ pub trait FromConnectParts: Sized { /// /// * See the [`connect`](super::connect) module doc for more details on connect middlewares. /// * See the [`extract`](crate::extract) module doc for more details on available extractors. -pub trait ConnectMiddleware: Send + Sync + 'static { +pub trait ConnectMiddleware: Send + Sync + 'static { /// Call the middleware with the given arguments. fn call<'a>( &'a self, - s: Arc>, + s: Arc, auth: &'a Option, + state: &'a Arc, ) -> impl Future + Send; #[doc(hidden)] - fn phantom(&self) -> std::marker::PhantomData<(A, T)> { + fn phantom(&self) -> std::marker::PhantomData { std::marker::PhantomData } } @@ -176,15 +182,16 @@ pub trait ConnectMiddleware: Send + Sync + 'static { /// /// * See the [`connect`](super::connect) module doc for more details on connect handler. /// * See the [`extract`](crate::extract) module doc for more details on available extractors. -pub trait ConnectHandler: Send + Sync + 'static { +pub trait ConnectHandler: Send + Sync + 'static { /// Call the handler with the given arguments. - fn call(&self, s: Arc>, auth: Option); + fn call(&self, s: Arc, auth: Option, state: Arc); /// Call the middleware with the given arguments. fn call_middleware<'a>( &'a self, - _: Arc>, + _: Arc, _: &'a Option, + _: &'a Arc, ) -> MiddlewareResFut<'a> { Box::pin(async move { Ok(()) }) } @@ -233,10 +240,10 @@ pub trait ConnectHandler: Send + Sync + 'static { /// let (_, io) = SocketIo::new_layer(); /// io.ns("/", handler.with(middleware).with(other_middleware)); /// ``` - fn with(self, middleware: M) -> impl ConnectHandler + fn with(self, middleware: M) -> impl ConnectHandler where Self: Sized, - M: ConnectMiddleware + Send + Sync + 'static, + M: ConnectMiddleware + Send + Sync + 'static, T: Send + Sync + 'static, T1: Send + Sync + 'static, { @@ -252,10 +259,10 @@ pub trait ConnectHandler: Send + Sync + 'static { std::marker::PhantomData } } -struct LayeredConnectHandler { +struct LayeredConnectHandler { handler: H, middleware: M, - phantom: std::marker::PhantomData<(A, T, T1)>, + phantom: std::marker::PhantomData<(T, T1)>, } struct ConnectMiddlewareLayer { middleware: M, @@ -263,57 +270,58 @@ struct ConnectMiddlewareLayer { phantom: std::marker::PhantomData<(T, T1)>, } -impl MakeErasedHandler +impl MakeErasedHandler where - H: ConnectHandler + Send + Sync + 'static, + H: ConnectHandler + Send + Sync + 'static, T: Send + Sync + 'static, { - pub fn new_ns_boxed(inner: H) -> Box> { + pub fn new_ns_boxed(inner: H) -> Box { Box::new(MakeErasedHandler::new(inner)) } } -impl ErasedConnectHandler for MakeErasedHandler +impl ErasedConnectHandler for MakeErasedHandler where - H: ConnectHandler + Send + Sync + 'static, + H: ConnectHandler + Send + Sync + 'static, T: Send + Sync + 'static, { - fn call(&self, s: Arc>, auth: Option) { - self.handler.call(s, auth); + fn call(&self, s: Arc, auth: Option, state: Arc) { + self.handler.call(s, auth, state); } fn call_middleware<'a>( &'a self, - s: Arc>, + s: Arc, auth: &'a Option, + state: &'a Arc, ) -> MiddlewareResFut<'a> { - self.handler.call_middleware(s, auth) + self.handler.call_middleware(s, auth, state) } } -impl ConnectHandler for LayeredConnectHandler +impl ConnectHandler for LayeredConnectHandler where - A: Adapter, - H: ConnectHandler + Send + Sync + 'static, - M: ConnectMiddleware + Send + Sync + 'static, + H: ConnectHandler + Send + Sync + 'static, + M: ConnectMiddleware + Send + Sync + 'static, T: Send + Sync + 'static, T1: Send + Sync + 'static, { - fn call(&self, s: Arc>, auth: Option) { - self.handler.call(s, auth); + fn call(&self, s: Arc, auth: Option, state: Arc) { + self.handler.call(s, auth, state); } fn call_middleware<'a>( &'a self, - s: Arc>, + s: Arc, auth: &'a Option, + state: &'a Arc, ) -> MiddlewareResFut<'a> { - Box::pin(async move { self.middleware.call(s, auth).await }) + Box::pin(async move { self.middleware.call(s, auth, state).await }) } - fn with(self, next: M2) -> impl ConnectHandler + fn with(self, next: M2) -> impl ConnectHandler where - M2: ConnectMiddleware + Send + Sync + 'static, + M2: ConnectMiddleware + Send + Sync + 'static, T2: Send + Sync + 'static, { LayeredConnectHandler { @@ -327,30 +335,38 @@ where } } } -impl ConnectMiddleware for LayeredConnectHandler +impl ConnectMiddleware for LayeredConnectHandler where - A: Adapter, - H: ConnectHandler + Send + Sync + 'static, - N: ConnectMiddleware + Send + Sync + 'static, + H: ConnectHandler + Send + Sync + 'static, + N: ConnectMiddleware + Send + Sync + 'static, T: Send + Sync + 'static, T1: Send + Sync + 'static, { - async fn call<'a>(&'a self, s: Arc>, auth: &'a Option) -> MiddlewareRes { + async fn call<'a>( + &'a self, + s: Arc, + auth: &'a Option, + state: &'a Arc, + ) -> MiddlewareRes { self.middleware.call(s, auth).await } } -impl ConnectMiddleware for ConnectMiddlewareLayer +impl ConnectMiddleware for ConnectMiddlewareLayer where - A: Adapter, - M: ConnectMiddleware + Send + Sync + 'static, - N: ConnectMiddleware + Send + Sync + 'static, + M: ConnectMiddleware + Send + Sync + 'static, + N: ConnectMiddleware + Send + Sync + 'static, T: Send + Sync + 'static, T1: Send + Sync + 'static, { - async fn call<'a>(&'a self, s: Arc>, auth: &'a Option) -> MiddlewareRes { - self.middleware.call(s.clone(), auth).await?; - self.next.call(s, auth).await + async fn call<'a>( + &'a self, + s: Arc, + auth: &'a Option, + state: &'a Arc, + ) -> MiddlewareRes { + self.middleware.call(s.clone(), auth, state).await?; + self.next.call(s, auth, state).await } } @@ -366,16 +382,20 @@ macro_rules! impl_handler_async { [$($ty:ident),*] ) => { #[allow(non_snake_case, unused)] - impl ConnectHandler for F + impl ConnectHandler<(private::Async, $($ty,)*)> for F where F: FnOnce($($ty,)*) -> Fut + Send + Sync + Clone + 'static, Fut: Future + Send + 'static, - A: Adapter, - $( $ty: FromConnectParts + Send, )* + $( $ty: FromConnectParts + Send, )* { - fn call(&self, s: Arc>, auth: Option) { + fn call( + &self, + s: Arc, + auth: Option, + state: Arc, + ) { $( - let $ty = match $ty::from_connect_parts(&s, &auth) { + let $ty = match $ty::from_connect_parts(&s, &auth, &state) { Ok(v) => v, Err(_e) => { #[cfg(feature = "tracing")] @@ -398,15 +418,19 @@ macro_rules! impl_handler { [$($ty:ident),*] ) => { #[allow(non_snake_case, unused)] - impl ConnectHandler for F + impl ConnectHandler<(private::Sync, $($ty,)*)> for F where F: FnOnce($($ty,)*) + Send + Sync + Clone + 'static, - A: Adapter, - $( $ty: FromConnectParts + Send, )* + $( $ty: FromConnectParts + Send, )* { - fn call(&self, s: Arc>, auth: Option) { + fn call( + &self, + s: Arc, + auth: Option, + state: Arc, + ) { $( - let $ty = match $ty::from_connect_parts(&s, &auth) { + let $ty = match $ty::from_connect_parts(&s, &auth, &state) { Ok(v) => v, Err(_e) => { #[cfg(feature = "tracing")] @@ -427,17 +451,21 @@ macro_rules! impl_middleware_async { [$($ty:ident),*] ) => { #[allow(non_snake_case, unused)] - impl ConnectMiddleware for F + impl ConnectMiddleware<(private::Async, $($ty,)*)> for F where F: FnOnce($($ty,)*) -> Fut + Send + Sync + Clone + 'static, Fut: Future> + Send + 'static, - A: Adapter, E: std::fmt::Display + Send + 'static, - $( $ty: FromConnectParts + Send, )* + $( $ty: FromConnectParts + Send, )* { - async fn call<'a>(&'a self, s: Arc>, auth: &'a Option) -> MiddlewareRes { + async fn call<'a>( + &'a self, + s: Arc, + auth: &'a Option, + state: &'a Arc, + ) -> MiddlewareRes { $( - let $ty = match $ty::from_connect_parts(&s, &auth) { + let $ty = match $ty::from_connect_parts(&s, &auth, &state) { Ok(v) => v, Err(e) => { #[cfg(feature = "tracing")] @@ -465,16 +493,20 @@ macro_rules! impl_middleware { [$($ty:ident),*] ) => { #[allow(non_snake_case, unused)] - impl ConnectMiddleware for F + impl ConnectMiddleware<(private::Sync, $($ty,)*)> for F where F: FnOnce($($ty,)*) -> Result<(), E> + Send + Sync + Clone + 'static, - A: Adapter, E: std::fmt::Display + Send + 'static, - $( $ty: FromConnectParts + Send, )* + $( $ty: FromConnectParts + Send, )* { - async fn call<'a>(&'a self, s: Arc>, auth: &'a Option) -> MiddlewareRes { + async fn call<'a>( + &'a self, + s: Arc, + auth: &'a Option, + state: &'a Arc, + ) -> MiddlewareRes { $( - let $ty = match $ty::from_connect_parts(&s, &auth) { + let $ty = match $ty::from_connect_parts(&s, &auth, &state) { Ok(v) => v, Err(e) => { #[cfg(feature = "tracing")] diff --git a/socketioxide/src/handler/disconnect.rs b/socketioxide/src/handler/disconnect.rs index b63adb15..86a1df94 100644 --- a/socketioxide/src/handler/disconnect.rs +++ b/socketioxide/src/handler/disconnect.rs @@ -58,36 +58,33 @@ use std::sync::Arc; use futures_core::Future; -use crate::{ - adapter::Adapter, - socket::{DisconnectReason, Socket}, -}; +use crate::socket::{DisconnectReason, Socket}; use super::MakeErasedHandler; /// A Type Erased [`DisconnectHandler`] so it can be stored in a HashMap -pub(crate) type BoxedDisconnectHandler = Box>; -pub(crate) trait ErasedDisconnectHandler: Send + Sync + 'static { - fn call(&self, s: Arc>, reason: DisconnectReason); +pub(crate) type BoxedDisconnectHandler = Box; +pub(crate) trait ErasedDisconnectHandler: Send + Sync + 'static { + fn call(&self, s: Arc, reason: DisconnectReason); } -impl MakeErasedHandler +impl MakeErasedHandler where T: Send + Sync + 'static, - H: DisconnectHandler + Send + Sync + 'static, + H: DisconnectHandler + Send + Sync + 'static, { - pub fn new_disconnect_boxed(inner: H) -> Box> { + pub fn new_disconnect_boxed(inner: H) -> Box { Box::new(MakeErasedHandler::new(inner)) } } -impl ErasedDisconnectHandler for MakeErasedHandler +impl ErasedDisconnectHandler for MakeErasedHandler where - H: DisconnectHandler + Send + Sync + 'static, + H: DisconnectHandler + Send + Sync + 'static, T: Send + Sync + 'static, { #[inline(always)] - fn call(&self, s: Arc>, reason: DisconnectReason) { + fn call(&self, s: Arc, reason: DisconnectReason) { self.handler.call(s, reason); } } @@ -98,14 +95,14 @@ where /// /// * See the [`disconnect`](super::disconnect) module doc for more details on disconnect handler. /// * See the [`extract`](crate::extract) module doc for more details on available extractors. -pub trait FromDisconnectParts: Sized { +pub trait FromDisconnectParts: Sized { /// The error type returned by the extractor type Error: std::error::Error + 'static; /// Extract the arguments from the disconnect event. /// If it fails, the handler is not called fn from_disconnect_parts( - s: &Arc>, + s: &Arc, reason: DisconnectReason, ) -> Result; } @@ -115,9 +112,9 @@ pub trait FromDisconnectParts: Sized { /// /// * See the [`disconnect`](super::disconnect) module doc for more details on disconnect handler. /// * See the [`extract`](crate::extract) module doc for more details on available extractors. -pub trait DisconnectHandler: Send + Sync + 'static { +pub trait DisconnectHandler: Send + Sync + 'static { /// Call the handler with the given arguments. - fn call(&self, s: Arc>, reason: DisconnectReason); + fn call(&self, s: Arc, reason: DisconnectReason); #[doc(hidden)] fn phantom(&self) -> std::marker::PhantomData { @@ -137,14 +134,13 @@ macro_rules! impl_handler_async { [$($ty:ident),*] ) => { #[allow(non_snake_case, unused)] - impl DisconnectHandler for F + impl DisconnectHandler<(private::Async, $($ty,)*)> for F where F: FnOnce($($ty,)*) -> Fut + Send + Sync + Clone + 'static, Fut: Future + Send + 'static, - A: Adapter, - $( $ty: FromDisconnectParts + Send, )* + $( $ty: FromDisconnectParts + Send, )* { - fn call(&self, s: Arc>, reason: DisconnectReason) { + fn call(&self, s: Arc, reason: DisconnectReason) { $( let $ty = match $ty::from_disconnect_parts(&s, reason) { Ok(v) => v, @@ -169,13 +165,12 @@ macro_rules! impl_handler { [$($ty:ident),*] ) => { #[allow(non_snake_case, unused)] - impl DisconnectHandler for F + impl DisconnectHandler<(private::Sync, $($ty,)*)> for F where F: FnOnce($($ty,)*) + Send + Sync + Clone + 'static, - A: Adapter, - $( $ty: FromDisconnectParts + Send, )* + $( $ty: FromDisconnectParts + Send, )* { - fn call(&self, s: Arc>, reason: DisconnectReason) { + fn call(&self, s: Arc, reason: DisconnectReason) { $( let $ty = match $ty::from_disconnect_parts(&s, reason) { Ok(v) => v, diff --git a/socketioxide/src/handler/message.rs b/socketioxide/src/handler/message.rs index 5b4ab255..84bf3f80 100644 --- a/socketioxide/src/handler/message.rs +++ b/socketioxide/src/handler/message.rs @@ -77,16 +77,15 @@ use bytes::Bytes; use futures_core::Future; use serde_json::Value; -use crate::adapter::Adapter; use crate::socket::Socket; use super::MakeErasedHandler; /// A Type Erased [`MessageHandler`] so it can be stored in a HashMap -pub(crate) type BoxedMessageHandler = Box>; +pub(crate) type BoxedMessageHandler = Box; -pub(crate) trait ErasedMessageHandler: Send + Sync + 'static { - fn call(&self, s: Arc>, v: Value, p: Vec, ack_id: Option); +pub(crate) trait ErasedMessageHandler: Send + Sync + 'static { + fn call(&self, s: Arc, v: Value, p: Vec, ack_id: Option); } /// Define a handler for the connect event. @@ -100,9 +99,9 @@ pub(crate) trait ErasedMessageHandler: Send + Sync + 'static { note = "Function argument is not a valid socketio extractor. \nSee `https://docs.rs/socketioxide/latest/socketioxide/extract/index.html` for details", ) )] -pub trait MessageHandler: Send + Sync + 'static { +pub trait MessageHandler: Send + Sync + 'static { /// Call the handler with the given arguments - fn call(&self, s: Arc>, v: Value, p: Vec, ack_id: Option); + fn call(&self, s: Arc, v: Value, p: Vec, ack_id: Option); #[doc(hidden)] fn phantom(&self) -> std::marker::PhantomData { @@ -110,25 +109,23 @@ pub trait MessageHandler: Send + Sync + 'static { } } -impl MakeErasedHandler +impl MakeErasedHandler where T: Send + Sync + 'static, - H: MessageHandler, - A: Adapter, + H: MessageHandler, { - pub fn new_message_boxed(inner: H) -> Box> { + pub fn new_message_boxed(inner: H) -> Box { Box::new(MakeErasedHandler::new(inner)) } } -impl ErasedMessageHandler for MakeErasedHandler +impl ErasedMessageHandler for MakeErasedHandler where T: Send + Sync + 'static, - H: MessageHandler, - A: Adapter, + H: MessageHandler, { #[inline(always)] - fn call(&self, s: Arc>, v: Value, p: Vec, ack_id: Option) { + fn call(&self, s: Arc, v: Value, p: Vec, ack_id: Option) { self.handler.call(s, v, p, ack_id); } } @@ -157,14 +154,14 @@ mod private { note = "Function argument is not a valid socketio extractor. \nSee `https://docs.rs/socketioxide/latest/socketioxide/extract/index.html` for details", ) )] -pub trait FromMessageParts: Sized { +pub trait FromMessageParts: Sized { /// The error type returned by the extractor type Error: std::error::Error + 'static; /// Extract the arguments from the message event. /// If it fails, the handler is not called. fn from_message_parts( - s: &Arc>, + s: &Arc, v: &mut Value, p: &mut Vec, ack_id: &Option, @@ -182,14 +179,14 @@ pub trait FromMessageParts: Sized { note = "Function argument is not a valid socketio extractor. \nSee `https://docs.rs/socketioxide/latest/socketioxide/extract/index.html` for details", ) )] -pub trait FromMessage: Sized { +pub trait FromMessage: Sized { /// The error type returned by the extractor type Error: std::error::Error + 'static; /// Extract the arguments from the message event. /// If it fails, the handler is not called fn from_message( - s: Arc>, + s: Arc, v: Value, p: Vec, ack_id: Option, @@ -197,14 +194,13 @@ pub trait FromMessage: Sized { } /// All the types that implement [`FromMessageParts`] also implement [`FromMessage`] -impl FromMessage for T +impl FromMessage for T where - T: FromMessageParts, - A: Adapter, + T: FromMessageParts, { type Error = T::Error; fn from_message( - s: Arc>, + s: Arc, mut v: Value, mut p: Vec, ack_id: Option, @@ -214,25 +210,23 @@ where } /// Empty Async handler -impl MessageHandler for F +impl MessageHandler<(private::Async,)> for F where F: FnOnce() -> Fut + Send + Sync + Clone + 'static, Fut: Future + Send + 'static, - A: Adapter, { - fn call(&self, _: Arc>, _: Value, _: Vec, _: Option) { + fn call(&self, _: Arc, _: Value, _: Vec, _: Option) { let fut = (self.clone())(); tokio::spawn(fut); } } /// Empty Sync handler -impl MessageHandler for F +impl MessageHandler<(private::Sync,)> for F where F: FnOnce() + Send + Sync + Clone + 'static, - A: Adapter, { - fn call(&self, _: Arc>, _: Value, _: Vec, _: Option) { + fn call(&self, _: Arc, _: Value, _: Vec, _: Option) { (self.clone())(); } } @@ -242,15 +236,14 @@ macro_rules! impl_async_handler { [$($ty:ident),*], $last:ident ) => { #[allow(non_snake_case, unused)] - impl MessageHandler for F + impl MessageHandler<(private::Async, M, $($ty,)* $last,)> for F where F: FnOnce($($ty,)* $last,) -> Fut + Send + Sync + Clone + 'static, Fut: Future + Send + 'static, - A: Adapter, - $( $ty: FromMessageParts + Send, )* - $last: FromMessage + Send, + $( $ty: FromMessageParts + Send, )* + $last: FromMessage + Send, { - fn call(&self, s: Arc>, mut v: Value, mut p: Vec, ack_id: Option) { + fn call(&self, s: Arc, mut v: Value, mut p: Vec, ack_id: Option) { $( let $ty = match $ty::from_message_parts(&s, &mut v, &mut p, &ack_id) { Ok(v) => v, @@ -281,14 +274,13 @@ macro_rules! impl_handler { [$($ty:ident),*], $last:ident ) => { #[allow(non_snake_case, unused)] - impl MessageHandler for F + impl MessageHandler<(private::Sync, M, $($ty,)* $last,)> for F where F: FnOnce($($ty,)* $last,) + Send + Sync + Clone + 'static, - A: Adapter, - $( $ty: FromMessageParts + Send, )* - $last: FromMessage + Send, + $( $ty: FromMessageParts + Send, )* + $last: FromMessage + Send, { - fn call(&self, s: Arc>, mut v: Value, mut p: Vec, ack_id: Option) { + fn call(&self, s: Arc, mut v: Value, mut p: Vec, ack_id: Option) { $( let $ty = match $ty::from_message_parts(&s, &mut v, &mut p, &ack_id) { Ok(v) => v, diff --git a/socketioxide/src/handler/mod.rs b/socketioxide/src/handler/mod.rs index 9f51a6dd..2fbdc4fc 100644 --- a/socketioxide/src/handler/mod.rs +++ b/socketioxide/src/handler/mod.rs @@ -12,16 +12,14 @@ pub use disconnect::{DisconnectHandler, FromDisconnectParts}; pub(crate) use message::BoxedMessageHandler; pub use message::{FromMessage, FromMessageParts, MessageHandler}; /// A struct used to erase the type of a [`ConnectHandler`] or [`MessageHandler`] so it can be stored in a map -pub(crate) struct MakeErasedHandler { +pub(crate) struct MakeErasedHandler { handler: H, - adapter: std::marker::PhantomData, type_: std::marker::PhantomData, } -impl MakeErasedHandler { +impl MakeErasedHandler { pub fn new(handler: H) -> Self { Self { handler, - adapter: std::marker::PhantomData, type_: std::marker::PhantomData, } } diff --git a/socketioxide/src/io.rs b/socketioxide/src/io.rs index 3f4ef4c4..1eef7c34 100644 --- a/socketioxide/src/io.rs +++ b/socketioxide/src/io.rs @@ -17,11 +17,11 @@ use crate::{ layer::SocketIoLayer, operators::{BroadcastOperators, RoomParam}, service::SocketIoService, - BroadcastError, DisconnectError, + AdapterError, BroadcastError, DisconnectError, }; /// Configuration for Socket.IO & Engine.IO -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct SocketIoConfig { /// The inner Engine.IO config pub engine_config: EngineIoConfig, @@ -35,6 +35,9 @@ pub struct SocketIoConfig { /// /// Defaults to 45 seconds. pub connect_timeout: Duration, + + /// The adapter to use for this server. + pub adapter: Box, } impl Default for SocketIoConfig { @@ -46,6 +49,17 @@ impl Default for SocketIoConfig { }, ack_timeout: Duration::from_secs(5), connect_timeout: Duration::from_secs(45), + adapter: Box::new(LocalAdapter::new()), + } + } +} +impl Clone for SocketIoConfig { + fn clone(&self) -> Self { + Self { + engine_config: self.engine_config.clone(), + ack_timeout: self.ack_timeout, + connect_timeout: self.connect_timeout, + adapter: self.adapter.boxed_clone(), } } } @@ -53,19 +67,17 @@ impl Default for SocketIoConfig { /// A builder to create a [`SocketIo`] instance. /// It contains everything to configure the socket.io server with a [`SocketIoConfig`]. /// It can be used to build either a Tower [`Layer`](tower::layer::Layer) or a [`Service`](tower::Service). -pub struct SocketIoBuilder { +pub struct SocketIoBuilder { config: SocketIoConfig, engine_config_builder: EngineIoConfigBuilder, - adapter: std::marker::PhantomData, } -impl SocketIoBuilder { +impl SocketIoBuilder { /// Creates a new [`SocketIoBuilder`] with default config pub fn new() -> Self { Self { config: SocketIoConfig::default(), engine_config_builder: EngineIoConfigBuilder::new().req_path("/socket.io".to_string()), - adapter: std::marker::PhantomData, } } @@ -154,12 +166,9 @@ impl SocketIoBuilder { } /// Sets a custom [`Adapter`] for this [`SocketIoBuilder`] - pub fn with_adapter(self) -> SocketIoBuilder { - SocketIoBuilder { - config: self.config, - engine_config_builder: self.engine_config_builder, - adapter: std::marker::PhantomData, - } + pub fn with_adapter(mut self, adapter: impl Adapter) -> SocketIoBuilder { + self.config.adapter = Box::new(adapter); + self } /// Add a custom global state for the [`SocketIo`] instance. @@ -176,7 +185,7 @@ impl SocketIoBuilder { /// Builds a [`SocketIoLayer`] and a [`SocketIo`] instance /// /// The layer can be used as a tower layer - pub fn build_layer(mut self) -> (SocketIoLayer, SocketIo) { + pub fn build_layer(mut self) -> (SocketIoLayer, SocketIo) { self.config.engine_config = self.engine_config_builder.build(); let (layer, client) = SocketIoLayer::from_config(Arc::new(self.config)); @@ -215,9 +224,9 @@ impl Default for SocketIoBuilder { /// The [`SocketIo`] instance can be cheaply cloned and moved around everywhere in your program. /// It can be used as the main handle to access the whole socket.io context. #[derive(Debug)] -pub struct SocketIo(Arc>); +pub struct SocketIo(Arc); -impl SocketIo { +impl SocketIo { /// Creates a new [`SocketIoBuilder`] with a default config #[inline(always)] pub fn builder() -> SocketIoBuilder { @@ -247,7 +256,7 @@ impl SocketIo { } } -impl SocketIo { +impl SocketIo { /// Returns a reference to the [`SocketIoConfig`] used by this [`SocketIo`] instance #[inline] pub fn config(&self) -> &SocketIoConfig { @@ -336,7 +345,7 @@ impl SocketIo { #[inline] pub fn ns(&self, path: impl Into>, callback: C) where - C: ConnectHandler, + C: ConnectHandler, T: Send + Sync + 'static, { self.0.add_ns(path.into(), callback); @@ -382,7 +391,7 @@ impl SocketIo { /// println!("found socket on /custom_ns namespace with id: {}", socket.id); /// } #[inline] - pub fn of<'a>(&self, path: impl Into<&'a str>) -> Option> { + pub fn of<'a>(&self, path: impl Into<&'a str>) -> Option { self.get_op(path.into()) } @@ -408,7 +417,7 @@ impl SocketIo { /// println!("found socket on / ns in room1 with id: {}", socket.id); /// } #[inline] - pub fn to(&self, rooms: impl RoomParam) -> BroadcastOperators { + pub fn to(&self, rooms: impl RoomParam) -> BroadcastOperators { self.get_default_op().to(rooms) } @@ -436,7 +445,7 @@ impl SocketIo { /// println!("found socket on / ns in room1 with id: {}", socket.id); /// } #[inline] - pub fn within(&self, rooms: impl RoomParam) -> BroadcastOperators { + pub fn within(&self, rooms: impl RoomParam) -> BroadcastOperators { self.get_default_op().within(rooms) } @@ -469,7 +478,7 @@ impl SocketIo { /// println!("found socket on / ns in room1 with id: {}", socket.id); /// } #[inline] - pub fn except(&self, rooms: impl RoomParam) -> BroadcastOperators { + pub fn except(&self, rooms: impl RoomParam) -> BroadcastOperators { self.get_default_op().except(rooms) } @@ -496,7 +505,7 @@ impl SocketIo { /// println!("found socket on / ns in room1 with id: {}", socket.id); /// } #[inline] - pub fn local(&self) -> BroadcastOperators { + pub fn local(&self) -> BroadcastOperators { self.get_default_op().local() } @@ -539,7 +548,7 @@ impl SocketIo { /// } /// }); #[inline] - pub fn timeout(&self, timeout: Duration) -> BroadcastOperators { + pub fn timeout(&self, timeout: Duration) -> BroadcastOperators { self.get_default_op().timeout(timeout) } @@ -568,7 +577,7 @@ impl SocketIo { /// .bin(vec![Bytes::from_static(&[1, 2, 3, 4])]) /// .emit("test", ()); #[inline] - pub fn bin(&self, binary: impl IntoIterator>) -> BroadcastOperators { + pub fn bin(&self, binary: impl IntoIterator>) -> BroadcastOperators { self.get_default_op().bin(binary) } @@ -695,7 +704,7 @@ impl SocketIo { /// println!("found socket on / ns in room1 with id: {}", socket.id); /// } #[inline] - pub fn sockets(&self) -> Result>, A::Error> { + pub fn sockets(&self) -> Result, AdapterError> { self.get_default_op().sockets() } @@ -739,7 +748,7 @@ impl SocketIo { /// // Later in your code you can for example add all sockets on the root namespace to the room1 and room3 /// io.join(["room1", "room3"]).unwrap(); #[inline] - pub fn join(self, rooms: impl RoomParam) -> Result<(), A::Error> { + pub fn join(self, rooms: impl RoomParam) -> Result<(), AdapterError> { self.get_default_op().join(rooms) } @@ -760,7 +769,7 @@ impl SocketIo { /// let rooms = io2.rooms().unwrap(); /// println!("All rooms on / namespace: {:?}", rooms); /// }); - pub fn rooms(&self) -> Result, A::Error> { + pub fn rooms(&self) -> Result, AdapterError> { self.get_default_op().rooms() } @@ -782,19 +791,19 @@ impl SocketIo { /// // Later in your code you can for example remove all sockets on the root namespace from the room1 and room3 /// io.leave(["room1", "room3"]).unwrap(); #[inline] - pub fn leave(self, rooms: impl RoomParam) -> Result<(), A::Error> { + pub fn leave(self, rooms: impl RoomParam) -> Result<(), AdapterError> { self.get_default_op().leave(rooms) } /// Gets a [`SocketRef`] by the specified [`Sid`]. #[inline] - pub fn get_socket(&self, sid: Sid) -> Option> { + pub fn get_socket(&self, sid: Sid) -> Option { self.get_default_op().get_socket(sid) } /// Returns a new operator on the given namespace #[inline(always)] - fn get_op(&self, path: &str) -> Option> { + fn get_op(&self, path: &str) -> Option { self.0 .get_ns(path) .map(|ns| BroadcastOperators::new(ns).broadcast()) @@ -806,19 +815,19 @@ impl SocketIo { /// /// If the **default namespace "/" is not found** this fn will panic! #[inline(always)] - fn get_default_op(&self) -> BroadcastOperators { + fn get_default_op(&self) -> BroadcastOperators { self.get_op("/").expect("default namespace not found") } } -impl Clone for SocketIo { +impl Clone for SocketIo { fn clone(&self) -> Self { Self(self.0.clone()) } } #[cfg(any(test, socketioxide_test))] -impl SocketIo { +impl SocketIo { /// Create a dummy socket for testing purpose with a /// receiver to get the packets sent to the client pub async fn new_dummy_sock( diff --git a/socketioxide/src/layer.rs b/socketioxide/src/layer.rs index 1a656766..bafff484 100644 --- a/socketioxide/src/layer.rs +++ b/socketioxide/src/layer.rs @@ -20,19 +20,14 @@ use std::sync::Arc; use tower::Layer; -use crate::{ - adapter::{Adapter, LocalAdapter}, - client::Client, - service::SocketIoService, - SocketIoConfig, -}; +use crate::{client::Client, service::SocketIoService, SocketIoConfig}; /// A [`Layer`] for [`SocketIoService`], acting as a middleware. -pub struct SocketIoLayer { - client: Arc>, +pub struct SocketIoLayer { + client: Arc, } -impl Clone for SocketIoLayer { +impl Clone for SocketIoLayer { fn clone(&self) -> Self { Self { client: self.client.clone(), @@ -40,8 +35,8 @@ impl Clone for SocketIoLayer { } } -impl SocketIoLayer { - pub(crate) fn from_config(config: Arc) -> (Self, Arc>) { +impl SocketIoLayer { + pub(crate) fn from_config(config: Arc) -> (Self, Arc) { let client = Arc::new(Client::new(config.clone())); let layer = Self { client: client.clone(), @@ -50,8 +45,8 @@ impl SocketIoLayer { } } -impl Layer for SocketIoLayer { - type Service = SocketIoService; +impl Layer for SocketIoLayer { + type Service = SocketIoService; fn layer(&self, inner: S) -> Self::Service { SocketIoService::with_client(inner, self.client.clone()) diff --git a/socketioxide/src/ns.rs b/socketioxide/src/ns.rs index 7f810b8a..76d4d429 100644 --- a/socketioxide/src/ns.rs +++ b/socketioxide/src/ns.rs @@ -5,34 +5,41 @@ use std::{ }; use crate::{ - adapter::Adapter, - errors::{ConnectFail, Error}, + adapter::{Adapter, LocalAdapter}, + client::SocketData, + errors::{AdapterError, ConnectFail, Error}, handler::{BoxedConnectHandler, ConnectHandler, MakeErasedHandler}, packet::{Packet, PacketData}, socket::{DisconnectReason, Socket}, SocketIoConfig, }; -use crate::{client::SocketData, errors::AdapterError}; use engineioxide::sid::Sid; -pub struct Namespace { +pub struct Namespace { pub path: Cow<'static, str>, - pub(crate) adapter: A, - handler: BoxedConnectHandler, - sockets: RwLock>>>, + pub(crate) adapter: Box, + handler: BoxedConnectHandler, + sockets: RwLock>>, } -impl Namespace { - pub fn new(path: Cow<'static, str>, handler: C) -> Arc +impl Namespace { + pub fn new( + path: Cow<'static, str>, + handler: C, + mut adapter: Box, + ) -> Arc where - C: ConnectHandler + Send + Sync + 'static, + C: ConnectHandler + Send + Sync + 'static, T: Send + Sync + 'static, { - Arc::new_cyclic(|ns| Self { - path, - handler: MakeErasedHandler::new_ns_boxed(handler), - sockets: HashMap::new().into(), - adapter: A::new(ns.clone()), + Arc::new_cyclic(move |ns: &std::sync::Weak<_>| { + adapter.init(ns.clone()).ok(); + Self { + path, + handler: MakeErasedHandler::new_ns_boxed(handler), + sockets: HashMap::new().into(), + adapter, + } }) } @@ -49,9 +56,13 @@ impl Namespace { auth: Option, config: Arc, ) -> Result<(), ConnectFail> { - let socket: Arc> = Socket::new(sid, self.clone(), esocket.clone(), config).into(); + let socket: Arc = Socket::new(sid, self.clone(), esocket.clone(), config).into(); - if let Err(e) = self.handler.call_middleware(socket.clone(), &auth).await { + if let Err(e) = self + .handler + .call_middleware(socket.clone(), &auth, &self.state) + .await + { #[cfg(feature = "tracing")] tracing::trace!(ns = self.path.as_ref(), ?socket.id, "emitting connect_error packet"); @@ -78,7 +89,7 @@ impl Namespace { } socket.set_connected(true); - self.handler.call(socket, auth); + self.handler.call(socket, auth, self.state.clone()); Ok(()) } @@ -106,7 +117,7 @@ impl Namespace { } } - pub fn get_socket(&self, sid: Sid) -> Result>, Error> { + pub fn get_socket(&self, sid: Sid) -> Result, Error> { self.sockets .read() .unwrap() @@ -115,7 +126,7 @@ impl Namespace { .ok_or(Error::SocketGone(sid)) } - pub fn get_sockets(&self) -> Vec>> { + pub fn get_sockets(&self) -> Vec> { self.sockets.read().unwrap().values().cloned().collect() } @@ -159,9 +170,9 @@ impl Namespace { } #[cfg(any(test, socketioxide_test))] -impl Namespace { +impl Namespace { pub fn new_dummy(sockets: [Sid; S]) -> Arc { - let ns = Namespace::new(Cow::Borrowed("/"), || {}); + let ns = Namespace::new(Cow::Borrowed("/"), || {}, Box::new(LocalAdapter::new())); for sid in sockets { ns.sockets .write() @@ -176,7 +187,7 @@ impl Namespace { } } -impl std::fmt::Debug for Namespace { +impl std::fmt::Debug for Namespace { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Namespace") .field("path", &self.path) @@ -187,7 +198,7 @@ impl std::fmt::Debug for Namespace { } #[cfg(feature = "tracing")] -impl Drop for Namespace { +impl Drop for Namespace { fn drop(&mut self) { #[cfg(feature = "tracing")] tracing::debug!("dropping namespace {}", self.path); diff --git a/socketioxide/src/operators.rs b/socketioxide/src/operators.rs index 46a9754f..0b814433 100644 --- a/socketioxide/src/operators.rs +++ b/socketioxide/src/operators.rs @@ -13,13 +13,11 @@ use bytes::Bytes; use engineioxide::sid::Sid; use crate::ack::{AckInnerStream, AckStream}; -use crate::adapter::LocalAdapter; -use crate::errors::{BroadcastError, DisconnectError}; +use crate::errors::{AdapterError, BroadcastError, DisconnectError, SendError}; use crate::extract::SocketRef; use crate::socket::Socket; -use crate::SendError; use crate::{ - adapter::{Adapter, BroadcastFlags, BroadcastOptions, Room}, + adapter::{BroadcastFlags, BroadcastOptions, Room}, ns::Namespace, packet::Packet, }; @@ -103,21 +101,21 @@ impl RoomParam for Sid { } /// Chainable operators to configure the message to be sent. -pub struct ConfOperators<'a, A: Adapter = LocalAdapter> { +pub struct ConfOperators<'a> { binary: Vec, timeout: Option, - socket: &'a Socket, + socket: &'a Socket, } /// Chainable operators to select sockets to send a message to and to configure the message to be sent. -pub struct BroadcastOperators { +pub struct BroadcastOperators { binary: Vec, timeout: Option, - ns: Arc>, + ns: Arc, opts: BroadcastOptions, } -impl From> for BroadcastOperators { - fn from(conf: ConfOperators<'_, A>) -> Self { +impl From> for BroadcastOperators { + fn from(conf: ConfOperators<'_>) -> Self { let opts = BroadcastOptions { sid: Some(conf.socket.id), ..Default::default() @@ -132,8 +130,8 @@ impl From> for BroadcastOperators { } // ==== impl ConfOperators operations ==== -impl<'a, A: Adapter> ConfOperators<'a, A> { - pub(crate) fn new(sender: &'a Socket) -> Self { +impl<'a> ConfOperators<'a> { + pub(crate) fn new(sender: &'a Socket) -> Self { Self { binary: vec![], timeout: None, @@ -161,7 +159,7 @@ impl<'a, A: Adapter> ConfOperators<'a, A> { /// .emit("test", data); /// }); /// }); - pub fn to(self, rooms: impl RoomParam) -> BroadcastOperators { + pub fn to(self, rooms: impl RoomParam) -> BroadcastOperators { BroadcastOperators::from(self).to(rooms) } @@ -185,7 +183,7 @@ impl<'a, A: Adapter> ConfOperators<'a, A> { /// .emit("test", data); /// }); /// }); - pub fn within(self, rooms: impl RoomParam) -> BroadcastOperators { + pub fn within(self, rooms: impl RoomParam) -> BroadcastOperators { BroadcastOperators::from(self).within(rooms) } @@ -208,7 +206,7 @@ impl<'a, A: Adapter> ConfOperators<'a, A> { /// socket.broadcast().except("room1").emit("test", data); /// }); /// }); - pub fn except(self, rooms: impl RoomParam) -> BroadcastOperators { + pub fn except(self, rooms: impl RoomParam) -> BroadcastOperators { BroadcastOperators::from(self).except(rooms) } @@ -225,7 +223,7 @@ impl<'a, A: Adapter> ConfOperators<'a, A> { /// socket.local().emit("test", data); /// }); /// }); - pub fn local(self) -> BroadcastOperators { + pub fn local(self) -> BroadcastOperators { BroadcastOperators::from(self).local() } @@ -241,7 +239,7 @@ impl<'a, A: Adapter> ConfOperators<'a, A> { /// socket.broadcast().emit("test", data); /// }); /// }); - pub fn broadcast(self) -> BroadcastOperators { + pub fn broadcast(self) -> BroadcastOperators { BroadcastOperators::from(self).broadcast() } @@ -304,7 +302,7 @@ impl<'a, A: Adapter> ConfOperators<'a, A> { } // ==== impl ConfOperators consume fns ==== -impl ConfOperators<'_, A> { +impl ConfOperators<'_> { /// Emits a message to the client and apply the previous operators on the message. /// /// If you provide array-like data (tuple, vec, arrays), it will be considered as multiple arguments. @@ -452,7 +450,7 @@ impl ConfOperators<'_, A> { /// socket.within("room1").within("room3").join(["room4", "room5"]).unwrap(); /// }); /// }); - pub fn join(self, rooms: impl RoomParam) -> Result<(), A::Error> { + pub fn join(self, rooms: impl RoomParam) -> Result<(), AdapterError> { self.socket.join(rooms) } @@ -468,12 +466,12 @@ impl ConfOperators<'_, A> { /// socket.within("room1").within("room3").leave(["room4", "room5"]).unwrap(); /// }); /// }); - pub fn leave(self, rooms: impl RoomParam) -> Result<(), A::Error> { + pub fn leave(self, rooms: impl RoomParam) -> Result<(), AdapterError> { self.socket.leave(rooms) } /// Gets all room names for a given namespace - pub fn rooms(self) -> Result, A::Error> { + pub fn rooms(self) -> Result, AdapterError> { self.socket.rooms() } @@ -495,8 +493,8 @@ impl ConfOperators<'_, A> { } } -impl BroadcastOperators { - pub(crate) fn new(ns: Arc>) -> Self { +impl BroadcastOperators { + pub(crate) fn new(ns: Arc) -> Self { Self { binary: vec![], timeout: None, @@ -504,7 +502,7 @@ impl BroadcastOperators { opts: BroadcastOptions::default(), } } - pub(crate) fn from_sock(ns: Arc>, sid: Sid) -> Self { + pub(crate) fn from_sock(ns: Arc, sid: Sid) -> Self { Self { binary: vec![], timeout: None, @@ -684,7 +682,7 @@ impl BroadcastOperators { } // ==== impl BroadcastOperators consume fns ==== -impl BroadcastOperators { +impl BroadcastOperators { /// Emits a message to all sockets selected with the previous operators. /// /// If you provide array-like data (tuple, vec, arrays), it will be considered as multiple arguments. @@ -826,7 +824,7 @@ impl BroadcastOperators { /// } /// }); /// }); - pub fn sockets(self) -> Result>, A::Error> { + pub fn sockets(self) -> Result, AdapterError> { self.ns.adapter.fetch_sockets(self.opts) } @@ -858,8 +856,10 @@ impl BroadcastOperators { /// socket.within("room1").within("room3").join(["room4", "room5"]).unwrap(); /// }); /// }); - pub fn join(self, rooms: impl RoomParam) -> Result<(), A::Error> { - self.ns.adapter.add_sockets(self.opts, rooms) + pub fn join(self, rooms: impl RoomParam) -> Result<(), AdapterError> { + self.ns + .adapter + .add_sockets(self.opts, rooms.into_room_iter().collect()) } /// Makes all sockets selected with the previous operators leave the given room(s). @@ -874,17 +874,19 @@ impl BroadcastOperators { /// socket.within("room1").within("room3").leave(["room4", "room5"]).unwrap(); /// }); /// }); - pub fn leave(self, rooms: impl RoomParam) -> Result<(), A::Error> { - self.ns.adapter.del_sockets(self.opts, rooms) + pub fn leave(self, rooms: impl RoomParam) -> Result<(), AdapterError> { + self.ns + .adapter + .del_sockets(self.opts, rooms.into_room_iter().collect()) } /// Gets all room names for a given namespace - pub fn rooms(self) -> Result, A::Error> { + pub fn rooms(self) -> Result, AdapterError> { self.ns.adapter.rooms() } /// Gets a [`SocketRef`] by the specified [`Sid`]. - pub fn get_socket(&self, sid: Sid) -> Option> { + pub fn get_socket(&self, sid: Sid) -> Option { self.ns.get_socket(sid).map(SocketRef::from).ok() } diff --git a/socketioxide/src/service.rs b/socketioxide/src/service.rs index 91bfda74..078ba790 100644 --- a/socketioxide/src/service.rs +++ b/socketioxide/src/service.rs @@ -53,31 +53,26 @@ use std::{ }; use tower::Service as TowerSvc; -use crate::{ - adapter::{Adapter, LocalAdapter}, - client::Client, - SocketIoConfig, -}; +use crate::{client::Client, SocketIoConfig}; /// A [`Tower`](TowerSvc)/[`Hyper`](HyperSvc) Service that wraps [`EngineIoService`] and /// redirect every request to it -pub struct SocketIoService { - engine_svc: EngineIoService>, S>, +pub struct SocketIoService { + engine_svc: EngineIoService, S>, } /// Tower Service implementation. -impl TowerSvc> for SocketIoService +impl TowerSvc> for SocketIoService where ReqBody: Body + Send + Unpin + std::fmt::Debug + 'static, ::Error: std::fmt::Debug, ::Data: Send, ResBody: Body + Send + 'static, S: TowerSvc, Response = Response> + Clone, - A: Adapter, { - type Response = >, S> as TowerSvc>>::Response; - type Error = >, S> as TowerSvc>>::Error; - type Future = >, S> as TowerSvc>>::Future; + type Response = , S> as TowerSvc>>::Response; + type Error = , S> as TowerSvc>>::Error; + type Future = , S> as TowerSvc>>::Future; #[inline(always)] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { @@ -90,18 +85,17 @@ where } /// Hyper 1.0 Service implementation. -impl HyperSvc> for SocketIoService +impl HyperSvc> for SocketIoService where ReqBody: Body + Send + Unpin + std::fmt::Debug + 'static, ::Error: std::fmt::Debug, ::Data: Send, ResBody: Body + Send + 'static, S: HyperSvc, Response = Response> + Clone, - A: Adapter, { - type Response = >, S> as HyperSvc>>::Response; - type Error = >, S> as HyperSvc>>::Error; - type Future = >, S> as HyperSvc>>::Future; + type Response = , S> as HyperSvc>>::Response; + type Error = , S> as HyperSvc>>::Error; + type Future = , S> as HyperSvc>>::Future; #[inline(always)] fn call(&self, req: Request) -> Self::Future { @@ -109,18 +103,15 @@ where } } -impl SocketIoService { +impl SocketIoService { /// Creates a MakeService which can be used as a hyper service #[inline(always)] - pub fn into_make_service(self) -> MakeEngineIoService>, S> { + pub fn into_make_service(self) -> MakeEngineIoService, S> { self.engine_svc.into_make_service() } /// Creates a new [`EngineIoService`] with a custom inner service and a custom config. - pub(crate) fn with_config_inner( - inner: S, - config: Arc, - ) -> (Self, Arc>) { + pub(crate) fn with_config_inner(inner: S, config: Arc) -> (Self, Arc) { let engine_config = config.engine_config.clone(); let client = Arc::new(Client::new(config)); let svc = EngineIoService::with_config_inner(inner, client.clone(), engine_config); @@ -129,14 +120,14 @@ impl SocketIoService { /// Creates a new [`EngineIoService`] with a custom inner service and an existing client /// It is mainly used with a [`SocketIoLayer`](crate::layer::SocketIoLayer) that owns the client - pub(crate) fn with_client(inner: S, client: Arc>) -> Self { + pub(crate) fn with_client(inner: S, client: Arc) -> Self { let engine_config = client.config.engine_config.clone(); let svc = EngineIoService::with_config_inner(inner, client, engine_config); Self { engine_svc: svc } } } -impl Clone for SocketIoService { +impl Clone for SocketIoService { fn clone(&self) -> Self { Self { engine_svc: self.engine_svc.clone(), diff --git a/socketioxide/src/socket.rs b/socketioxide/src/socket.rs index b082d8c0..0d14a4ac 100644 --- a/socketioxide/src/socket.rs +++ b/socketioxide/src/socket.rs @@ -23,7 +23,7 @@ use crate::extensions::Extensions; use crate::{ ack::{AckInnerStream, AckResponse, AckResult, AckStream}, - adapter::{Adapter, LocalAdapter, Room}, + adapter::Room, errors::{DisconnectError, Error, SendError}, handler::{ BoxedDisconnectHandler, BoxedMessageHandler, DisconnectHandler, MakeErasedHandler, @@ -127,11 +127,11 @@ impl<'a> PermitExt<'a> for Permit<'a> { /// A Socket represents a client connected to a namespace. /// It is used to send and receive messages from the client, join and leave rooms, etc. /// The socket struct itself should not be used directly, but through a [`SocketRef`](crate::extract::SocketRef). -pub struct Socket { +pub struct Socket { pub(crate) config: Arc, - pub(crate) ns: Arc>, - message_handlers: RwLock, BoxedMessageHandler>>, - disconnect_handler: Mutex>>, + pub(crate) ns: Arc, + message_handlers: RwLock, BoxedMessageHandler>>, + disconnect_handler: Mutex>, ack_message: Mutex>>>, ack_counter: AtomicI64, connected: AtomicBool, @@ -149,10 +149,10 @@ pub struct Socket { esocket: Arc>, } -impl Socket { +impl Socket { pub(crate) fn new( sid: Sid, - ns: Arc>, + ns: Arc, esocket: Arc>, config: Arc, ) -> Self { @@ -225,7 +225,7 @@ impl Socket { /// ``` pub fn on(&self, event: impl Into>, handler: H) where - H: MessageHandler, + H: MessageHandler, T: Send + Sync + 'static, { self.message_handlers @@ -259,7 +259,7 @@ impl Socket { /// }); pub fn on_disconnect(&self, callback: C) where - C: DisconnectHandler + Send + Sync + 'static, + C: DisconnectHandler + Send + Sync + 'static, T: Send + Sync + 'static, { let handler = MakeErasedHandler::new_disconnect_boxed(callback); @@ -411,8 +411,10 @@ impl Socket { /// ## Errors /// When using a distributed adapter, it can return an [`Adapter::Error`] which is mostly related to network errors. /// For the default [`LocalAdapter`] it is always an [`Infallible`](std::convert::Infallible) error - pub fn join(&self, rooms: impl RoomParam) -> Result<(), A::Error> { - self.ns.adapter.add_all(self.id, rooms) + pub fn join(&self, rooms: impl RoomParam) -> Result<(), AdapterError> { + self.ns + .adapter + .add_all(self.id, rooms.into_room_iter().collect()) } /// Leaves the given rooms. @@ -421,15 +423,17 @@ impl Socket { /// ## Errors /// When using a distributed adapter, it can return an [`Adapter::Error`] which is mostly related to network errors. /// For the default [`LocalAdapter`] it is always an [`Infallible`](std::convert::Infallible) error - pub fn leave(&self, rooms: impl RoomParam) -> Result<(), A::Error> { - self.ns.adapter.del(self.id, rooms) + pub fn leave(&self, rooms: impl RoomParam) -> Result<(), AdapterError> { + self.ns + .adapter + .del(self.id, rooms.into_room_iter().collect()) } /// Leaves all rooms where the socket is connected. /// ## Errors /// When using a distributed adapter, it can return an [`Adapter::Error`] which is mostly related to network errors. /// For the default [`LocalAdapter`] it is always an [`Infallible`](std::convert::Infallible) error - pub fn leave_all(&self) -> Result<(), A::Error> { + pub fn leave_all(&self) -> Result<(), AdapterError> { self.ns.adapter.del_all(self.id) } @@ -437,7 +441,7 @@ impl Socket { /// ## Errors /// When using a distributed adapter, it can return an [`Adapter::Error`] which is mostly related to network errors. /// For the default [`LocalAdapter`] it is always an [`Infallible`](std::convert::Infallible) error - pub fn rooms(&self) -> Result, A::Error> { + pub fn rooms(&self) -> Result, AdapterError> { self.ns.adapter.socket_rooms(self.id) } @@ -471,7 +475,7 @@ impl Socket { /// .emit("test", data); /// }); /// }); - pub fn to(&self, rooms: impl RoomParam) -> BroadcastOperators { + pub fn to(&self, rooms: impl RoomParam) -> BroadcastOperators { BroadcastOperators::from_sock(self.ns.clone(), self.id).to(rooms) } @@ -495,7 +499,7 @@ impl Socket { /// .emit("test", data); /// }); /// }); - pub fn within(&self, rooms: impl RoomParam) -> BroadcastOperators { + pub fn within(&self, rooms: impl RoomParam) -> BroadcastOperators { BroadcastOperators::from_sock(self.ns.clone(), self.id).within(rooms) } @@ -519,7 +523,7 @@ impl Socket { /// socket.broadcast().except("room1").emit("test", data); /// }); /// }); - pub fn except(&self, rooms: impl RoomParam) -> BroadcastOperators { + pub fn except(&self, rooms: impl RoomParam) -> BroadcastOperators { BroadcastOperators::from_sock(self.ns.clone(), self.id).except(rooms) } @@ -537,7 +541,7 @@ impl Socket { /// socket.local().emit("test", data); /// }); /// }); - pub fn local(&self) -> BroadcastOperators { + pub fn local(&self) -> BroadcastOperators { BroadcastOperators::from_sock(self.ns.clone(), self.id).local() } @@ -576,7 +580,7 @@ impl Socket { /// }); /// }); /// - pub fn timeout(&self, timeout: Duration) -> ConfOperators<'_, A> { + pub fn timeout(&self, timeout: Duration) -> ConfOperators<'_> { ConfOperators::new(self).timeout(timeout) } @@ -593,7 +597,7 @@ impl Socket { /// socket.bin(bin).emit("test", data); /// }); /// }); - pub fn bin(&self, binary: impl IntoIterator>) -> ConfOperators<'_, A> { + pub fn bin(&self, binary: impl IntoIterator>) -> ConfOperators<'_> { ConfOperators::new(self).bin(binary) } @@ -610,7 +614,7 @@ impl Socket { /// socket.broadcast().emit("test", data); /// }); /// }); - pub fn broadcast(&self) -> BroadcastOperators { + pub fn broadcast(&self) -> BroadcastOperators { BroadcastOperators::from_sock(self.ns.clone(), self.id).broadcast() } @@ -798,7 +802,7 @@ impl Socket { } } -impl Debug for Socket { +impl Debug for Socket { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Socket") .field("ns", &self.ns()) @@ -808,16 +812,16 @@ impl Debug for Socket { .finish() } } -impl PartialEq for Socket { +impl PartialEq for Socket { fn eq(&self, other: &Self) -> bool { self.id == other.id } } #[cfg(any(test, socketioxide_test))] -impl Socket { +impl Socket { /// Creates a dummy socket for testing purposes - pub fn new_dummy(sid: Sid, ns: Arc>) -> Socket { + pub fn new_dummy(sid: Sid, ns: Arc) -> Socket { let close_fn = Box::new(move |_, _| ()); let s = Socket::new( sid, @@ -837,7 +841,7 @@ mod test { #[tokio::test] async fn send_with_ack_error() { let sid = Sid::new(); - let ns = Namespace::::new_dummy([sid]).into(); + let ns = Namespace::new_dummy([sid]).into(); let socket: Arc = Socket::new_dummy(sid, ns).into(); // Saturate the channel for _ in 0..1024 { diff --git a/socketioxide/tests/fixture.rs b/socketioxide/tests/fixture.rs index 04d23dc0..7d127a12 100644 --- a/socketioxide/tests/fixture.rs +++ b/socketioxide/tests/fixture.rs @@ -111,7 +111,7 @@ pub async fn create_server(port: u16) -> SocketIo { io } -async fn spawn_server(port: u16, svc: SocketIoService) { +async fn spawn_server(port: u16, svc: SocketIoService) { let addr = &SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port); let listener = TcpListener::bind(&addr).await.unwrap(); tokio::spawn(async move {