From 40f9b6f2c29e2c862ce7581b2ec5412701ad2cb4 Mon Sep 17 00:00:00 2001 From: Tobin Feldman-Fitzthum Date: Wed, 25 Sep 2024 09:54:43 -0500 Subject: [PATCH] policy: make policy error more general Since we reference this error enum from mod.rs, it should not be rego-specific. The error variants are not specific to OPA, so lift them into mod.rs. Now, someone writing an alternative policy engine can use the same errors. Signed-off-by: Tobin Feldman-Fitzthum --- attestation-service/src/policy_engine/mod.rs | 45 ++++++++-- .../src/policy_engine/opa/mod.rs | 82 ++++++------------- 2 files changed, 63 insertions(+), 64 deletions(-) diff --git a/attestation-service/src/policy_engine/mod.rs b/attestation-service/src/policy_engine/mod.rs index ca72d98e5..0d1f5c294 100644 --- a/attestation-service/src/policy_engine/mod.rs +++ b/attestation-service/src/policy_engine/mod.rs @@ -1,13 +1,48 @@ -use crate::policy_engine::opa::RegoError; use anyhow::Result; use async_trait::async_trait; use serde::Deserialize; use std::collections::HashMap; +use std::io; use std::path::Path; use strum::EnumString; +use thiserror::Error; pub mod opa; +#[derive(Error, Debug)] +pub enum PolicyError { + #[error("Failed to create policy directory: {0}")] + CreatePolicyDirFailed(#[source] io::Error), + #[error("Failed to convert policy directory path to string")] + PolicyDirPathToStringFailed, + #[error("Failed to write default policy: {0}")] + WriteDefaultPolicyFailed(#[source] io::Error), + #[error("Failed to read attestation service policy file: {0}")] + ReadPolicyFileFailed(#[source] io::Error), + #[error("Failed to write attestation service policy to file: {0}")] + WritePolicyFileFailed(#[source] io::Error), + #[error("Failed to load policy: {0}")] + LoadPolicyFailed(#[source] anyhow::Error), + #[error("Policy evaluation denied for {policy_id}")] + PolicyDenied { policy_id: String }, + #[error("Serde json error: {0}")] + SerdeJsonError(#[from] serde_json::Error), + #[error("IO error: {0}")] + IOError(#[from] std::io::Error), + #[error("Base64 decode attestation service policy string failed: {0}")] + Base64DecodeFailed(#[source] base64::DecodeError), + #[error("Illegal policy id. Only support alphabet, numeric, `-` or `_`")] + InvalidPolicyId, + #[error("Failed to load reference data: {0}")] + LoadReferenceDataFailed(#[source] anyhow::Error), + #[error("Failed to set input data: {0}")] + SetInputDataFailed(#[source] anyhow::Error), + #[error("Failed to evaluate policy: {0}")] + EvalPolicyFailed(#[source] anyhow::Error), + #[error("json serialization failed: {0}")] + JsonSerializationFailed(#[source] anyhow::Error), +} + #[derive(Debug, EnumString, Deserialize)] #[strum(ascii_case_insensitive)] pub enum PolicyEngineType { @@ -38,13 +73,13 @@ pub trait PolicyEngine { reference_data_map: HashMap>, input: String, policy_ids: Vec, - ) -> Result, RegoError>; + ) -> Result, PolicyError>; - async fn set_policy(&mut self, policy_id: String, policy: String) -> Result<(), RegoError>; + async fn set_policy(&mut self, policy_id: String, policy: String) -> Result<(), PolicyError>; /// The result is a map. The key is the policy id, and the /// value is the digest of the policy (using **Sha384**). - async fn list_policies(&self) -> Result, RegoError>; + async fn list_policies(&self) -> Result, PolicyError>; - async fn get_policy(&self, policy_id: String) -> Result; + async fn get_policy(&self, policy_id: String) -> Result; } diff --git a/attestation-service/src/policy_engine/opa/mod.rs b/attestation-service/src/policy_engine/opa/mod.rs index 40f66d9d8..4a83c0c93 100644 --- a/attestation-service/src/policy_engine/opa/mod.rs +++ b/attestation-service/src/policy_engine/opa/mod.rs @@ -8,69 +8,33 @@ use base64::Engine; use sha2::{Digest, Sha384}; use std::collections::HashMap; use std::fs; -use std::io; use std::path::PathBuf; -use thiserror::Error; -use super::{PolicyDigest, PolicyEngine}; +use super::{PolicyDigest, PolicyEngine, PolicyError}; #[derive(Debug, Clone)] pub struct OPA { policy_dir_path: PathBuf, } -#[derive(Error, Debug)] -pub enum RegoError { - #[error("Failed to create policy directory: {0}")] - CreatePolicyDirFailed(#[source] io::Error), - #[error("Failed to convert policy directory path to string")] - PolicyDirPathToStringFailed, - #[error("Failed to write default policy: {0}")] - WriteDefaultPolicyFailed(#[source] io::Error), - #[error("Failed to read OPA policy file: {0}")] - ReadPolicyFileFailed(#[source] io::Error), - #[error("Failed to write OPA policy to file: {0}")] - WritePolicyFileFailed(#[source] io::Error), - #[error("Failed to load policy: {0}")] - LoadPolicyFailed(#[source] anyhow::Error), - #[error("Policy evaluation denied for {policy_id}")] - PolicyDenied { policy_id: String }, - #[error("Serde json error: {0}")] - SerdeJsonError(#[from] serde_json::Error), - #[error("IO error: {0}")] - IOError(#[from] std::io::Error), - #[error("Base64 decode OPA policy string failed: {0}")] - Base64DecodeFailed(#[source] base64::DecodeError), - #[error("Illegal policy id. Only support alphabet, numeric, `-` or `_`")] - InvalidPolicyId, - #[error("Failed to load reference data: {0}")] - LoadReferenceDataFailed(#[source] anyhow::Error), - #[error("Failed to set input data: {0}")] - SetInputDataFailed(#[source] anyhow::Error), - #[error("Failed to evaluate policy: {0}")] - EvalPolicyFailed(#[source] anyhow::Error), - #[error("json serialization failed: {0}")] - JsonSerializationFailed(#[source] anyhow::Error), -} - impl OPA { - pub fn new(work_dir: PathBuf) -> Result { + pub fn new(work_dir: PathBuf) -> Result { let mut policy_dir_path = work_dir; policy_dir_path.push("opa"); if !policy_dir_path.as_path().exists() { - fs::create_dir_all(&policy_dir_path).map_err(RegoError::CreatePolicyDirFailed)?; + fs::create_dir_all(&policy_dir_path).map_err(PolicyError::CreatePolicyDirFailed)?; } let mut default_policy_path = PathBuf::from( &policy_dir_path .to_str() - .ok_or_else(|| RegoError::PolicyDirPathToStringFailed)?, + .ok_or_else(|| PolicyError::PolicyDirPathToStringFailed)?, ); default_policy_path.push("default.rego"); if !default_policy_path.as_path().exists() { let policy = std::include_str!("default_policy.rego").to_string(); - fs::write(&default_policy_path, policy).map_err(RegoError::WriteDefaultPolicyFailed)?; + fs::write(&default_policy_path, policy).map_err(PolicyError::WriteDefaultPolicyFailed)?; } Ok(Self { policy_dir_path }) @@ -90,13 +54,13 @@ impl PolicyEngine for OPA { reference_data_map: HashMap>, input: String, policy_ids: Vec, - ) -> Result, RegoError> { + ) -> Result, PolicyError> { let mut res = HashMap::new(); let policy_dir_path = self .policy_dir_path .to_str() - .ok_or_else(|| RegoError::PolicyDirPathToStringFailed)?; + .ok_or_else(|| PolicyError::PolicyDirPathToStringFailed)?; for policy_id in &policy_ids { let input = input.clone(); @@ -104,7 +68,7 @@ impl PolicyEngine for OPA { let policy = tokio::fs::read_to_string(policy_file_path.clone()) .await - .map_err(RegoError::ReadPolicyFileFailed)?; + .map_err(PolicyError::ReadPolicyFileFailed)?; let mut engine = regorus::Engine::new(); @@ -119,27 +83,27 @@ impl PolicyEngine for OPA { // Add policy as data engine .add_policy(policy_id.clone(), policy) - .map_err(RegoError::LoadPolicyFailed)?; + .map_err(PolicyError::LoadPolicyFailed)?; let reference_data_map = serde_json::to_string(&reference_data_map)?; let reference_data_map = regorus::Value::from_json_str(&format!("{{\"reference\":{reference_data_map}}}")) - .map_err(RegoError::JsonSerializationFailed)?; + .map_err(PolicyError::JsonSerializationFailed)?; engine .add_data(reference_data_map) - .map_err(RegoError::LoadReferenceDataFailed)?; + .map_err(PolicyError::LoadReferenceDataFailed)?; // Add TCB claims as input engine .set_input_json(&input) .context("set input") - .map_err(RegoError::SetInputDataFailed)?; + .map_err(PolicyError::SetInputDataFailed)?; let allow = engine .eval_bool_query("data.policy.allow".to_string(), false) - .map_err(RegoError::EvalPolicyFailed)?; + .map_err(PolicyError::EvalPolicyFailed)?; if !allow { - return Err(RegoError::PolicyDenied { + return Err(PolicyError::PolicyDenied { policy_id: policy_id.clone(), }); } @@ -150,30 +114,30 @@ impl PolicyEngine for OPA { Ok(res) } - async fn set_policy(&mut self, policy_id: String, policy: String) -> Result<(), RegoError> { + async fn set_policy(&mut self, policy_id: String, policy: String) -> Result<(), PolicyError> { let policy_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD .decode(policy) - .map_err(RegoError::Base64DecodeFailed)?; + .map_err(PolicyError::Base64DecodeFailed)?; if !Self::is_valid_policy_id(&policy_id) { - return Err(RegoError::InvalidPolicyId); + return Err(PolicyError::InvalidPolicyId); } let mut policy_file_path = PathBuf::from( &self .policy_dir_path .to_str() - .ok_or_else(|| RegoError::PolicyDirPathToStringFailed)?, + .ok_or_else(|| PolicyError::PolicyDirPathToStringFailed)?, ); policy_file_path.push(format!("{}.rego", policy_id)); tokio::fs::write(&policy_file_path, policy_bytes) .await - .map_err(RegoError::WritePolicyFileFailed) + .map_err(PolicyError::WritePolicyFileFailed) } - async fn list_policies(&self) -> Result, RegoError> { + async fn list_policies(&self) -> Result, PolicyError> { let mut policy_ids = Vec::new(); let mut entries = tokio::fs::read_dir(&self.policy_dir_path).await?; while let Some(entry) = entries.next_entry().await? { @@ -193,7 +157,7 @@ impl PolicyEngine for OPA { let policy_file_path = self.policy_dir_path.join(format!("{id}.rego")); let policy = tokio::fs::read(policy_file_path) .await - .map_err(RegoError::ReadPolicyFileFailed)?; + .map_err(PolicyError::ReadPolicyFileFailed)?; let mut hasher = Sha384::new(); hasher.update(policy); @@ -207,11 +171,11 @@ impl PolicyEngine for OPA { Ok(policy_list) } - async fn get_policy(&self, policy_id: String) -> Result { + async fn get_policy(&self, policy_id: String) -> Result { let policy_file_path = self.policy_dir_path.join(format!("{policy_id}.rego")); let policy = tokio::fs::read(policy_file_path) .await - .map_err(RegoError::ReadPolicyFileFailed)?; + .map_err(PolicyError::ReadPolicyFileFailed)?; let base64_policy = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(policy); Ok(base64_policy) }