Skip to content

Commit

Permalink
Removed unwraps in middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
SamTV12345 committed Dec 31, 2024
1 parent f7736b9 commit 1612399
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 92 deletions.
11 changes: 6 additions & 5 deletions src/adapters/api/controllers/device_controller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ pub async fn post_device(
) -> Result<HttpResponse, CustomError> {
match opt_flag {
Some(flag) => {
let username = query.clone().0;
let deviceid = query.clone().1;
if flag.username != username {
let username = &query.0;
let deviceid = &query.1;
if &flag.username != username {
return Err(CustomError::Forbidden);
}

Expand All @@ -46,10 +46,11 @@ pub async fn get_devices_of_user(
) -> Result<HttpResponse, CustomError> {
match opt_flag {
Some(flag) => {
if flag.username != query.clone() {
let user_query = query.into_inner();
if flag.username != user_query {
return Err(CustomError::Forbidden);
}
let devices = DeviceService::query_by_username(query.clone())?;
let devices = DeviceService::query_by_username(&user_query)?;

let dtos = devices
.iter()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ impl DeviceRepository for DeviceRepositoryImpl {
.map(|device_entity: DeviceEntity| device_entity.into()))
}

fn get_devices_of_user(username_to_find: String) -> Result<Vec<Device>, CustomError> {
fn get_devices_of_user(username_to_find: &str) -> Result<Vec<Device>, CustomError> {
devices
.filter(username.eq(username_to_find))
.load::<DeviceEntity>(&mut get_connection())
Expand Down
2 changes: 1 addition & 1 deletion src/application/repositories/device_repository.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ use crate::utils::error::CustomError;

pub trait DeviceRepository {
fn create(device: Device) -> Result<Device, CustomError>;
fn get_devices_of_user(username: String) -> Result<Vec<Device>, CustomError>;
fn get_devices_of_user(username: &str) -> Result<Vec<Device>, CustomError>;
fn delete_by_username(username: &str) -> Result<(), CustomError>;
}
2 changes: 1 addition & 1 deletion src/application/services/device/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ impl CreateUseCase for DeviceService {
}

impl QueryUseCase for DeviceService {
fn query_by_username(username: String) -> Result<Vec<Device>, CustomError> {
fn query_by_username(username: &str) -> Result<Vec<Device>, CustomError> {
DeviceRepositoryImpl::get_devices_of_user(username)
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/application/usecases/devices/query_use_case.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ use crate::domain::models::device::model::Device;
use crate::utils::error::CustomError;

pub trait QueryUseCase {
fn query_by_username(username: String) -> Result<Vec<Device>, CustomError>;
fn query_by_username(username: &str) -> Result<Vec<Device>, CustomError>;
}
208 changes: 127 additions & 81 deletions src/auth_middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
use log::info;
use serde_json::Value;
use sha256::digest;
use crate::utils::error::CustomError;

pub struct AuthFilter {}

Expand Down Expand Up @@ -97,33 +98,44 @@ where
match opt_auth_header {
Some(header) => match header.to_str() {
Ok(auth) => {
let (username, password) = AuthFilter::extract_basic_auth(auth);
let result_of_check = AuthFilter::extract_basic_auth(auth);
if result_of_check.is_err() {
return Box::pin(ok(req
.error_response(ErrorUnauthorized("Unauthorized"))
.map_into_right_body()));
}

let (username, password) = result_of_check.expect("Error extracting basic auth");
let found_user = User::find_by_username(username.as_str());

if found_user.is_err() {
return Box::pin(ok(req
.error_response(ErrorUnauthorized("Unauthorized"))
.map_into_right_body()));
}
let unwrapped_user = found_user.unwrap();
let unwrapped_user = found_user.expect("Error unwrapping user");

if let Some(admin_username) = ENVIRONMENT_SERVICE.username.clone() {
if unwrapped_user.username.clone() == admin_username {
return match ENVIRONMENT_SERVICE.password.is_some()
&& digest(password) == ENVIRONMENT_SERVICE.password.clone().unwrap()
{
true => {
req.extensions_mut().insert(unwrapped_user);
let service = Rc::clone(&self.service);
async move {
service.call(req).await.map(|res| res.map_into_left_body())
}
.boxed_local()
return if let Some(password) = &ENVIRONMENT_SERVICE.password {
if &digest(password) == password {
req.extensions_mut().insert(unwrapped_user);
let service = Rc::clone(&self.service);
async move {
service.call(req).await.map(|res| res.map_into_left_body())
}
.boxed_local()
}
else {
Box::pin(ok(req
.error_response(ErrorUnauthorized("Unauthorized"))
.map_into_right_body()))
}
false => Box::pin(ok(req
.error_response(ErrorUnauthorized("Unauthorized"))
.map_into_right_body())),
};
} else {
Box::pin(ok(req
.error_response(ErrorUnauthorized("Unauthorized"))
.map_into_right_body()))
}
}
}

Expand All @@ -149,76 +161,100 @@ where
}

fn handle_oidc_auth(&self, req: ServiceRequest) -> MyFuture<B, Error> {
let token_res = req.headers().get("Authorization").unwrap().to_str();
let token_res = match req.headers().get("Authorization") {
Some(token) => Ok(token.to_str()),
None => Err(ErrorUnauthorized("Unauthorized")),
};

if token_res.is_err() {
return Box::pin(ok(req
.error_response(ErrorUnauthorized("Unauthorized"))
.map_into_right_body()));
}
let token = token_res.unwrap().replace("Bearer ", "");

let jwk = req.app_data::<web::Data<Option<Jwk>>>().cloned().unwrap();

// Create a DecodingKey from a PEM-encoded RSA string

let key = DecodingKey::from_jwk(&jwk.as_ref().clone().unwrap()).unwrap();
let mut validation = Validation::new(Algorithm::RS256);
validation.aud = Some(
req.app_data::<web::Data<HashSet<String>>>()
.unwrap()
.clone()
.into_inner()
.deref()
.clone(),
);

match decode::<Value>(&token, &key, &validation) {
Ok(decoded) => {
let username = decoded
.claims
.get("preferred_username")

if let Ok(Ok(token)) = token_res {
let token = token.replace("Bearer ", "");
let jwk = req.app_data::<web::Data<Option<Jwk>>>().cloned().unwrap();
let key = DecodingKey::from_jwk(&jwk.as_ref().clone().unwrap()).unwrap();
let mut validation = Validation::new(Algorithm::RS256);
validation.aud = Some(
req.app_data::<web::Data<HashSet<String>>>()
.unwrap()
.as_str()
.unwrap();
let found_user = User::find_by_username(username);
let service = Rc::clone(&self.service);

match found_user {
Ok(user) => {
req.extensions_mut().insert(user);
async move { service.call(req).await.map(|res| res.map_into_left_body()) }
.boxed_local()
}
Err(_) => {
// User is authenticated so we can onboard him if he is new
let user = User::insert_user(&mut User {
id: 0,
username: decoded
.clone()
.into_inner()
.deref()
.clone(),
);
match decode::<Value>(&token, &key, &validation) {
Ok(decoded) => {
let username = decoded
.claims
.get("preferred_username")
.unwrap()
.as_str()
.unwrap();
let found_user = User::find_by_username(username);
let service = Rc::clone(&self.service);

match found_user {
Ok(user) => {
req.extensions_mut().insert(user);
async move { service.call(req).await.map(|res| res.map_into_left_body()) }
.boxed_local()
}
Err(_) => {
let preferred_username_claim = decoded
.claims
.get("preferred_username")
.unwrap()
.as_str()
.unwrap()
.to_string(),
role: "user".to_string(),
password: None,
explicit_consent: false,
created_at: chrono::Utc::now().naive_utc(),
api_key: None,
})
.expect("Error inserting user");
req.extensions_mut().insert(user);
async move { service.call(req).await.map(|res| res.map_into_left_body()) }
.boxed_local()
.get("preferred_username");

if preferred_username_claim.is_none() {
return Box::pin(ok(req
.error_response(ErrorForbidden("Forbidden"))
.map_into_right_body()));
}

let content = preferred_username_claim.expect("Preferred username \
claim is \
none").as_str();

if content.is_none() {
return Box::pin(ok(req
.error_response(ErrorForbidden("Forbidden"))
.map_into_right_body()));
}

let preferred_username = content.expect("Preferred username is none");

// User is authenticated so we can onboard him if he is new
let user = User::insert_user(&mut User {
id: 0,
username: preferred_username.to_string(),
role: "user".to_string(),
password: None,
explicit_consent: false,
created_at: chrono::Utc::now().naive_utc(),
api_key: None,
})
.expect("Error inserting user");
req.extensions_mut().insert(user);
async move { service.call(req).await.map(|res| res.map_into_left_body()) }
.boxed_local()
}
}
}
}
Err(e) => {
info!("Error decoding token: {:?}", e);
Err(e)=>{
info!("Error decoding token: {:?}", e);
Box::pin(ok(req
.error_response(ErrorForbidden("Forbidden"))
.map_into_right_body()))
}
}
} else {
// Create a DecodingKey from a PEM-encoded RSA string
info!("Error decoding token");
Box::pin(ok(req
.error_response(ErrorForbidden("Forbidden"))
.map_into_right_body()))
}
}
}

Expand All @@ -230,7 +266,17 @@ where
}

fn handle_proxy_auth(&self, req: ServiceRequest) -> MyFuture<B, Error> {
let config = ENVIRONMENT_SERVICE.reverse_proxy_config.clone().unwrap();
let config = &ENVIRONMENT_SERVICE.reverse_proxy_config;

if config.is_none() {
info!("Reverse proxy is enabled but no config is provided");
return Box::pin(ok(req
.error_response(ErrorForbidden("Forbidden"))
.map_into_right_body()));
}

let config = config.clone().expect("Reverse proxy config is not set");


let header_val = req.headers().get(config.header_name);

Expand Down Expand Up @@ -287,21 +333,21 @@ where
}

impl AuthFilter {
pub fn extract_basic_auth(auth: &str) -> (String, String) {
pub fn extract_basic_auth(auth: &str) -> Result<(String, String), CustomError> {
let auth = auth.to_string();
let auth = auth.split(' ').collect::<Vec<&str>>();
let auth = auth[1];
let auth = general_purpose::STANDARD.decode(auth).unwrap();
let auth = general_purpose::STANDARD.decode(auth).map_err(|_| CustomError::Forbidden)?;
let auth = String::from_utf8(auth).unwrap();
let auth = auth.split(':').collect::<Vec<&str>>();
let username = auth[0];
let password = auth[1];
(username.to_string(), password.to_string())
Ok((username.to_string(), password.to_string()))
}

pub fn basic_auth_login(rq: String) -> (String, String) {
let (u, p) = Self::extract_basic_auth(rq.as_str());
pub fn basic_auth_login(rq: String) -> Result<(String, String), CustomError> {
let (u, p) = Self::extract_basic_auth(rq.as_str())?;

(u.to_string(), p.to_string())
Ok((u.to_string(), p.to_string()))
}
}
4 changes: 2 additions & 2 deletions src/gpodder/auth/authentication.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub async fn login(
rq: HttpRequest,
) -> Result<HttpResponse, CustomError> {
// If cookie is already set, return it
if let Some(cookie) = rq.clone().cookie("sessionid") {
if let Some(cookie) = rq.cookie("sessionid") {
let session = cookie.value();
let opt_session = Session::find_by_session_id(session);
if let Ok(unwrapped_session) = opt_session {
Expand Down Expand Up @@ -87,7 +87,7 @@ fn handle_gpodder_basic_auth(
let authorization = opt_authorization.unwrap().to_str().unwrap();

let unwrapped_username = username.into_inner();
let (username_basic, password) = AuthFilter::basic_auth_login(authorization.to_string());
let (username_basic, password) = AuthFilter::basic_auth_login(authorization.to_string())?;
if username_basic != unwrapped_username {
return Err(CustomError::Forbidden);
}
Expand Down

0 comments on commit 1612399

Please sign in to comment.