From b0c3c7031ca04a3235239587ced76a9bbfd45c05 Mon Sep 17 00:00:00 2001 From: Flavien David Date: Mon, 22 Jul 2024 08:47:57 +0200 Subject: [PATCH] Improve oauth service provider errors (#6364) * Improve provider errors in oauth service * :scissors: --- core/src/oauth/connection.rs | 78 ++++++++++++++++++++++-- core/src/oauth/providers/confluence.rs | 14 +++-- core/src/oauth/providers/github.rs | 10 ++- core/src/oauth/providers/google_drive.rs | 14 +++-- core/src/oauth/providers/intercom.rs | 16 +++-- core/src/oauth/providers/notion.rs | 16 +++-- core/src/oauth/providers/slack.rs | 37 +++++------ core/src/oauth/providers/utils.rs | 65 ++++++++++++-------- 8 files changed, 179 insertions(+), 71 deletions(-) diff --git a/core/src/oauth/connection.rs b/core/src/oauth/connection.rs index e7e271586f76..04f3a972661f 100644 --- a/core/src/oauth/connection.rs +++ b/core/src/oauth/connection.rs @@ -23,6 +23,8 @@ use std::time::Duration; use std::{env, fmt}; use tracing::{error, info}; +use super::providers::utils::ProviderHttpRequestError; + // We hold the lock for at most 15s. In case of panic preventing the lock from being released, this // is the maximum time the lock will be held. static REDIS_LOCK_TTL_SECONDS: u64 = 15; @@ -40,6 +42,8 @@ lazy_static! { }; } +// API Error types. + // Define the ErrorKind enum with serde attributes for serialization #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] @@ -128,14 +132,33 @@ pub trait Provider { connection: &Connection, code: &str, redirect_uri: &str, - ) -> Result; + ) -> Result; - async fn refresh(&self, connection: &Connection) -> Result; + async fn refresh(&self, connection: &Connection) -> Result; // This method scrubs raw_json to remove information that should not exfill `oauth`, in // particular the `refresh_token`. By convetion the `access_token` should be scrubbed as well // to prevent users from relying in the raw_json to access it. fn scrubbed_raw_json(&self, raw_json: &serde_json::Value) -> Result; + + // Default implementation for handling errors. + fn handle_provider_request_error(&self, error: ProviderHttpRequestError) -> ProviderError { + match error { + ProviderHttpRequestError::NetworkError(e) => ProviderError::UnknownError(e.to_string()), + ProviderHttpRequestError::Timeout => ProviderError::TimeoutError, + ProviderHttpRequestError::RequestFailed { + provider, + status, + message: _, + } => ProviderError::UnknownError(format!( + "Request failed for provider {}. Status: {}.", + provider, status + )), + ProviderHttpRequestError::InvalidResponse(e) => { + ProviderError::UnknownError(e.to_string()) + } + } + } } pub fn provider(t: ConnectionProvider) -> Box { @@ -149,6 +172,44 @@ pub fn provider(t: ConnectionProvider) -> Box { } } +// Internal Error types. + +#[derive(Debug, thiserror::Error)] +pub enum ProviderError { + #[error("Action not supported: {0}.")] + ActionNotSupportedError(String), + // TODO(2024-07-19 flav) Implement InvalidToken. + #[error("Timeout error.")] + TimeoutError, + #[error("Unknown error: {0}.")] + UnknownError(String), + #[error("Internal error: {0}.")] + InternalError(anyhow::Error), +} + +impl From for ProviderError { + fn from(error: anyhow::Error) -> Self { + ProviderError::InternalError(error) + } +} + +impl From<&ProviderError> for ConnectionError { + fn from(err: &ProviderError) -> Self { + match err { + ProviderError::ActionNotSupportedError(_) + | ProviderError::TimeoutError + | ProviderError::UnknownError(_) => ConnectionError { + code: ConnectionErrorCode::ProviderFinalizationError, + message: err.to_string(), + }, + ProviderError::InternalError(_) => ConnectionError { + code: ConnectionErrorCode::InternalError, + message: "Failed to finalize connection with provider".to_string(), + }, + } + } +} + #[derive(Debug, Clone, Copy, Serialize, PartialEq, Deserialize)] #[serde(rename_all = "snake_case")] pub enum ConnectionStatus { @@ -558,10 +619,15 @@ impl Connection { provider = ?self.provider, "Failed to finalize connection", ); - Err(ConnectionError { - code: ConnectionErrorCode::ProviderFinalizationError, - message: "Failed to finalize connection with provider".to_string(), - }) + + if let Some(provider_error) = e.downcast_ref::() { + Err(ConnectionError::from(provider_error)) + } else { + Err(ConnectionError { + code: ConnectionErrorCode::InternalError, + message: "Failed to finalize connection with provider".to_string(), + }) + } } } } diff --git a/core/src/oauth/providers/confluence.rs b/core/src/oauth/providers/confluence.rs index 8529a593359b..7bba80f7fbaa 100644 --- a/core/src/oauth/providers/confluence.rs +++ b/core/src/oauth/providers/confluence.rs @@ -1,7 +1,7 @@ use crate::{ oauth::{ connection::{ - Connection, ConnectionProvider, FinalizeResult, Provider, RefreshResult, + Connection, ConnectionProvider, FinalizeResult, Provider, ProviderError, RefreshResult, PROVIDER_TIMEOUT_SECONDS, }, providers::utils::execute_request, @@ -40,7 +40,7 @@ impl Provider for ConfluenceConnectionProvider { _connection: &Connection, code: &str, redirect_uri: &str, - ) -> Result { + ) -> Result { let body = json!({ "grant_type": "authorization_code", "client_id": *OAUTH_CONFLUENCE_CLIENT_ID, @@ -54,7 +54,9 @@ impl Provider for ConfluenceConnectionProvider { .header("Content-Type", "application/json") .json(&body); - let raw_json = execute_request(ConnectionProvider::Confluence, req).await?; + let raw_json = execute_request(ConnectionProvider::Confluence, req) + .await + .map_err(|e| self.handle_provider_request_error(e))?; let access_token = match raw_json["access_token"].as_str() { Some(token) => token, @@ -92,7 +94,7 @@ impl Provider for ConfluenceConnectionProvider { /// Note: Confluence hard expires refresh_tokens after 360 days. /// Confluence expires access_tokens after 1 hour. /// Confluence expires refresh_tokens after 30 days of inactivity. - async fn refresh(&self, connection: &Connection) -> Result { + async fn refresh(&self, connection: &Connection) -> Result { let refresh_token = match connection.unseal_refresh_token() { Ok(Some(token)) => token, Ok(None) => Err(anyhow!("Missing `refresh_token` in Confluence connection"))?, @@ -111,7 +113,9 @@ impl Provider for ConfluenceConnectionProvider { .header("Content-Type", "application/json") .json(&body); - let raw_json = execute_request(ConnectionProvider::Confluence, req).await?; + let raw_json = execute_request(ConnectionProvider::Confluence, req) + .await + .map_err(|e| self.handle_provider_request_error(e))?; let access_token = match raw_json["access_token"].as_str() { Some(token) => token, diff --git a/core/src/oauth/providers/github.rs b/core/src/oauth/providers/github.rs index 3e5d538b669f..853df7b2586b 100644 --- a/core/src/oauth/providers/github.rs +++ b/core/src/oauth/providers/github.rs @@ -1,6 +1,8 @@ use crate::{ oauth::{ - connection::{Connection, ConnectionProvider, FinalizeResult, Provider, RefreshResult}, + connection::{ + Connection, ConnectionProvider, FinalizeResult, Provider, ProviderError, RefreshResult, + }, providers::utils::execute_request, }, utils, @@ -102,7 +104,7 @@ impl Provider for GithubConnectionProvider { _connection: &Connection, code: &str, redirect_uri: &str, - ) -> Result { + ) -> Result { // `code` is the installation_id returned by Github. let (token, expiry, raw_json) = self.refresh_token(code).await?; @@ -117,7 +119,7 @@ impl Provider for GithubConnectionProvider { }) } - async fn refresh(&self, connection: &Connection) -> Result { + async fn refresh(&self, connection: &Connection) -> Result { // `code` is the installation_id returned by Github. let code = match connection.unseal_authorization_code()? { Some(code) => code, @@ -145,4 +147,6 @@ impl Provider for GithubConnectionProvider { }; Ok(raw_json) } + + // TODO(2024-07-19 flav) Implement custom error handling for GitHub. } diff --git a/core/src/oauth/providers/google_drive.rs b/core/src/oauth/providers/google_drive.rs index 2c4e821b9131..359a38ff90dc 100644 --- a/core/src/oauth/providers/google_drive.rs +++ b/core/src/oauth/providers/google_drive.rs @@ -1,7 +1,7 @@ use crate::{ oauth::{ connection::{ - Connection, ConnectionProvider, FinalizeResult, Provider, RefreshResult, + Connection, ConnectionProvider, FinalizeResult, Provider, ProviderError, RefreshResult, PROVIDER_TIMEOUT_SECONDS, }, providers::utils::execute_request, @@ -40,7 +40,7 @@ impl Provider for GoogleDriveConnectionProvider { _connection: &Connection, code: &str, redirect_uri: &str, - ) -> Result { + ) -> Result { let body = json!({ "grant_type": "authorization_code", "client_id": *OAUTH_GOOGLE_DRIVE_CLIENT_ID, @@ -54,7 +54,9 @@ impl Provider for GoogleDriveConnectionProvider { .header("Content-Type", "application/json") .json(&body); - let raw_json = execute_request(ConnectionProvider::GoogleDrive, req).await?; + let raw_json = execute_request(ConnectionProvider::GoogleDrive, req) + .await + .map_err(|e| self.handle_provider_request_error(e))?; let access_token = match raw_json["access_token"].as_str() { Some(token) => token, @@ -97,7 +99,7 @@ impl Provider for GoogleDriveConnectionProvider { // Google Drive does not automatically expire refresh tokens for published apps, // unless they have been unused for six months. // Acess tokens expire after 1 hour. - async fn refresh(&self, connection: &Connection) -> Result { + async fn refresh(&self, connection: &Connection) -> Result { let refresh_token = match connection.unseal_refresh_token() { Ok(Some(token)) => token, Ok(None) => Err(anyhow!( @@ -118,7 +120,9 @@ impl Provider for GoogleDriveConnectionProvider { .header("Content-Type", "application/json") .json(&body); - let raw_json = execute_request(ConnectionProvider::GoogleDrive, req).await?; + let raw_json = execute_request(ConnectionProvider::GoogleDrive, req) + .await + .map_err(|e| self.handle_provider_request_error(e))?; let access_token = match raw_json["access_token"].as_str() { Some(token) => token, diff --git a/core/src/oauth/providers/intercom.rs b/core/src/oauth/providers/intercom.rs index 5fdf05c2a341..4ee8c748fc0a 100644 --- a/core/src/oauth/providers/intercom.rs +++ b/core/src/oauth/providers/intercom.rs @@ -1,5 +1,7 @@ use crate::oauth::{ - connection::{Connection, ConnectionProvider, FinalizeResult, Provider, RefreshResult}, + connection::{ + Connection, ConnectionProvider, FinalizeResult, Provider, ProviderError, RefreshResult, + }, providers::utils::execute_request, }; use anyhow::{anyhow, Result}; @@ -33,7 +35,7 @@ impl Provider for IntercomConnectionProvider { _connection: &Connection, code: &str, redirect_uri: &str, - ) -> Result { + ) -> Result { let body = json!({ "grant_type": "authorization_code", "client_id": *OAUTH_INTERCOM_CLIENT_ID, @@ -47,7 +49,9 @@ impl Provider for IntercomConnectionProvider { .header("Content-Type", "application/json") .json(&body); - let raw_json = execute_request(ConnectionProvider::Intercom, req).await?; + let raw_json = execute_request(ConnectionProvider::Intercom, req) + .await + .map_err(|e| self.handle_provider_request_error(e))?; let access_token = match raw_json["access_token"].as_str() { Some(token) => token, @@ -64,8 +68,10 @@ impl Provider for IntercomConnectionProvider { }) } - async fn refresh(&self, _connection: &Connection) -> Result { - Err(anyhow!("Intercom access tokens do not expire"))? + async fn refresh(&self, _connection: &Connection) -> Result { + Err(ProviderError::ActionNotSupportedError( + "Intercom access tokens do not expire".to_string(), + ))? } fn scrubbed_raw_json(&self, raw_json: &serde_json::Value) -> Result { diff --git a/core/src/oauth/providers/notion.rs b/core/src/oauth/providers/notion.rs index 786150aaaa2f..42955b26ac69 100644 --- a/core/src/oauth/providers/notion.rs +++ b/core/src/oauth/providers/notion.rs @@ -1,5 +1,7 @@ use crate::oauth::{ - connection::{Connection, ConnectionProvider, FinalizeResult, Provider, RefreshResult}, + connection::{ + Connection, ConnectionProvider, FinalizeResult, Provider, ProviderError, RefreshResult, + }, providers::utils::execute_request, }; use anyhow::{anyhow, Result}; @@ -40,7 +42,7 @@ impl Provider for NotionConnectionProvider { _connection: &Connection, code: &str, redirect_uri: &str, - ) -> Result { + ) -> Result { let body = json!({ "grant_type": "authorization_code", "code": code, @@ -54,7 +56,9 @@ impl Provider for NotionConnectionProvider { .header("Authorization", format!("Basic {}", self.basic_auth())) .json(&body); - let raw_json = execute_request(ConnectionProvider::Notion, req).await?; + let raw_json = execute_request(ConnectionProvider::Notion, req) + .await + .map_err(|e| self.handle_provider_request_error(e))?; let access_token = match raw_json["access_token"].as_str() { Some(token) => token, @@ -71,8 +75,10 @@ impl Provider for NotionConnectionProvider { }) } - async fn refresh(&self, _connection: &Connection) -> Result { - Err(anyhow!("Notion access tokens do not expire"))? + async fn refresh(&self, _connection: &Connection) -> Result { + Err(ProviderError::ActionNotSupportedError( + "Notion access tokens do not expire".to_string(), + ))? } fn scrubbed_raw_json(&self, raw_json: &serde_json::Value) -> Result { diff --git a/core/src/oauth/providers/slack.rs b/core/src/oauth/providers/slack.rs index b2bed5bce5c4..5377ffdfbb54 100644 --- a/core/src/oauth/providers/slack.rs +++ b/core/src/oauth/providers/slack.rs @@ -1,16 +1,13 @@ -use crate::{ - oauth::{ - connection::{ - Connection, - ConnectionProvider, - FinalizeResult, - Provider, - RefreshResult, - // PROVIDER_TIMEOUT_SECONDS, - }, - providers::utils::execute_request, +use crate::oauth::{ + connection::{ + Connection, + ConnectionProvider, + FinalizeResult, + Provider, + ProviderError, + RefreshResult, // PROVIDER_TIMEOUT_SECONDS, }, - // utils, + providers::utils::execute_request, }; use anyhow::{anyhow, Result}; use async_trait::async_trait; @@ -49,7 +46,7 @@ impl Provider for SlackConnectionProvider { _connection: &Connection, code: &str, redirect_uri: &str, - ) -> Result { + ) -> Result { let req = reqwest::Client::new() .post("https://slack.com/api/oauth.v2.access") .header("Content-Type", "application/json; charset=utf-8") @@ -57,13 +54,15 @@ impl Provider for SlackConnectionProvider { // Very important, this will *not* work with JSON body. .form(&[("code", code), ("redirect_uri", redirect_uri)]); - let raw_json = execute_request(ConnectionProvider::Slack, req).await?; + let raw_json = execute_request(ConnectionProvider::Slack, req) + .await + .map_err(|e| self.handle_provider_request_error(e))?; if !raw_json["ok"].as_bool().unwrap_or(false) { - return Err(anyhow!( + return Err(ProviderError::UnknownError(format!( "Slack OAuth error: {}", raw_json["error"].as_str().unwrap_or("Unknown error") - )); + ))); } let access_token = raw_json["access_token"] @@ -85,8 +84,10 @@ impl Provider for SlackConnectionProvider { }) } - async fn refresh(&self, _connection: &Connection) -> Result { - Err(anyhow!("Slack token rotation not implemented."))? + async fn refresh(&self, _connection: &Connection) -> Result { + Err(ProviderError::ActionNotSupportedError( + "Slack access tokens do not expire.".to_string(), + ))? // let refresh_token = connection // .unseal_refresh_token()? // .ok_or_else(|| anyhow!("Missing `refresh_token` in Slack connection"))?; diff --git a/core/src/oauth/providers/utils.rs b/core/src/oauth/providers/utils.rs index f5fbfac76365..55e8210e94ae 100644 --- a/core/src/oauth/providers/utils.rs +++ b/core/src/oauth/providers/utils.rs @@ -2,49 +2,66 @@ use crate::{ oauth::connection::{ConnectionProvider, PROVIDER_TIMEOUT_SECONDS}, utils, }; -use anyhow::{anyhow, Result}; +use anyhow::Result; use hyper::body::Buf; use reqwest::RequestBuilder; use std::io::prelude::*; use std::time::Duration; use tokio::time::timeout; +#[derive(Debug, thiserror::Error)] +pub enum ProviderHttpRequestError { + #[error("Network error: {0}")] + NetworkError(reqwest::Error), + #[error("Timeout error")] + Timeout, + #[error("Request failed for provider {provider}. Status: {status}. {message}")] + RequestFailed { + provider: ConnectionProvider, + status: u16, + message: String, + }, + #[error("Invalid response: {0}")] + InvalidResponse(anyhow::Error), +} + pub async fn execute_request( provider: ConnectionProvider, req: RequestBuilder, -) -> Result { +) -> Result { let now = utils::now_secs(); - let res = match timeout(Duration::new(PROVIDER_TIMEOUT_SECONDS, 0), req.send()).await { - Ok(Ok(res)) => res, - Ok(Err(e)) => Err(e)?, - Err(_) => Err(anyhow!("Timeout sending request: provider={}", provider))?, - }; + let res = timeout(Duration::from_secs(PROVIDER_TIMEOUT_SECONDS), req.send()) + .await + .map_err(|_| ProviderHttpRequestError::Timeout)? + .map_err(|e| ProviderHttpRequestError::NetworkError(e))?; if !res.status().is_success() { - Err(anyhow!( - "Error generating access token: provider={} status={}", - provider, - res.status().as_u16(), - ))?; + let status = res.status(); + let body = res + .text() + .await + .unwrap_or_else(|_| String::from("Unable to read response body")); + + return Err(ProviderHttpRequestError::RequestFailed { + provider: provider, + status: status.as_u16(), + message: body, + }); } - let body = match timeout( - Duration::new(PROVIDER_TIMEOUT_SECONDS - (utils::now_secs() - now), 0), + let body = timeout( + Duration::from_secs(PROVIDER_TIMEOUT_SECONDS - (utils::now_secs() - now)), res.bytes(), ) .await - { - Ok(Ok(body)) => body, - Ok(Err(e)) => Err(e)?, - Err(_) => Err(anyhow!("Timeout reading response from Confluence"))?, - }; + .map_err(|_| ProviderHttpRequestError::Timeout)? + .map_err(|e| ProviderHttpRequestError::NetworkError(e))?; let mut b: Vec = vec![]; - body.reader().read_to_end(&mut b)?; - let c: &[u8] = &b; - - let raw_json: serde_json::Value = serde_json::from_slice(c)?; + body.reader() + .read_to_end(&mut b) + .map_err(|e| ProviderHttpRequestError::InvalidResponse(e.into()))?; - Ok(raw_json) + serde_json::from_slice(&b).map_err(|e| ProviderHttpRequestError::InvalidResponse(e.into())) }