Skip to content

Commit

Permalink
Retain generic state parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
commonsensesoftware committed Dec 24, 2023
1 parent 3b1411f commit f0e1596
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 25 deletions.
2 changes: 1 addition & 1 deletion src/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "more-di-axum"
version = "0.1.0"
version = "0.2.0"
edition = "2021"
authors = ["Chris Martinez <[email protected]>"]
description = "Provides support dependency injection (DI) for Axum"
Expand Down
48 changes: 36 additions & 12 deletions src/inject.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,14 @@ fn unregistered_type<T: ?Sized>() -> String {
}

#[async_trait]
impl<T: ?Sized + 'static> FromRequestParts<()> for TryInject<T> {
impl<T, S> FromRequestParts<S> for TryInject<T>
where
T: ?Sized + 'static,
S: Send + Sync,
{
type Rejection = Infallible;

async fn from_request_parts(parts: &mut Parts, _state: &()) -> Result<Self, Self::Rejection> {
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if let Some(provider) = parts.extensions.get::<ServiceProvider>() {
Ok(Self(provider.get::<T>()))
} else {
Expand All @@ -50,10 +54,14 @@ impl<T: ?Sized + 'static> FromRequestParts<()> for TryInject<T> {
}

#[async_trait]
impl<T: ?Sized + 'static> FromRequestParts<()> for Inject<T> {
impl<T, S> FromRequestParts<S> for Inject<T>
where
T: ?Sized + 'static,
S: Send + Sync,
{
type Rejection = (StatusCode, String);

async fn from_request_parts(parts: &mut Parts, _state: &()) -> Result<Self, Self::Rejection> {
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if let Some(provider) = parts.extensions.get::<ServiceProvider>() {
if let Some(service) = provider.get::<T>() {
return Ok(Self(service));
Expand All @@ -65,10 +73,14 @@ impl<T: ?Sized + 'static> FromRequestParts<()> for Inject<T> {
}

#[async_trait]
impl<T: ?Sized + 'static> FromRequestParts<()> for TryInjectMut<T> {
impl<T, S> FromRequestParts<S> for TryInjectMut<T>
where
T: ?Sized + 'static,
S: Send + Sync,
{
type Rejection = Infallible;

async fn from_request_parts(parts: &mut Parts, _state: &()) -> Result<Self, Self::Rejection> {
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if let Some(provider) = parts.extensions.get::<ServiceProvider>() {
Ok(Self(provider.get_mut::<T>()))
} else {
Expand All @@ -78,10 +90,14 @@ impl<T: ?Sized + 'static> FromRequestParts<()> for TryInjectMut<T> {
}

#[async_trait]
impl<T: ?Sized + 'static> FromRequestParts<()> for InjectMut<T> {
impl<T, S> FromRequestParts<S> for InjectMut<T>
where
T: ?Sized + 'static,
S: Send + Sync,
{
type Rejection = (StatusCode, String);

async fn from_request_parts(parts: &mut Parts, _state: &()) -> Result<Self, Self::Rejection> {
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if let Some(provider) = parts.extensions.get::<ServiceProvider>() {
if let Some(service) = provider.get_mut::<T>() {
return Ok(Self(service));
Expand All @@ -93,10 +109,14 @@ impl<T: ?Sized + 'static> FromRequestParts<()> for InjectMut<T> {
}

#[async_trait]
impl<T: ?Sized + 'static> FromRequestParts<()> for InjectAll<T> {
impl<T, S> FromRequestParts<S> for InjectAll<T>
where
T: ?Sized + 'static,
S: Send + Sync,
{
type Rejection = Infallible;

async fn from_request_parts(parts: &mut Parts, _state: &()) -> Result<Self, Self::Rejection> {
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if let Some(provider) = parts.extensions.get::<ServiceProvider>() {
Ok(Self(provider.get_all::<T>().collect()))
} else {
Expand All @@ -106,10 +126,14 @@ impl<T: ?Sized + 'static> FromRequestParts<()> for InjectAll<T> {
}

#[async_trait]
impl<T: ?Sized + 'static> FromRequestParts<()> for InjectAllMut<T> {
impl<T, S> FromRequestParts<S> for InjectAllMut<T>
where
T: ?Sized + 'static,
S: Send + Sync,
{
type Rejection = Infallible;

async fn from_request_parts(parts: &mut Parts, _state: &()) -> Result<Self, Self::Rejection> {
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if let Some(provider) = parts.extensions.get::<ServiceProvider>() {
Ok(Self(provider.get_all_mut::<T>().collect()))
} else {
Expand Down
48 changes: 36 additions & 12 deletions src/inject_keyed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,14 @@ fn unregistered_type_with_key<TKey, TSvc: ?Sized>() -> String {
}

#[async_trait]
impl<TKey, TSvc: ?Sized + 'static> FromRequestParts<()> for TryInjectWithKey<TKey, TSvc> {
impl<TKey, TSvc, S> FromRequestParts<S> for TryInjectWithKey<TKey, TSvc>
where
TSvc: ?Sized + 'static,
S: Send + Sync,
{
type Rejection = Infallible;

async fn from_request_parts(parts: &mut Parts, _state: &()) -> Result<Self, Self::Rejection> {
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if let Some(provider) = parts.extensions.get::<ServiceProvider>() {
Ok(Self(provider.get_by_key::<TKey, TSvc>()))
} else {
Expand All @@ -51,10 +55,14 @@ impl<TKey, TSvc: ?Sized + 'static> FromRequestParts<()> for TryInjectWithKey<TKe
}

#[async_trait]
impl<TKey, TSvc: ?Sized + 'static> FromRequestParts<()> for InjectWithKey<TKey, TSvc> {
impl<TKey, TSvc, S> FromRequestParts<S> for InjectWithKey<TKey, TSvc>
where
TSvc: ?Sized + 'static,
S: Send + Sync,
{
type Rejection = (StatusCode, String);

async fn from_request_parts(parts: &mut Parts, _state: &()) -> Result<Self, Self::Rejection> {
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if let Some(provider) = parts.extensions.get::<ServiceProvider>() {
if let Some(service) = provider.get_by_key::<TKey, TSvc>() {
return Ok(Self(service));
Expand All @@ -69,10 +77,14 @@ impl<TKey, TSvc: ?Sized + 'static> FromRequestParts<()> for InjectWithKey<TKey,
}

#[async_trait]
impl<TKey, TSvc: ?Sized + 'static> FromRequestParts<()> for TryInjectWithKeyMut<TKey, TSvc> {
impl<TKey, TSvc, S> FromRequestParts<S> for TryInjectWithKeyMut<TKey, TSvc>
where
TSvc: ?Sized + 'static,
S: Send + Sync,
{
type Rejection = Infallible;

async fn from_request_parts(parts: &mut Parts, _state: &()) -> Result<Self, Self::Rejection> {
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if let Some(provider) = parts.extensions.get::<ServiceProvider>() {
Ok(Self(provider.get_by_key_mut::<TKey, TSvc>()))
} else {
Expand All @@ -82,10 +94,14 @@ impl<TKey, TSvc: ?Sized + 'static> FromRequestParts<()> for TryInjectWithKeyMut<
}

#[async_trait]
impl<TKey, TSvc: ?Sized + 'static> FromRequestParts<()> for InjectWithKeyMut<TKey, TSvc> {
impl<TKey, TSvc, S> FromRequestParts<S> for InjectWithKeyMut<TKey, TSvc>
where
TSvc: ?Sized + 'static,
S: Send + Sync,
{
type Rejection = (StatusCode, String);

async fn from_request_parts(parts: &mut Parts, _state: &()) -> Result<Self, Self::Rejection> {
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if let Some(provider) = parts.extensions.get::<ServiceProvider>() {
if let Some(service) = provider.get_by_key_mut::<TKey, TSvc>() {
return Ok(Self(service));
Expand All @@ -100,10 +116,14 @@ impl<TKey, TSvc: ?Sized + 'static> FromRequestParts<()> for InjectWithKeyMut<TKe
}

#[async_trait]
impl<TKey, TSvc: ?Sized + 'static> FromRequestParts<()> for InjectAllWithKey<TKey, TSvc> {
impl<TKey, TSvc, S> FromRequestParts<S> for InjectAllWithKey<TKey, TSvc>
where
TSvc: ?Sized + 'static,
S: Send + Sync,
{
type Rejection = Infallible;

async fn from_request_parts(parts: &mut Parts, _state: &()) -> Result<Self, Self::Rejection> {
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if let Some(provider) = parts.extensions.get::<ServiceProvider>() {
Ok(Self(provider.get_all_by_key::<TKey, TSvc>().collect()))
} else {
Expand All @@ -113,10 +133,14 @@ impl<TKey, TSvc: ?Sized + 'static> FromRequestParts<()> for InjectAllWithKey<TKe
}

#[async_trait]
impl<TKey, TSvc: ?Sized + 'static> FromRequestParts<()> for InjectAllWithKeyMut<TKey, TSvc> {
impl<TKey, TSvc, S> FromRequestParts<S> for InjectAllWithKeyMut<TKey, TSvc>
where
TSvc: ?Sized + 'static,
S: Send + Sync,
{
type Rejection = Infallible;

async fn from_request_parts(parts: &mut Parts, _state: &()) -> Result<Self, Self::Rejection> {
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if let Some(provider) = parts.extensions.get::<ServiceProvider>() {
Ok(Self(provider.get_all_by_key_mut::<TKey, TSvc>().collect()))
} else {
Expand Down

0 comments on commit f0e1596

Please sign in to comment.