Skip to content

Commit

Permalink
Merge pull request #11 from cert-orangecyberdefense/ci/bug/GH-10
Browse files Browse the repository at this point in the history
Handle expired tokens
  • Loading branch information
Hugo-C authored Nov 7, 2022
2 parents dbb9850 + 435bd1f commit 5b4a8c9
Show file tree
Hide file tree
Showing 13 changed files with 395 additions and 45 deletions.
1 change: 1 addition & 0 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ jobs:
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
- run: rustup toolchain install stable --profile minimal
- id: rustcache
uses: Swatinem/rust-cache@v2
with:
Expand Down
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "ocd_datalake_rs"
version = "0.1.2"
version = "0.2.0"
edition = "2021"
authors = ["Orange Cyberdefense CERT"]
description = "Library wrapper around Orange Cyberdefense's Datalake API"
Expand All @@ -18,6 +18,7 @@ serde_json = "1.0.48"
config = "0.13.1"
strum = "0.24"
strum_macros = "0.24"
log = "0.4"

[dev-dependencies]
mockito = "0.31.0"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Check [open issues](https://github.com/cert-orangecyberdefense/ocd-datalake-rs/i
put in Cargo.toml:
```
[dependencies]
ocd_datalake_rs = "0.1.2"
ocd_datalake_rs = "0.2.0"
```

## Usage
Expand Down
1 change: 1 addition & 0 deletions conf/conf.prod.ron
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
base_url: "https://datalake.cert.orangecyberdefense.com/api/v2",
routes: RoutesSetting(
authentication: "{base_url}/auth/token/",
refresh_token: "{base_url}/auth/refresh-token/",
atom_values_extract: "{base_url}/mrti/threats/atom-values-extract/",
bulk_lookup: "{base_url}/mrti/threats/bulk-lookup/",
bulk_search: "{base_url}/mrti/bulk-search/",
Expand Down
1 change: 1 addition & 0 deletions examples/custom_config.ron
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
base_url: "https://custom_host",
routes: RoutesSetting(
authentication: "{base_url}/auth/token/",
refresh_token: "{base_url}/auth/refresh-token/",
atom_values_extract: "{base_url}/mrti/threats/atom-values-extract/",
bulk_lookup: "value not tested !",
bulk_search: "value not tested !",
Expand Down
2 changes: 1 addition & 1 deletion examples/custom_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ fn main() {
"password".to_string(),
DatalakeSetting::new(contents.as_str()),
);
let result = dtl.get_token();
let result = dtl.get_access_token();
let err = result.expect_err("Error expected");
println!("{}", err.to_string()); // print "HTTP Error Could not fetch API for url https://custom_host/auth/token/"
}
15 changes: 7 additions & 8 deletions src/bulk_search/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ pub fn create_bulk_search_task(dtl: &mut Datalake, query_hash: String, query_fie
body.insert("query_fields".to_string(), query_fields_serialized);

let request = dtl.client.post(&url)
.header("Authorization", dtl.get_token()?)
.header("Accept", "text/csv");
let resp = request.json(&body).send()?;
.header("Accept", "text/csv")
.json(&body);
let resp = dtl.run_with_authorization_token(&request)?;
let status_code = resp.status();
let json_response = resp.json::<Value>()?;

Expand Down Expand Up @@ -85,9 +85,9 @@ pub fn get_bulk_search_task(dtl: &mut Datalake, uuid: TaskUuid) -> Result<BulkSe
body.insert("task_uuid".to_string(), Value::String(uuid));

let request = dtl.client.post(&url)
.header("Authorization", dtl.get_token()?)
.header("Accept", "application/json");
let resp = request.json(&body).send()?;
.header("Accept", "application/json")
.json(&body);
let resp = dtl.run_with_authorization_token(&request)?;

// Prepare fields for error message
let status_code = resp.status();
Expand Down Expand Up @@ -120,9 +120,8 @@ pub fn download_bulk_search(dtl: &mut Datalake, uuid: TaskUuid) -> Result<String
let url = dtl.settings.routes().bulk_search_download.replace("{task_uuid}", &uuid);
let request = dtl.client.get(&url)
.timeout(Duration::from_secs(BULK_SEARCH_DOWNLOAD_TIMEOUT))
.header("Authorization", dtl.get_token()?)
.header("Accept", "text/csv");
let resp = request.send()?;
let resp = dtl.run_with_authorization_token(&request)?;
let status_code = resp.status();

if status_code == 202 {
Expand Down
4 changes: 3 additions & 1 deletion src/error/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::error::Error;
use std::fmt;
use reqwest::StatusCode;
use crate::DatalakeError::TimeoutError;
use crate::error::DatalakeError::{ApiError, AuthenticationError, HttpError, ParseError};
use crate::error::DatalakeError::{ApiError, AuthenticationError, HttpError, ParseError, UnexpectedLibError};

#[derive(Debug, PartialEq, Eq)]
pub struct DetailedError {
Expand Down Expand Up @@ -30,6 +30,7 @@ pub enum DatalakeError {
ApiError(DetailedError),
TimeoutError(DetailedError),
ParseError(DetailedError),
UnexpectedLibError(DetailedError),
}


Expand All @@ -47,6 +48,7 @@ impl fmt::Display for DatalakeError {
TimeoutError(err) => write!(f, "Timeout Error {}", err),
ApiError(err) => write!(f, "API Error {}", err),
ParseError(err) => write!(f, "Parse Error {}", err),
UnexpectedLibError(err) => write!(f, "Unexpected Library Error {}", err),
}
}
}
Expand Down
169 changes: 141 additions & 28 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,31 @@ pub mod bulk_search;
use std::collections::{BTreeMap, HashMap};
use std::thread;
use std::time::{Duration, Instant};
use reqwest::blocking::Client;
use log::info;
use reqwest::blocking::{Client, RequestBuilder, Response};
use reqwest::header::AUTHORIZATION;
use serde_json::{json, Map, Value};
use crate::bulk_search::{create_bulk_search_task, download_bulk_search, get_bulk_search_task, State};
use crate::error::{DatalakeError, DetailedError};
use crate::DatalakeError::{ApiError, AuthenticationError, TimeoutError};
use crate::error::DatalakeError::UnexpectedLibError;
pub use crate::setting::{DatalakeSetting, RoutesSetting};

pub const ATOM_VALUE_QUERY_FIELD: &str = "atom_value";

#[derive(Clone, Debug)]
struct Tokens { // Tokens are saved with the "Token " prefix
access: String,
refresh: String,
}

#[derive(Clone, Debug)]
pub struct Datalake {
settings: DatalakeSetting,
username: String,
password: String,
client: Client,
access_token: Option<String>,
tokens: Option<Tokens>,
}

impl Datalake {
Expand All @@ -32,13 +41,10 @@ impl Datalake {
username,
password,
client: Client::new(),
access_token: None,
tokens: None,
}
}
// TODO handle expired access / refresh token
fn retrieve_api_token(&self) -> Result<String, DatalakeError> {
let mut token = "Token ".to_string();

fn retrieve_api_tokens(&self) -> Result<Tokens, DatalakeError> {
let url = &self.settings.routes().authentication;
let auth_request = self.client.post(url);
let mut json_body = HashMap::new();
Expand All @@ -47,8 +53,56 @@ impl Datalake {
let resp = auth_request.json(&json_body).send()?;
let status_code = resp.status();
let json_resp = resp.json::<Value>()?;
let raw_token = json_resp["access_token"].as_str();
let op_token = match raw_token {
let raw_access_token = json_resp["access_token"].as_str();
let raw_refresh_token = json_resp["refresh_token"].as_str();
if raw_access_token.is_none() || raw_refresh_token.is_none() {
let err = DetailedError {
summary: "Invalid credentials".to_string(),
api_url: Some(url.to_string()),
api_response: Some(json_resp.to_string()),
api_status_code: Some(status_code),
};
return Err(AuthenticationError(err));
} // Else access and refresh token are guaranteed to be there
let access_token = format!("Token {}", raw_access_token.unwrap());
let refresh_token = format!("Token {}", raw_refresh_token.unwrap());

Ok(Tokens {
access: access_token,
refresh: refresh_token,
})
}

/// Cached version of retrieve_api_token that return a new token only if needed
pub fn get_access_token(&mut self) -> Result<String, DatalakeError> {
if self.tokens.is_none() {
self.tokens = Some(self.retrieve_api_tokens()?);
}
let access_token = self.tokens.as_ref().unwrap().clone().access;
Ok(access_token)
}

/// Return valid tokens, first by using the refresh token, then by using user credentials
fn refresh_tokens(&self) -> Result<Tokens, DatalakeError> {
info!("Refreshing the access token");
let url = &self.settings.routes().refresh_token;
let refresh_token = if let Some(tokens) = &self.tokens {
tokens.clone().refresh
} else {
let error_message = "Refresh tokens called despite no token set".to_string();
return Err(UnexpectedLibError(DetailedError::new(error_message)));
};
let request = self.client.post(url)
.header("Authorization", refresh_token.clone());

let resp = request.send()?;
let status_code = resp.status();
if status_code == 401 {
info!("Refresh token is expired, authenticating from the start");
return self.retrieve_api_tokens();
}
let json_resp = resp.json::<Value>()?;
let access_token = match json_resp["access_token"].as_str() {
None => {
let err = DetailedError {
summary: "Invalid credentials".to_string(),
Expand All @@ -58,26 +112,19 @@ impl Datalake {
};
return Err(AuthenticationError(err));
}
Some(op_token) => { op_token }
Some(raw_access_token) => format!("Token {}", raw_access_token)
};
token.push_str(op_token);
Ok(token)
}

/// Cached version of retrieve_api_token that return a new token only if needed
pub fn get_token(&mut self) -> Result<String, DatalakeError> {
if self.access_token.is_none() {
self.access_token = Some(self.retrieve_api_token()?);
}
let token = self.access_token.as_ref().unwrap().clone();
Ok(token)
Ok(Tokens {
access: access_token,
refresh: refresh_token,
})
}

/// Return the atom types based on the given atom_values
pub fn extract_atom_type(&mut self, atom_values: &[String]) -> Result<BTreeMap<String, String>, DatalakeError> {
let url = self.settings.routes().atom_values_extract.clone();
let mut request = self.client.post(&url);
request = request.header("Authorization", self.get_token()?);
let mut joined_atom_values = String::from(&atom_values[0]);
for value in atom_values.iter().skip(1) {
joined_atom_values.push(' ');
Expand All @@ -86,7 +133,8 @@ impl Datalake {
let json_body = json!({
"content": joined_atom_values,
});
let resp = request.json(&json_body).send()?;
request = request.json(&json_body);
let resp = self.run_with_authorization_token(&request)?;
let status_code = resp.status();
let json_resp = resp.json::<Value>()?;
let extracted_atom_types = Self::parse_extract_atom_type_result(&json_resp);
Expand All @@ -103,6 +151,39 @@ impl Datalake {
}
}

/// Run a request by injecting authorization token. Automatically retry once if the token is expired
fn run_with_authorization_token(&mut self, request: &RequestBuilder) -> Result<Response, DatalakeError> {
let Some(mut cloned_request) = request.try_clone() else {
return Err(UnexpectedLibError(DetailedError::new("Can't clone given request".to_string())))
};
cloned_request = cloned_request.header(AUTHORIZATION, self.get_access_token()?);
let mut response = cloned_request.send()?;
let mut status_code = response.status();
if status_code != 401 {
return Ok(response);
}

// Else retry
self.tokens = Some(self.refresh_tokens()?);
let refreshed_token = self.get_access_token()?;
let Some(mut retry_request) = request.try_clone() else {
return Err(UnexpectedLibError(DetailedError::new("Can't clone given request".to_string())))
};
retry_request = retry_request.header(AUTHORIZATION, refreshed_token);
response = retry_request.send()?;
status_code = response.status();
if status_code == 401 {
Err(AuthenticationError(DetailedError {
summary: "401 response despite refreshed token".to_string(),
api_url: Some(response.url().to_string()),
api_response: response.text().ok(),
api_status_code: Some(status_code),
}))
} else {
Ok(response) // Refreshing the token was enough to yield a correct response
}
}

fn parse_extract_atom_type_result(json_resp: &Value) -> Option<BTreeMap<String, String>> {
let results_value = json_resp.get("results")?;
let results = results_value.as_object()?;
Expand Down Expand Up @@ -150,8 +231,8 @@ impl Datalake {
}

/// Bulk lookup a chunk of atom_values
fn bulk_lookup_chunk(&mut self, atom_values: &[String]) -> Result<String, DatalakeError> {
// Construct the body by identifying the atom types
fn bulk_lookup_chunk(&mut self, atom_values: &[String]) -> Result<String, DatalakeError> {
// Construct the body by identifying the atom types
let extracted = self.extract_atom_type(atom_values)?;
let mut body = Map::new();
body.insert("hashkey_only".to_string(), Value::Bool(false));
Expand All @@ -168,11 +249,11 @@ impl Datalake {
}

let request = self.client.post(&self.settings.routes().bulk_lookup)
.header("Authorization", self.get_token()?)
.header("Accept", "text/csv");
let csv_resp = request.json(&body).send()?.text()?;
.header("Accept", "text/csv")
.json(&body);
let csv_resp = self.run_with_authorization_token(&request)?.text()?;
Ok(csv_resp)
}
}

/// Retrieve all the results of a query using its query_hash.
///
Expand Down Expand Up @@ -207,6 +288,8 @@ impl Datalake {
#[cfg(test)]
mod tests {
use crate::{Datalake, DatalakeSetting};
use crate::error::DatalakeError::UnexpectedLibError;
use crate::error::DetailedError;

#[test]
fn test_create_datalake_with_prod_config() {
Expand All @@ -231,4 +314,34 @@ mod tests {

assert_eq!(dtl.settings.routes().authentication, "https://ti.extranet.mrti-center.com/api/v2/auth/token/");
}

#[test]
fn test_run_with_authorization_token_fail_on_unclonable_request() {
let preprod_setting = DatalakeSetting::preprod();
let mut dtl = Datalake::new(
"username".to_string(),
"password".to_string(),
preprod_setting,
);
// Create a random request
let mut request = dtl.client.post(&dtl.settings.routes().authentication);
// Set a streaming body that can't be cloned
request = request.body(reqwest::blocking::Body::new(std::io::empty()));
let err = dtl.run_with_authorization_token(&request).err().unwrap();
let expected_error_message = "Can't clone given request".to_string();
assert_eq!(err, UnexpectedLibError(DetailedError::new(expected_error_message)));
}

#[test]
fn test_refresh_tokens_with_no_existing_tokens() {
let preprod_setting = DatalakeSetting::preprod();
let dtl = Datalake::new(
"username".to_string(),
"password".to_string(),
preprod_setting,
);
let err = dtl.refresh_tokens().err().unwrap();
let expected_error_message = "Refresh tokens called despite no token set".to_string();
assert_eq!(err, UnexpectedLibError(DetailedError::new(expected_error_message)));
}
}
2 changes: 2 additions & 0 deletions src/setting/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use serde::Deserialize;
#[derive(Deserialize, Clone, Debug)]
pub struct RoutesSetting {
pub authentication: String,
pub refresh_token: String,
pub atom_values_extract: String,
pub bulk_lookup: String,
pub bulk_search: String,
Expand Down Expand Up @@ -40,6 +41,7 @@ impl DatalakeSetting {
fn replace_base_url(&mut self) {
self.formatted_routes = Some(RoutesSetting {
authentication: self.routes.authentication.replace("{base_url}", &self.base_url),
refresh_token: self.routes.refresh_token.replace("{base_url}", &self.base_url),
atom_values_extract: self.routes.atom_values_extract.replace("{base_url}", &self.base_url),
bulk_lookup: self.routes.bulk_lookup.replace("{base_url}", &self.base_url),
bulk_search: self.routes.bulk_search.replace("{base_url}", &self.base_url),
Expand Down
Loading

0 comments on commit 5b4a8c9

Please sign in to comment.