Skip to content

Commit

Permalink
Improve provider errors in oauth service
Browse files Browse the repository at this point in the history
  • Loading branch information
flvndvd committed Jul 19, 2024
1 parent 74710f9 commit 199d7a0
Show file tree
Hide file tree
Showing 8 changed files with 185 additions and 71 deletions.
84 changes: 78 additions & 6 deletions core/src/oauth/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ use std::time::Duration;
use std::{env, fmt};
use tracing::{error, info};

use super::providers::utils::ProviderHttpRequestError;

// We hold the lock for at most 15s. In case of panic preventing the lock from being released, this
// is the maximum time the lock will be held.
static REDIS_LOCK_TTL_SECONDS: u64 = 15;
Expand All @@ -40,12 +42,15 @@ lazy_static! {
};
}

// API Error types.

// Define the ErrorKind enum with serde attributes for serialization
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ConnectionErrorCode {
// Finalize
ConnectionAlreadyFinalizedError,
ProviderInvalidToken,
ProviderFinalizationError,
// Refresh Access Token
ConnectionNotFinalizedError,
Expand Down Expand Up @@ -128,14 +133,33 @@ pub trait Provider {
connection: &Connection,
code: &str,
redirect_uri: &str,
) -> Result<FinalizeResult>;
) -> Result<FinalizeResult, ProviderError>;

async fn refresh(&self, connection: &Connection) -> Result<RefreshResult>;
async fn refresh(&self, connection: &Connection) -> Result<RefreshResult, ProviderError>;

// This method scrubs raw_json to remove information that should not exfill `oauth`, in
// particular the `refresh_token`. By convetion the `access_token` should be scrubbed as well
// to prevent users from relying in the raw_json to access it.
fn scrubbed_raw_json(&self, raw_json: &serde_json::Value) -> Result<serde_json::Value>;

// Default implementation for handling errors.
fn handle_provider_request_error(&self, error: ProviderHttpRequestError) -> ProviderError {
match error {
ProviderHttpRequestError::NetworkError(e) => ProviderError::UnknownError(e.to_string()),
ProviderHttpRequestError::Timeout => ProviderError::TimeoutError,
ProviderHttpRequestError::RequestFailed {
provider,
status,
message: _,
} => ProviderError::UnknownError(format!(
"Request failed for provider {}. Status: {}.",
provider, status
)),
ProviderHttpRequestError::InvalidResponse(e) => {
ProviderError::UnknownError(e.to_string())
}
}
}
}

pub fn provider(t: ConnectionProvider) -> Box<dyn Provider + Sync + Send> {
Expand All @@ -149,6 +173,49 @@ pub fn provider(t: ConnectionProvider) -> Box<dyn Provider + Sync + Send> {
}
}

// Internal Error types.

#[derive(Debug, thiserror::Error)]
pub enum ProviderError {
#[error("Action not supported: {0}.")]
ActionNotSupportedError(String),
#[error("Invalid token.")]
InvalidToken,
#[error("Token expired.")]
TokenExpired,
#[error("Timeout error.")]
TimeoutError,
#[error("Unknown error: {0}.")]
UnknownError(String),
#[error("Internal error: {0}.")]
InternalError(anyhow::Error),
}

impl From<anyhow::Error> for ProviderError {
fn from(error: anyhow::Error) -> Self {
ProviderError::InternalError(error)
}
}

impl From<&ProviderError> for ConnectionError {
fn from(err: &ProviderError) -> Self {
match err {
ProviderError::ActionNotSupportedError(_)
| ProviderError::InvalidToken
| ProviderError::TokenExpired
| ProviderError::TimeoutError
| ProviderError::UnknownError(_) => ConnectionError {
code: ConnectionErrorCode::ProviderFinalizationError,
message: err.to_string(),
},
ProviderError::InternalError(_) => ConnectionError {
code: ConnectionErrorCode::InternalError,
message: "Failed to finalize connection with provider".to_string(),
},
}
}
}

#[derive(Debug, Clone, Copy, Serialize, PartialEq, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ConnectionStatus {
Expand Down Expand Up @@ -558,10 +625,15 @@ impl Connection {
provider = ?self.provider,
"Failed to finalize connection",
);
Err(ConnectionError {
code: ConnectionErrorCode::ProviderFinalizationError,
message: "Failed to finalize connection with provider".to_string(),
})

if let Some(provider_error) = e.downcast_ref::<ProviderError>() {
Err(ConnectionError::from(provider_error))
} else {
Err(ConnectionError {
code: ConnectionErrorCode::InternalError,
message: "Failed to finalize connection with provider".to_string(),
})
}
}
}
}
Expand Down
14 changes: 9 additions & 5 deletions core/src/oauth/providers/confluence.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
oauth::{
connection::{
Connection, ConnectionProvider, FinalizeResult, Provider, RefreshResult,
Connection, ConnectionProvider, FinalizeResult, Provider, ProviderError, RefreshResult,
PROVIDER_TIMEOUT_SECONDS,
},
providers::utils::execute_request,
Expand Down Expand Up @@ -40,7 +40,7 @@ impl Provider for ConfluenceConnectionProvider {
_connection: &Connection,
code: &str,
redirect_uri: &str,
) -> Result<FinalizeResult> {
) -> Result<FinalizeResult, ProviderError> {
let body = json!({
"grant_type": "authorization_code",
"client_id": *OAUTH_CONFLUENCE_CLIENT_ID,
Expand All @@ -54,7 +54,9 @@ impl Provider for ConfluenceConnectionProvider {
.header("Content-Type", "application/json")
.json(&body);

let raw_json = execute_request(ConnectionProvider::Confluence, req).await?;
let raw_json = execute_request(ConnectionProvider::Confluence, req)
.await
.map_err(|e| self.handle_provider_request_error(e))?;

let access_token = match raw_json["access_token"].as_str() {
Some(token) => token,
Expand Down Expand Up @@ -92,7 +94,7 @@ impl Provider for ConfluenceConnectionProvider {
/// Note: Confluence hard expires refresh_tokens after 360 days.
/// Confluence expires access_tokens after 1 hour.
/// Confluence expires refresh_tokens after 30 days of inactivity.
async fn refresh(&self, connection: &Connection) -> Result<RefreshResult> {
async fn refresh(&self, connection: &Connection) -> Result<RefreshResult, ProviderError> {
let refresh_token = match connection.unseal_refresh_token() {
Ok(Some(token)) => token,
Ok(None) => Err(anyhow!("Missing `refresh_token` in Confluence connection"))?,
Expand All @@ -111,7 +113,9 @@ impl Provider for ConfluenceConnectionProvider {
.header("Content-Type", "application/json")
.json(&body);

let raw_json = execute_request(ConnectionProvider::Confluence, req).await?;
let raw_json = execute_request(ConnectionProvider::Confluence, req)
.await
.map_err(|e| self.handle_provider_request_error(e))?;

let access_token = match raw_json["access_token"].as_str() {
Some(token) => token,
Expand Down
10 changes: 7 additions & 3 deletions core/src/oauth/providers/github.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::{
oauth::{
connection::{Connection, ConnectionProvider, FinalizeResult, Provider, RefreshResult},
connection::{
Connection, ConnectionProvider, FinalizeResult, Provider, ProviderError, RefreshResult,
},
providers::utils::execute_request,
},
utils,
Expand Down Expand Up @@ -102,7 +104,7 @@ impl Provider for GithubConnectionProvider {
_connection: &Connection,
code: &str,
redirect_uri: &str,
) -> Result<FinalizeResult> {
) -> Result<FinalizeResult, ProviderError> {
// `code` is the installation_id returned by Github.
let (token, expiry, raw_json) = self.refresh_token(code).await?;

Expand All @@ -117,7 +119,7 @@ impl Provider for GithubConnectionProvider {
})
}

async fn refresh(&self, connection: &Connection) -> Result<RefreshResult> {
async fn refresh(&self, connection: &Connection) -> Result<RefreshResult, ProviderError> {
// `code` is the installation_id returned by Github.
let code = match connection.unseal_authorization_code()? {
Some(code) => code,
Expand Down Expand Up @@ -145,4 +147,6 @@ impl Provider for GithubConnectionProvider {
};
Ok(raw_json)
}

// TODO(2024-07-19 flav) Implement custom error handling for GitHub.
}
14 changes: 9 additions & 5 deletions core/src/oauth/providers/google_drive.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
oauth::{
connection::{
Connection, ConnectionProvider, FinalizeResult, Provider, RefreshResult,
Connection, ConnectionProvider, FinalizeResult, Provider, ProviderError, RefreshResult,
PROVIDER_TIMEOUT_SECONDS,
},
providers::utils::execute_request,
Expand Down Expand Up @@ -40,7 +40,7 @@ impl Provider for GoogleDriveConnectionProvider {
_connection: &Connection,
code: &str,
redirect_uri: &str,
) -> Result<FinalizeResult> {
) -> Result<FinalizeResult, ProviderError> {
let body = json!({
"grant_type": "authorization_code",
"client_id": *OAUTH_GOOGLE_DRIVE_CLIENT_ID,
Expand All @@ -54,7 +54,9 @@ impl Provider for GoogleDriveConnectionProvider {
.header("Content-Type", "application/json")
.json(&body);

let raw_json = execute_request(ConnectionProvider::GoogleDrive, req).await?;
let raw_json = execute_request(ConnectionProvider::GoogleDrive, req)
.await
.map_err(|e| self.handle_provider_request_error(e))?;

let access_token = match raw_json["access_token"].as_str() {
Some(token) => token,
Expand Down Expand Up @@ -97,7 +99,7 @@ impl Provider for GoogleDriveConnectionProvider {
// Google Drive does not automatically expire refresh tokens for published apps,
// unless they have been unused for six months.
// Acess tokens expire after 1 hour.
async fn refresh(&self, connection: &Connection) -> Result<RefreshResult> {
async fn refresh(&self, connection: &Connection) -> Result<RefreshResult, ProviderError> {
let refresh_token = match connection.unseal_refresh_token() {
Ok(Some(token)) => token,
Ok(None) => Err(anyhow!(
Expand All @@ -118,7 +120,9 @@ impl Provider for GoogleDriveConnectionProvider {
.header("Content-Type", "application/json")
.json(&body);

let raw_json = execute_request(ConnectionProvider::GoogleDrive, req).await?;
let raw_json = execute_request(ConnectionProvider::GoogleDrive, req)
.await
.map_err(|e| self.handle_provider_request_error(e))?;

let access_token = match raw_json["access_token"].as_str() {
Some(token) => token,
Expand Down
16 changes: 11 additions & 5 deletions core/src/oauth/providers/intercom.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use crate::oauth::{
connection::{Connection, ConnectionProvider, FinalizeResult, Provider, RefreshResult},
connection::{
Connection, ConnectionProvider, FinalizeResult, Provider, ProviderError, RefreshResult,
},
providers::utils::execute_request,
};
use anyhow::{anyhow, Result};
Expand Down Expand Up @@ -33,7 +35,7 @@ impl Provider for IntercomConnectionProvider {
_connection: &Connection,
code: &str,
redirect_uri: &str,
) -> Result<FinalizeResult> {
) -> Result<FinalizeResult, ProviderError> {
let body = json!({
"grant_type": "authorization_code",
"client_id": *OAUTH_INTERCOM_CLIENT_ID,
Expand All @@ -47,7 +49,9 @@ impl Provider for IntercomConnectionProvider {
.header("Content-Type", "application/json")
.json(&body);

let raw_json = execute_request(ConnectionProvider::Intercom, req).await?;
let raw_json = execute_request(ConnectionProvider::Intercom, req)
.await
.map_err(|e| self.handle_provider_request_error(e))?;

let access_token = match raw_json["access_token"].as_str() {
Some(token) => token,
Expand All @@ -64,8 +68,10 @@ impl Provider for IntercomConnectionProvider {
})
}

async fn refresh(&self, _connection: &Connection) -> Result<RefreshResult> {
Err(anyhow!("Intercom access tokens do not expire"))?
async fn refresh(&self, _connection: &Connection) -> Result<RefreshResult, ProviderError> {
Err(ProviderError::ActionNotSupportedError(
"Intercom access tokens do not expire".to_string(),
))?
}

fn scrubbed_raw_json(&self, raw_json: &serde_json::Value) -> Result<serde_json::Value> {
Expand Down
16 changes: 11 additions & 5 deletions core/src/oauth/providers/notion.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use crate::oauth::{
connection::{Connection, ConnectionProvider, FinalizeResult, Provider, RefreshResult},
connection::{
Connection, ConnectionProvider, FinalizeResult, Provider, ProviderError, RefreshResult,
},
providers::utils::execute_request,
};
use anyhow::{anyhow, Result};
Expand Down Expand Up @@ -40,7 +42,7 @@ impl Provider for NotionConnectionProvider {
_connection: &Connection,
code: &str,
redirect_uri: &str,
) -> Result<FinalizeResult> {
) -> Result<FinalizeResult, ProviderError> {
let body = json!({
"grant_type": "authorization_code",
"code": code,
Expand All @@ -54,7 +56,9 @@ impl Provider for NotionConnectionProvider {
.header("Authorization", format!("Basic {}", self.basic_auth()))
.json(&body);

let raw_json = execute_request(ConnectionProvider::Notion, req).await?;
let raw_json = execute_request(ConnectionProvider::Notion, req)
.await
.map_err(|e| self.handle_provider_request_error(e))?;

let access_token = match raw_json["access_token"].as_str() {
Some(token) => token,
Expand All @@ -71,8 +75,10 @@ impl Provider for NotionConnectionProvider {
})
}

async fn refresh(&self, _connection: &Connection) -> Result<RefreshResult> {
Err(anyhow!("Notion access tokens do not expire"))?
async fn refresh(&self, _connection: &Connection) -> Result<RefreshResult, ProviderError> {
Err(ProviderError::ActionNotSupportedError(
"Notion access tokens do not expire".to_string(),
))?
}

fn scrubbed_raw_json(&self, raw_json: &serde_json::Value) -> Result<serde_json::Value> {
Expand Down
Loading

0 comments on commit 199d7a0

Please sign in to comment.