Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AA: kbs: Add supported hash algorithms to Request #712

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions attestation-agent/attestation-agent/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ required-features = ["bin", "ttrpc"]
anyhow.workspace = true
async-trait.workspace = true
attester = { path = "../attester", default-features = false }
crypto = { path = "../deps/crypto" }
base64.workspace = true
clap = { workspace = true, features = ["derive"], optional = true }
config.workspace = true
Expand Down
31 changes: 1 addition & 30 deletions attestation-agent/attestation-agent/src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
//

use anyhow::Result;
use crypto::HashAlgorithm;
use serde::Deserialize;
use sha2::{Digest, Sha256, Sha384, Sha512};

/// Default PCR index used by AA. `17` is selected for its usage of dynamic root of trust for measurement.
/// - [Linux TPM PCR Registry](https://uapi-group.org/specifications/specs/linux_tpm_pcr_registry/)
Expand All @@ -24,35 +24,6 @@ pub const DEFAULT_AA_CONFIG_PATH: &str = "/etc/attestation-agent.conf";

pub const DEFAULT_EVENTLOG_HASH: &str = "sha384";

/// Hash algorithms used to calculate runtime/init data binding
#[derive(Deserialize, Clone, Debug, Copy)]
#[serde(rename_all = "lowercase")]
pub enum HashAlgorithm {
Sha256,
Sha384,
Sha512,
}

impl Default for HashAlgorithm {
fn default() -> Self {
Self::Sha384
}
}

fn hash_reportdata<D: Digest>(material: &[u8]) -> Vec<u8> {
D::new().chain_update(material).finalize().to_vec()
}

impl HashAlgorithm {
pub fn digest(&self, material: &[u8]) -> Vec<u8> {
match self {
HashAlgorithm::Sha256 => hash_reportdata::<Sha256>(material),
HashAlgorithm::Sha384 => hash_reportdata::<Sha384>(material),
HashAlgorithm::Sha512 => hash_reportdata::<Sha512>(material),
}
}
}

#[derive(Clone, Debug, Deserialize)]
pub struct Config {
/// configs about token
Expand Down
3 changes: 1 addition & 2 deletions attestation-agent/attestation-agent/src/eventlog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::{fmt::Display, fs::File, io::Write};
use anyhow::{bail, Context, Result};
use const_format::concatcp;

use crate::config::HashAlgorithm;
use crypto::HashAlgorithm;

/// AA's eventlog will be put into this parent directory
pub const EVENTLOG_PARENT_DIR_PATH: &str = "/run/attestation-agent";
Expand Down Expand Up @@ -109,7 +109,6 @@ impl<'a> Display for LogEntry<'a> {
#[cfg(test)]
mod tests {
use super::*;
use crate::config::HashAlgorithm;
use rstest::rstest;
use std::sync::{Arc, Mutex};

Expand Down
75 changes: 75 additions & 0 deletions attestation-agent/deps/crypto/src/algorithms.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright (c) 2024 Alibaba Cloud
// Copyright (c) 2024 Intel Corporation
//
// SPDX-License-Identifier: Apache-2.0
//

use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256, Sha384, Sha512};
use std::fmt;
use std::str::FromStr;

/// Hash algorithms used to calculate runtime/init data binding
#[derive(Serialize, Deserialize, Clone, Debug, Display, Copy)]
#[serde(rename_all = "lowercase")]
pub enum HashAlgorithm {
Sha256,
Sha384,
Sha512,
}

impl Default for HashAlgorithm {
fn default() -> Self {
Self::Sha384
}
}

fn hash_reportdata<D: Digest>(material: &[u8]) -> Vec<u8> {
D::new().chain_update(material).finalize().to_vec()
}

impl HashAlgorithm {
pub fn digest(&self, material: &[u8]) -> Vec<u8> {
match self {
HashAlgorithm::Sha256 => hash_reportdata::<Sha256>(material),
HashAlgorithm::Sha384 => hash_reportdata::<Sha384>(material),
HashAlgorithm::Sha512 => hash_reportdata::<Sha512>(material),
}
}

/// Return a list of all supported hash algorithms.
pub fn list_all() -> Vec<Self> {
vec![
HashAlgorithm::Sha256,
HashAlgorithm::Sha384,
HashAlgorithm::Sha512,
]
}
}

#[derive(Debug, PartialEq, Eq)]
pub struct ParseHashAlgorithmError;

// XXX: Required to allow conversion to a std::error::Error by `anyhow!()`.
impl fmt::Display for ParseHashAlgorithmError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "ParseHashAlgorithmError")
}
}

impl FromStr for HashAlgorithm {
type Err = ParseHashAlgorithmError;

fn from_str(s: &str) -> Result<Self, Self::Err> {
let cleaned = s.replace('-', "").to_lowercase();

let result = match cleaned.as_str() {
"sha256" => HashAlgorithm::Sha256,
"sha384" => HashAlgorithm::Sha384,
"sha512" => HashAlgorithm::Sha512,
_ => return Err(ParseHashAlgorithmError),
};

Ok(result)
}
}
3 changes: 3 additions & 0 deletions attestation-agent/deps/crypto/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,6 @@ pub use symmetric::*;

mod asymmetric;
pub use asymmetric::*;

mod algorithms;
pub use algorithms::*;
154 changes: 137 additions & 17 deletions attestation-agent/kbs_protocol/src/client/rcar_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ use std::time::Duration;

use anyhow::{bail, Context};
use async_trait::async_trait;
use crypto::HashAlgorithm;
use kbs_types::{Attestation, Challenge, ErrorInformation, Request, Response, Tee};
use log::{debug, warn};
use resource_uri::ResourceUri;
use serde::Deserialize;
use serde_json::json;
use sha2::{Digest, Sha384};

use crate::{
api::KbsClientCapabilities,
Expand All @@ -32,12 +32,43 @@ const RCAR_MAX_ATTEMPT: i32 = 5;
/// The interval (seconds) between RCAR handshake retries.
const RCAR_RETRY_TIMEOUT_SECOND: u64 = 1;

/// JSON object added to a 'Request's extra parameters.
const SUPPORTED_HASH_ALGORITHMS_JSON_KEY: &str = "supported-hash-algorithms";

/// JSON object returned in the Challenge whose value is based on
/// SUPPORTED_HASH_ALGORITHMS_JSON_KEY and the TEE.
const SELECTED_HASH_ALGORITHM_JSON_KEY: &str = "selected-hash-algorithm";

/// Hash algorithm to use by default.
const DEFAULT_HASH_ALGORITHM: HashAlgorithm = HashAlgorithm::Sha384;

#[derive(Deserialize, Debug, Clone)]
struct AttestationResponseData {
// Attestation token in JWT format
token: String,
}

async fn get_request_extra_params() -> serde_json::Value {
let supported_hash_algorithms = HashAlgorithm::list_all();

let extra_params = json!({SUPPORTED_HASH_ALGORITHMS_JSON_KEY: supported_hash_algorithms});

extra_params
}

async fn build_request(tee: Tee) -> Request {
let extra_params = get_request_extra_params().await;

// Note that the Request includes the list of supported hash algorithms.
// The Challenge response will return which TEE-specific algorithm should
// be used for future communications.
Request {
version: String::from(KBS_PROTOCOL_VERSION),
tee,
extra_params,
}
}

impl KbsClient<Box<dyn EvidenceProvider>> {
/// Get a [`TeeKeyPair`] and a [`Token`] that certifies the [`TeeKeyPair`].
/// If the client does not already have token or the token is invalid,
Expand Down Expand Up @@ -101,13 +132,9 @@ impl KbsClient<Box<dyn EvidenceProvider>> {
ClientTee::_Initializated(tee) => *tee,
};

let request = Request {
version: String::from(KBS_PROTOCOL_VERSION),
tee,
extra_params: serde_json::Value::String(String::new()),
};
let request = build_request(tee).await;

debug!("send auth request to {auth_endpoint}");
debug!("send auth request {request:?} to {auth_endpoint}");

let resp = self
.http_client
Expand Down Expand Up @@ -138,6 +165,23 @@ impl KbsClient<Box<dyn EvidenceProvider>> {

let challenge = resp.json::<Challenge>().await?;
debug!("get challenge: {challenge:#?}");

let extra_params = challenge.extra_params;

let algorithm = match extra_params.get(SELECTED_HASH_ALGORITHM_JSON_KEY) {
Some(selected_hash_algorithm) => {
// Note the blank string which will be handled as an error when parsed.
let name = selected_hash_algorithm
.as_str()
.unwrap_or("")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: i'd use .ok_or(Error::NotAString(selected_hash_algorithm))? or something, otherwise we might swallow a malformed extra_params body

.to_lowercase();

name.parse::<HashAlgorithm>()
.map_err(|_| Error::InvalidHashAlgorithm(name))?
}
None => DEFAULT_HASH_ALGORITHM,
};

let tee_pubkey = self.tee_key.export_pubkey()?;
let runtime_data = json!({
"tee-pubkey": tee_pubkey,
Expand All @@ -146,7 +190,7 @@ impl KbsClient<Box<dyn EvidenceProvider>> {
let runtime_data =
serde_json::to_string(&runtime_data).context("serialize runtime data failed")?;
let evidence = self
.generate_evidence(tee, runtime_data, challenge.nonce)
.generate_evidence(tee, runtime_data, challenge.nonce, algorithm)
.await?;
debug!("get evidence with challenge: {evidence}");

Expand Down Expand Up @@ -186,25 +230,42 @@ impl KbsClient<Box<dyn EvidenceProvider>> {
Ok(())
}

async fn generate_evidence(
/// Convert the runtime data and the nonce into a hashed representation using the
/// specified hash algorithm.
async fn hash_runtime_data(
&self,
tee: Tee,
runtime_data: String,
nonce: String,
) -> Result<String> {
debug!("Challenge nonce: {nonce}");
let mut hasher = Sha384::new();
hasher.update(runtime_data);
tee: Tee,
algorithm: HashAlgorithm,
) -> Result<Vec<u8>> {
debug!("Hashing {tee:?} runtime data using nonce {nonce} and algorithm {algorithm:?}");

let ehd = match tee {
let hashed_data = match tee {
// IBM SE uses nonce as runtime_data to pass attestation_request
Tee::Se => nonce.into_bytes(),
_ => hasher.finalize().to_vec(),
_ => algorithm.digest(runtime_data.as_bytes()),
};

Ok(hashed_data)
}

async fn generate_evidence(
&self,
tee: Tee,
runtime_data: String,
nonce: String,
algorithm: HashAlgorithm,
) -> Result<String> {
debug!("Challenge nonce: {nonce}, algorithm: {algorithm:?}");

let hashed_data = self
.hash_runtime_data(runtime_data, nonce, tee, algorithm)
.await?;

let tee_evidence = self
.provider
.get_evidence(ehd)
.get_evidence(hashed_data)
.await
.context("Get TEE evidence failed")
.map_err(|e| Error::GetEvidence(e.to_string()))?;
Expand Down Expand Up @@ -288,6 +349,7 @@ impl KbsClientCapabilities for KbsClient<Box<dyn EvidenceProvider>> {

#[cfg(test)]
mod test {
use crypto::HashAlgorithm;
use std::{env, path::PathBuf, time::Duration};
use testcontainers::{clients, images::generic::GenericImage};
use tokio::fs;
Expand All @@ -296,6 +358,12 @@ mod test {
evidence_provider::NativeEvidenceProvider, KbsClientBuilder, KbsClientCapabilities,
};

use crate::client::rcar_client::{
build_request, get_request_extra_params, KBS_PROTOCOL_VERSION,
SUPPORTED_HASH_ALGORITHMS_JSON_KEY,
};
use kbs_types::Tee;

const CONTENT: &[u8] = b"test content";

#[tokio::test]
Expand Down Expand Up @@ -382,4 +450,56 @@ mod test {
println!("Get token : {token:?}");
println!("Get key: {key:?}");
}

#[tokio::test]
#[serial_test::serial]
async fn test_get_request_extra_params() {
let extra_params = get_request_extra_params().await;

assert!(extra_params.is_object());

let algos_json = extra_params
.get(SUPPORTED_HASH_ALGORITHMS_JSON_KEY)
.unwrap();
assert!(algos_json.is_array());

let algos = algos_json.as_array().unwrap();

let expected_algos = HashAlgorithm::list_all();
let expected_length: usize = expected_algos.len();

assert!(expected_length > 0);

for algo in algos {
let result = algos.contains(algo);
assert!(result);
}
}

#[tokio::test]
#[serial_test::serial]
async fn test_build_request() {
let tees = vec![
Tee::AzSnpVtpm,
Tee::AzTdxVtpm,
Tee::Cca,
Tee::Csv,
Tee::Se,
Tee::Sev,
Tee::Sgx,
Tee::Snp,
Tee::Tdx,
];

let expected_version = String::from(KBS_PROTOCOL_VERSION);
let expected_extra_params = get_request_extra_params().await;

for tee in tees {
let request = build_request(tee).await;

assert_eq!(request.version, expected_version);
assert_eq!(request.tee, tee);
assert_eq!(request.extra_params, expected_extra_params);
}
}
}
Loading
Loading