Skip to content

Commit

Permalink
policy: make policy error more general
Browse files Browse the repository at this point in the history
Since we reference this error enum from mod.rs, it should not
be rego-specific. The error variants are not specific to OPA,
so lift them into mod.rs.

Now, someone writing an alternative policy engine
can use the same errors.

Signed-off-by: Tobin Feldman-Fitzthum <[email protected]>
  • Loading branch information
fitzthum committed Sep 25, 2024
1 parent cdf6cb3 commit 40f9b6f
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 64 deletions.
45 changes: 40 additions & 5 deletions attestation-service/src/policy_engine/mod.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,48 @@
use crate::policy_engine::opa::RegoError;
use anyhow::Result;
use async_trait::async_trait;
use serde::Deserialize;
use std::collections::HashMap;
use std::io;
use std::path::Path;
use strum::EnumString;
use thiserror::Error;

pub mod opa;

#[derive(Error, Debug)]
pub enum PolicyError {
#[error("Failed to create policy directory: {0}")]
CreatePolicyDirFailed(#[source] io::Error),
#[error("Failed to convert policy directory path to string")]
PolicyDirPathToStringFailed,
#[error("Failed to write default policy: {0}")]
WriteDefaultPolicyFailed(#[source] io::Error),
#[error("Failed to read attestation service policy file: {0}")]
ReadPolicyFileFailed(#[source] io::Error),
#[error("Failed to write attestation service policy to file: {0}")]
WritePolicyFileFailed(#[source] io::Error),
#[error("Failed to load policy: {0}")]
LoadPolicyFailed(#[source] anyhow::Error),
#[error("Policy evaluation denied for {policy_id}")]
PolicyDenied { policy_id: String },
#[error("Serde json error: {0}")]
SerdeJsonError(#[from] serde_json::Error),
#[error("IO error: {0}")]
IOError(#[from] std::io::Error),
#[error("Base64 decode attestation service policy string failed: {0}")]
Base64DecodeFailed(#[source] base64::DecodeError),
#[error("Illegal policy id. Only support alphabet, numeric, `-` or `_`")]
InvalidPolicyId,
#[error("Failed to load reference data: {0}")]
LoadReferenceDataFailed(#[source] anyhow::Error),
#[error("Failed to set input data: {0}")]
SetInputDataFailed(#[source] anyhow::Error),
#[error("Failed to evaluate policy: {0}")]
EvalPolicyFailed(#[source] anyhow::Error),
#[error("json serialization failed: {0}")]
JsonSerializationFailed(#[source] anyhow::Error),
}

#[derive(Debug, EnumString, Deserialize)]
#[strum(ascii_case_insensitive)]
pub enum PolicyEngineType {
Expand Down Expand Up @@ -38,13 +73,13 @@ pub trait PolicyEngine {
reference_data_map: HashMap<String, Vec<String>>,
input: String,
policy_ids: Vec<String>,
) -> Result<HashMap<String, PolicyDigest>, RegoError>;
) -> Result<HashMap<String, PolicyDigest>, PolicyError>;

async fn set_policy(&mut self, policy_id: String, policy: String) -> Result<(), RegoError>;
async fn set_policy(&mut self, policy_id: String, policy: String) -> Result<(), PolicyError>;

/// The result is a map. The key is the policy id, and the
/// value is the digest of the policy (using **Sha384**).
async fn list_policies(&self) -> Result<HashMap<String, PolicyDigest>, RegoError>;
async fn list_policies(&self) -> Result<HashMap<String, PolicyDigest>, PolicyError>;

async fn get_policy(&self, policy_id: String) -> Result<String, RegoError>;
async fn get_policy(&self, policy_id: String) -> Result<String, PolicyError>;
}
82 changes: 23 additions & 59 deletions attestation-service/src/policy_engine/opa/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,69 +8,33 @@ use base64::Engine;
use sha2::{Digest, Sha384};
use std::collections::HashMap;
use std::fs;
use std::io;
use std::path::PathBuf;
use thiserror::Error;

use super::{PolicyDigest, PolicyEngine};
use super::{PolicyDigest, PolicyEngine, PolicyError};

#[derive(Debug, Clone)]
pub struct OPA {
policy_dir_path: PathBuf,
}

#[derive(Error, Debug)]
pub enum RegoError {
#[error("Failed to create policy directory: {0}")]
CreatePolicyDirFailed(#[source] io::Error),
#[error("Failed to convert policy directory path to string")]
PolicyDirPathToStringFailed,
#[error("Failed to write default policy: {0}")]
WriteDefaultPolicyFailed(#[source] io::Error),
#[error("Failed to read OPA policy file: {0}")]
ReadPolicyFileFailed(#[source] io::Error),
#[error("Failed to write OPA policy to file: {0}")]
WritePolicyFileFailed(#[source] io::Error),
#[error("Failed to load policy: {0}")]
LoadPolicyFailed(#[source] anyhow::Error),
#[error("Policy evaluation denied for {policy_id}")]
PolicyDenied { policy_id: String },
#[error("Serde json error: {0}")]
SerdeJsonError(#[from] serde_json::Error),
#[error("IO error: {0}")]
IOError(#[from] std::io::Error),
#[error("Base64 decode OPA policy string failed: {0}")]
Base64DecodeFailed(#[source] base64::DecodeError),
#[error("Illegal policy id. Only support alphabet, numeric, `-` or `_`")]
InvalidPolicyId,
#[error("Failed to load reference data: {0}")]
LoadReferenceDataFailed(#[source] anyhow::Error),
#[error("Failed to set input data: {0}")]
SetInputDataFailed(#[source] anyhow::Error),
#[error("Failed to evaluate policy: {0}")]
EvalPolicyFailed(#[source] anyhow::Error),
#[error("json serialization failed: {0}")]
JsonSerializationFailed(#[source] anyhow::Error),
}

impl OPA {
pub fn new(work_dir: PathBuf) -> Result<Self, RegoError> {
pub fn new(work_dir: PathBuf) -> Result<Self, PolicyError> {
let mut policy_dir_path = work_dir;

policy_dir_path.push("opa");
if !policy_dir_path.as_path().exists() {
fs::create_dir_all(&policy_dir_path).map_err(RegoError::CreatePolicyDirFailed)?;
fs::create_dir_all(&policy_dir_path).map_err(PolicyError::CreatePolicyDirFailed)?;
}

let mut default_policy_path = PathBuf::from(
&policy_dir_path
.to_str()
.ok_or_else(|| RegoError::PolicyDirPathToStringFailed)?,
.ok_or_else(|| PolicyError::PolicyDirPathToStringFailed)?,
);
default_policy_path.push("default.rego");
if !default_policy_path.as_path().exists() {
let policy = std::include_str!("default_policy.rego").to_string();
fs::write(&default_policy_path, policy).map_err(RegoError::WriteDefaultPolicyFailed)?;
fs::write(&default_policy_path, policy).map_err(PolicyError::WriteDefaultPolicyFailed)?;
}

Ok(Self { policy_dir_path })
Expand All @@ -90,21 +54,21 @@ impl PolicyEngine for OPA {
reference_data_map: HashMap<String, Vec<String>>,
input: String,
policy_ids: Vec<String>,
) -> Result<HashMap<String, PolicyDigest>, RegoError> {
) -> Result<HashMap<String, PolicyDigest>, PolicyError> {
let mut res = HashMap::new();

let policy_dir_path = self
.policy_dir_path
.to_str()
.ok_or_else(|| RegoError::PolicyDirPathToStringFailed)?;
.ok_or_else(|| PolicyError::PolicyDirPathToStringFailed)?;

for policy_id in &policy_ids {
let input = input.clone();
let policy_file_path = format!("{policy_dir_path}/{policy_id}.rego");

let policy = tokio::fs::read_to_string(policy_file_path.clone())
.await
.map_err(RegoError::ReadPolicyFileFailed)?;
.map_err(PolicyError::ReadPolicyFileFailed)?;

let mut engine = regorus::Engine::new();

Expand All @@ -119,27 +83,27 @@ impl PolicyEngine for OPA {
// Add policy as data
engine
.add_policy(policy_id.clone(), policy)
.map_err(RegoError::LoadPolicyFailed)?;
.map_err(PolicyError::LoadPolicyFailed)?;

let reference_data_map = serde_json::to_string(&reference_data_map)?;
let reference_data_map =
regorus::Value::from_json_str(&format!("{{\"reference\":{reference_data_map}}}"))
.map_err(RegoError::JsonSerializationFailed)?;
.map_err(PolicyError::JsonSerializationFailed)?;
engine
.add_data(reference_data_map)
.map_err(RegoError::LoadReferenceDataFailed)?;
.map_err(PolicyError::LoadReferenceDataFailed)?;

// Add TCB claims as input
engine
.set_input_json(&input)
.context("set input")
.map_err(RegoError::SetInputDataFailed)?;
.map_err(PolicyError::SetInputDataFailed)?;

let allow = engine
.eval_bool_query("data.policy.allow".to_string(), false)
.map_err(RegoError::EvalPolicyFailed)?;
.map_err(PolicyError::EvalPolicyFailed)?;
if !allow {
return Err(RegoError::PolicyDenied {
return Err(PolicyError::PolicyDenied {
policy_id: policy_id.clone(),
});
}
Expand All @@ -150,30 +114,30 @@ impl PolicyEngine for OPA {
Ok(res)
}

async fn set_policy(&mut self, policy_id: String, policy: String) -> Result<(), RegoError> {
async fn set_policy(&mut self, policy_id: String, policy: String) -> Result<(), PolicyError> {
let policy_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(policy)
.map_err(RegoError::Base64DecodeFailed)?;
.map_err(PolicyError::Base64DecodeFailed)?;

if !Self::is_valid_policy_id(&policy_id) {
return Err(RegoError::InvalidPolicyId);
return Err(PolicyError::InvalidPolicyId);
}

let mut policy_file_path = PathBuf::from(
&self
.policy_dir_path
.to_str()
.ok_or_else(|| RegoError::PolicyDirPathToStringFailed)?,
.ok_or_else(|| PolicyError::PolicyDirPathToStringFailed)?,
);

policy_file_path.push(format!("{}.rego", policy_id));

tokio::fs::write(&policy_file_path, policy_bytes)
.await
.map_err(RegoError::WritePolicyFileFailed)
.map_err(PolicyError::WritePolicyFileFailed)
}

async fn list_policies(&self) -> Result<HashMap<String, PolicyDigest>, RegoError> {
async fn list_policies(&self) -> Result<HashMap<String, PolicyDigest>, PolicyError> {
let mut policy_ids = Vec::new();
let mut entries = tokio::fs::read_dir(&self.policy_dir_path).await?;
while let Some(entry) = entries.next_entry().await? {
Expand All @@ -193,7 +157,7 @@ impl PolicyEngine for OPA {
let policy_file_path = self.policy_dir_path.join(format!("{id}.rego"));
let policy = tokio::fs::read(policy_file_path)
.await
.map_err(RegoError::ReadPolicyFileFailed)?;
.map_err(PolicyError::ReadPolicyFileFailed)?;

let mut hasher = Sha384::new();
hasher.update(policy);
Expand All @@ -207,11 +171,11 @@ impl PolicyEngine for OPA {
Ok(policy_list)
}

async fn get_policy(&self, policy_id: String) -> Result<String, RegoError> {
async fn get_policy(&self, policy_id: String) -> Result<String, PolicyError> {
let policy_file_path = self.policy_dir_path.join(format!("{policy_id}.rego"));
let policy = tokio::fs::read(policy_file_path)
.await
.map_err(RegoError::ReadPolicyFileFailed)?;
.map_err(PolicyError::ReadPolicyFileFailed)?;
let base64_policy = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(policy);
Ok(base64_policy)
}
Expand Down

0 comments on commit 40f9b6f

Please sign in to comment.