From f0e15963b74f3ddcfd6249aa1e76895824d7fdef Mon Sep 17 00:00:00 2001 From: Chris Martinez Date: Sun, 24 Dec 2023 11:06:00 -0800 Subject: [PATCH] Retain generic state parameter --- src/Cargo.toml | 2 +- src/inject.rs | 48 +++++++++++++++++++++++++++++++++------------ src/inject_keyed.rs | 48 +++++++++++++++++++++++++++++++++------------ 3 files changed, 73 insertions(+), 25 deletions(-) diff --git a/src/Cargo.toml b/src/Cargo.toml index d90d99c..6318842 100644 --- a/src/Cargo.toml +++ b/src/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "more-di-axum" -version = "0.1.0" +version = "0.2.0" edition = "2021" authors = ["Chris Martinez "] description = "Provides support dependency injection (DI) for Axum" diff --git a/src/inject.rs b/src/inject.rs index d92f5bf..1879029 100644 --- a/src/inject.rs +++ b/src/inject.rs @@ -37,10 +37,14 @@ fn unregistered_type() -> String { } #[async_trait] -impl FromRequestParts<()> for TryInject { +impl FromRequestParts for TryInject +where + T: ?Sized + 'static, + S: Send + Sync, +{ type Rejection = Infallible; - async fn from_request_parts(parts: &mut Parts, _state: &()) -> Result { + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { if let Some(provider) = parts.extensions.get::() { Ok(Self(provider.get::())) } else { @@ -50,10 +54,14 @@ impl FromRequestParts<()> for TryInject { } #[async_trait] -impl FromRequestParts<()> for Inject { +impl FromRequestParts for Inject +where + T: ?Sized + 'static, + S: Send + Sync, +{ type Rejection = (StatusCode, String); - async fn from_request_parts(parts: &mut Parts, _state: &()) -> Result { + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { if let Some(provider) = parts.extensions.get::() { if let Some(service) = provider.get::() { return Ok(Self(service)); @@ -65,10 +73,14 @@ impl FromRequestParts<()> for Inject { } #[async_trait] -impl FromRequestParts<()> for TryInjectMut { +impl FromRequestParts for TryInjectMut +where + T: ?Sized + 'static, + S: Send + Sync, +{ type Rejection = Infallible; - async fn from_request_parts(parts: &mut Parts, _state: &()) -> Result { + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { if let Some(provider) = parts.extensions.get::() { Ok(Self(provider.get_mut::())) } else { @@ -78,10 +90,14 @@ impl FromRequestParts<()> for TryInjectMut { } #[async_trait] -impl FromRequestParts<()> for InjectMut { +impl FromRequestParts for InjectMut +where + T: ?Sized + 'static, + S: Send + Sync, +{ type Rejection = (StatusCode, String); - async fn from_request_parts(parts: &mut Parts, _state: &()) -> Result { + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { if let Some(provider) = parts.extensions.get::() { if let Some(service) = provider.get_mut::() { return Ok(Self(service)); @@ -93,10 +109,14 @@ impl FromRequestParts<()> for InjectMut { } #[async_trait] -impl FromRequestParts<()> for InjectAll { +impl FromRequestParts for InjectAll +where + T: ?Sized + 'static, + S: Send + Sync, +{ type Rejection = Infallible; - async fn from_request_parts(parts: &mut Parts, _state: &()) -> Result { + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { if let Some(provider) = parts.extensions.get::() { Ok(Self(provider.get_all::().collect())) } else { @@ -106,10 +126,14 @@ impl FromRequestParts<()> for InjectAll { } #[async_trait] -impl FromRequestParts<()> for InjectAllMut { +impl FromRequestParts for InjectAllMut +where + T: ?Sized + 'static, + S: Send + Sync, +{ type Rejection = Infallible; - async fn from_request_parts(parts: &mut Parts, _state: &()) -> Result { + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { if let Some(provider) = parts.extensions.get::() { Ok(Self(provider.get_all_mut::().collect())) } else { diff --git a/src/inject_keyed.rs b/src/inject_keyed.rs index 214e23a..2401a31 100644 --- a/src/inject_keyed.rs +++ b/src/inject_keyed.rs @@ -38,10 +38,14 @@ fn unregistered_type_with_key() -> String { } #[async_trait] -impl FromRequestParts<()> for TryInjectWithKey { +impl FromRequestParts for TryInjectWithKey +where + TSvc: ?Sized + 'static, + S: Send + Sync, +{ type Rejection = Infallible; - async fn from_request_parts(parts: &mut Parts, _state: &()) -> Result { + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { if let Some(provider) = parts.extensions.get::() { Ok(Self(provider.get_by_key::())) } else { @@ -51,10 +55,14 @@ impl FromRequestParts<()> for TryInjectWithKey FromRequestParts<()> for InjectWithKey { +impl FromRequestParts for InjectWithKey +where + TSvc: ?Sized + 'static, + S: Send + Sync, +{ type Rejection = (StatusCode, String); - async fn from_request_parts(parts: &mut Parts, _state: &()) -> Result { + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { if let Some(provider) = parts.extensions.get::() { if let Some(service) = provider.get_by_key::() { return Ok(Self(service)); @@ -69,10 +77,14 @@ impl FromRequestParts<()> for InjectWithKey FromRequestParts<()> for TryInjectWithKeyMut { +impl FromRequestParts for TryInjectWithKeyMut +where + TSvc: ?Sized + 'static, + S: Send + Sync, +{ type Rejection = Infallible; - async fn from_request_parts(parts: &mut Parts, _state: &()) -> Result { + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { if let Some(provider) = parts.extensions.get::() { Ok(Self(provider.get_by_key_mut::())) } else { @@ -82,10 +94,14 @@ impl FromRequestParts<()> for TryInjectWithKeyMut< } #[async_trait] -impl FromRequestParts<()> for InjectWithKeyMut { +impl FromRequestParts for InjectWithKeyMut +where + TSvc: ?Sized + 'static, + S: Send + Sync, +{ type Rejection = (StatusCode, String); - async fn from_request_parts(parts: &mut Parts, _state: &()) -> Result { + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { if let Some(provider) = parts.extensions.get::() { if let Some(service) = provider.get_by_key_mut::() { return Ok(Self(service)); @@ -100,10 +116,14 @@ impl FromRequestParts<()> for InjectWithKeyMut FromRequestParts<()> for InjectAllWithKey { +impl FromRequestParts for InjectAllWithKey +where + TSvc: ?Sized + 'static, + S: Send + Sync, +{ type Rejection = Infallible; - async fn from_request_parts(parts: &mut Parts, _state: &()) -> Result { + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { if let Some(provider) = parts.extensions.get::() { Ok(Self(provider.get_all_by_key::().collect())) } else { @@ -113,10 +133,14 @@ impl FromRequestParts<()> for InjectAllWithKey FromRequestParts<()> for InjectAllWithKeyMut { +impl FromRequestParts for InjectAllWithKeyMut +where + TSvc: ?Sized + 'static, + S: Send + Sync, +{ type Rejection = Infallible; - async fn from_request_parts(parts: &mut Parts, _state: &()) -> Result { + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { if let Some(provider) = parts.extensions.get::() { Ok(Self(provider.get_all_by_key_mut::().collect())) } else {