Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/OIDC #392

Merged
merged 2 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 16 additions & 54 deletions src/auth_middleware.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use std::collections::HashSet;
use std::future::Future;
use std::ops::Deref;
use std::pin::Pin;
use std::rc::Rc;
use std::sync::Mutex;
use std::time::{SystemTime, UNIX_EPOCH};


use actix::fut::{ok};
use futures_util::FutureExt;
use actix_web::{dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform}, Error, HttpMessage, web};
Expand All @@ -13,16 +15,14 @@ use base64::engine::general_purpose;
use futures_util::future::{LocalBoxFuture, Ready};
use dotenv::var;
use jsonwebtoken::{Algorithm, decode, DecodingKey, Validation};
use jsonwebtoken::jwk::Jwk;
use jsonwebtoken::jwk::{Jwk};
use log::info;
use serde_json::{from_str, Value};
use serde_json::{Value};
use crate::constants::inner_constants::{BASIC_AUTH, OIDC_AUTH, PASSWORD, USERNAME};
use crate::{DbPool};
use crate::models::user::User;
use sha256::digest;
use crate::models::oidc_model::{CustomJwk, CustomJwkSet};
use crate::mutex::LockResultExt;
use crate::service::jwkservice::JWKService;

use crate::utils::environment_variables::is_env_var_present_and_true;

pub struct AuthFilter {
Expand Down Expand Up @@ -154,50 +154,17 @@ impl<S, B> AuthFilterMiddleware<S> where B: 'static + MessageBody, S: 'static +
}
let token = token_res.unwrap().replace("Bearer ", "");

let start = SystemTime::now();
let since_the_epoch = start
.duration_since(UNIX_EPOCH)
.expect("Time went backwards").as_secs();

let response:CustomJwkSet;
let binding = req.app_data::<web::Data<Mutex<JWKService>>>().cloned().unwrap();
let mut jwk_service = binding.lock()
.ignore_poison();
match jwk_service.jwk.clone() {
Some(jwk)=>{
if since_the_epoch-jwk_service.timestamp>3600{
//refetch and update timestamp
info!("Renewing jwk set");
response = AuthFilter::get_jwk();
jwk_service.jwk = Some(response.clone());
jwk_service.timestamp = since_the_epoch
}
else{
info!("Using cached jwk set");
response = jwk;
}
}
None=>{
// Fetch on cold start
response = AuthFilter::get_jwk();
jwk_service.jwk = Some(response.clone());
jwk_service.timestamp = since_the_epoch
}
}
let jwk = req.app_data::<web::Data<Option<Jwk>>>().cloned().unwrap();

// Create a DecodingKey from a PEM-encoded RSA string

// Filter out all unknown algorithms
let response = response.clone().keys.into_iter().filter(|x| {
x.alg.eq(&"RS256")
}).collect::<Vec<CustomJwk>>();

let jwk = response.clone();
let custom_jwk = jwk.get(0).expect("Your jwk set needs to have RS256");
let key = DecodingKey::from_jwk(&jwk.as_ref().clone().unwrap()).unwrap();
let mut validation = Validation::new(Algorithm::RS256);
validation.aud = Some(req.app_data::<web::Data<HashSet<String>>>().unwrap().clone().into_inner()
.deref().clone());

let jwk_string = serde_json::to_string(&custom_jwk).unwrap();

let jwk = from_str::<Jwk>(&jwk_string).unwrap();
let key = DecodingKey::from_jwk(&jwk).unwrap();
let validation = Validation::new(Algorithm::RS256);
return match decode::<Value>(&token, &key, &validation) {
Ok(decoded) => {
let username = decoded.claims.get("preferred_username").unwrap().as_str().unwrap();
Expand Down Expand Up @@ -237,7 +204,8 @@ impl<S, B> AuthFilterMiddleware<S> where B: 'static + MessageBody, S: 'static +
}
}
},
_ => {
Err(e) => {
info!("Error decoding token: {:?}", e);
Box::pin(ok(req.error_response(ErrorForbidden("Forbidden"))
.map_into_right_body()))
}
Expand Down Expand Up @@ -271,12 +239,6 @@ impl AuthFilter{
(username.to_string(), password.to_string())
}

pub fn get_jwk() -> CustomJwkSet {
let jwk_uri = var("OIDC_JWKS").expect("OIDC_JWKS must be set");
reqwest::blocking::get(jwk_uri).unwrap()
.json::<CustomJwkSet>().unwrap()
}

pub fn basic_auth_login(rq: String) -> (String, String) {
let (u,p) = Self::extract_basic_auth(rq.as_str());

Expand Down
4 changes: 3 additions & 1 deletion src/constants/inner_constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,6 @@ pub const MAX_FILE_TREE_DEPTH:i32 = 4;


pub const COMMON_USER_AGENT: &str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 \
(KHTML, like Gecko) Chrome/116.0.0.0 Safari/537.36";
(KHTML, like Gecko) Chrome/116.0.0.0 Safari/537.36";

pub const OIDC_JWKS: &str = "OIDC_JWKS";
71 changes: 70 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,17 @@ use clokwerk::{Scheduler, TimeUnits};
use std::sync::{Mutex};
use std::time::Duration;
use std::{env, thread};
use std::collections::HashSet;
use std::env::{args, var};
use std::io::Read;

use std::process::exit;
use actix_web::body::{BoxBody, EitherBody};
use log::{info};
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
use diesel::r2d2::{ConnectionManager};
use jsonwebtoken::jwk::{AlgorithmParameters, CommonParameters, Jwk, KeyAlgorithm, RSAKeyParameters, RSAKeyType};
use r2d2::{Pool};
use regex::Regex;
use tokio::task::spawn_blocking;
Expand All @@ -31,7 +34,7 @@ mod controllers;
#[cfg(sqlite)]
use crate::config::dbconfig::{ConnectionOptions};
use crate::config::dbconfig::{establish_connection, get_database_url};
use crate::constants::inner_constants::{BASIC_AUTH, CSS, JS, OIDC_AUTH, TELEGRAM_API_ENABLED, TELEGRAM_BOT_CHAT_ID, TELEGRAM_BOT_TOKEN};
use crate::constants::inner_constants::{BASIC_AUTH, CSS, JS, OIDC_AUTH, OIDC_CLIENT_ID, OIDC_JWKS, TELEGRAM_API_ENABLED, TELEGRAM_BOT_CHAT_ID, TELEGRAM_BOT_TOKEN};
use crate::controllers::api_doc::ApiDoc;
use crate::controllers::notification_controller::{
dismiss_notifications, get_unread_notifications,
Expand Down Expand Up @@ -60,6 +63,7 @@ mod models;
mod service;
use crate::gpodder::parametrization::get_client_parametrization;
use crate::gpodder::routes::get_gpodder_api;
use crate::models::oidc_model::{CustomJwk, CustomJwkSet};
use crate::models::podcasts::Podcast;
use crate::models::session::Session;
use crate::models::settings::Setting;
Expand Down Expand Up @@ -222,8 +226,71 @@ async fn main() -> std::io::Result<()> {
thread::sleep(Duration::from_millis(1000));
}
});

let key_param: Option<RSAKeyParameters>;
let mut hash = HashSet::new();
let jwk: Option<Jwk>;

match var(OIDC_JWKS) {
Ok(jwk_uri)=>{
let resp = reqwest::get(&jwk_uri).await.unwrap()
.json::<CustomJwkSet>().await;

match resp {
Ok(res) => {
let oidc = res
.clone()
.keys
.into_iter()
.filter(|x| x.alg.eq(&"RS256"))
.collect::<Vec<CustomJwk>>()
.first().cloned();

if oidc.is_none() {
panic!("No RS256 key found in JWKS")
}

key_param = Some(RSAKeyParameters {
e: oidc.clone().unwrap().e,
n: oidc.unwrap().n.clone(),
key_type: RSAKeyType::RSA,
});

jwk = Some(Jwk{
common: CommonParameters{
public_key_use: None,
key_id: None,
x509_url: None,
x509_chain: None,
x509_sha1_fingerprint: None,
key_operations: None,
key_algorithm: Some(KeyAlgorithm::RS256),
x509_sha256_fingerprint: None,
},
algorithm: AlgorithmParameters::RSA(key_param.clone().unwrap()),
});
},
Err(_) => {
panic!("Error downloading OIDC")
}
}
}
_ => {
key_param = None;
jwk = None;
}
}

if let Ok(client_id) = var(OIDC_CLIENT_ID) {
hash.insert(client_id);
}


HttpServer::new(move || {
App::new()
.app_data(Data::new(key_param.clone()))
.app_data(Data::new(jwk.clone()))
.app_data(Data::new(hash.clone()))
.service(redirect("/", var("SUB_DIRECTORY").unwrap()+"/ui/"))
.service(get_gpodder_api(environment_service.clone()))
.service(get_global_scope())
Expand Down Expand Up @@ -284,6 +351,8 @@ pub fn get_global_scope() -> Scope {

fn get_private_api() -> Scope<impl ServiceFactory<ServiceRequest, Config = (), Response = ServiceResponse<EitherBody<BoxBody>>, Error = actix_web::Error, InitError = ()>> {
let middleware = AuthFilter::new();


web::scope("")
.wrap(middleware)
.service(delete_playlist_item)
Expand Down
Loading