Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 213 additions & 5 deletions attested-tls/src/attestation/pccs.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,34 @@
use std::{
collections::HashMap,
sync::{Arc, Weak},
sync::{
Arc, Weak,
atomic::{AtomicBool, AtomicUsize, Ordering},
},
time::{SystemTime, UNIX_EPOCH},
};

use anyhow::Context;
use dcap_qvl::{QuoteCollateralV3, collateral::get_collateral_for_fmspc, tcb_info::TcbInfo};
use time::{OffsetDateTime, format_description::well_known::Rfc3339};
use tokio::{
sync::RwLock,
task::JoinHandle,
runtime::Handle,
sync::{RwLock, Semaphore},
task::{JoinHandle, JoinSet},
time::{Duration, sleep},
};

use crate::attestation::dcap::{DcapVerificationError, PCS_URL};

const REFRESH_MARGIN_SECS: i64 = 300;
const REFRESH_RETRY_SECS: u64 = 60;
const STARTUP_PREWARM_CONCURRENCY: usize = 8;

/// PCCS collateral cache with proactive background refresh
#[derive(Clone)]
pub struct Pccs {
pccs_url: String,
cache: Arc<RwLock<HashMap<PccsInput, CacheEntry>>>,
prewarm_stats: Arc<PrewarmStats>,
}

impl std::fmt::Debug for Pccs {
Expand All @@ -36,10 +43,13 @@ impl std::fmt::Debug for Pccs {
impl Pccs {
/// Creates a new PCCS cache using the provided URL or Intel PCS default
pub fn new(pccs_url: Option<String>) -> Self {
Self {
let pccs = Self {
pccs_url: pccs_url.unwrap_or(PCS_URL.to_string()),
cache: RwLock::new(HashMap::new()).into(),
}
prewarm_stats: Arc::new(PrewarmStats::default()),
};
pccs.try_spawn_startup_prewarm();
pccs
}

/// Returns collateral from cache when valid, otherwise fetches and caches fresh collateral
Expand Down Expand Up @@ -128,6 +138,143 @@ impl Pccs {
refresh_loop(weak_cache, pccs_url, key).await;
}));
}

/// Attempts to spawn startup pre-provisioning without blocking constructor
fn try_spawn_startup_prewarm(&self) {
match Handle::try_current() {
Ok(handle) => {
let pccs = self.clone();
handle.spawn(async move {
pccs.startup_prewarm_all_tdx().await;
});
}
Err(_) => {
tracing::warn!("No Tokio runtime available, skipping PCCS startup pre-provision");
self.prewarm_stats.completed.store(true, Ordering::SeqCst);
}
}
}

/// Pre-provisions TDX collateral for discovered FMSPC values to reduce hot-path fetches
async fn startup_prewarm_all_tdx(&self) {
let fmspcs = match self.fetch_fmspcs().await {
Ok(fmspcs) => fmspcs,
Err(e) => {
tracing::warn!(error = %e, "Failed to fetch FMSPC list for startup pre-provision");
self.prewarm_stats.completed.store(true, Ordering::SeqCst);
return;
}
};
self.prewarm_stats
.discovered_fmspcs
.store(fmspcs.len(), Ordering::SeqCst);

if fmspcs.is_empty() {
tracing::warn!("No FMSPC entries returned during startup pre-provision");
self.prewarm_stats.completed.store(true, Ordering::SeqCst);
return;
}

let semaphore = Arc::new(Semaphore::new(STARTUP_PREWARM_CONCURRENCY));
let mut join_set = JoinSet::new();
for entry in fmspcs {
for ca in ["processor", "platform"] {
let permit = semaphore.clone().acquire_owned().await;
let Ok(permit) = permit else {
continue;
};
self.prewarm_stats.attempted.fetch_add(1, Ordering::SeqCst);
let pccs = self.clone();
let fmspc = entry.fmspc.clone();
join_set.spawn(async move {
let _permit = permit;
let now = unix_now()?;
let result = pccs.refresh_collateral(fmspc.clone(), ca, now).await;
Ok::<
(String, &'static str, Result<(), DcapVerificationError>),
DcapVerificationError,
>((fmspc, ca, result.map(|_| ())))
});
}
}

let mut successes = 0usize;
let mut failures = 0usize;
while let Some(task_result) = join_set.join_next().await {
match task_result {
Ok(Ok((_, _, Ok(())))) => {
successes += 1;
self.prewarm_stats.successes.fetch_add(1, Ordering::SeqCst);
}
Ok(Ok((fmspc, ca, Err(e)))) => {
failures += 1;
self.prewarm_stats.failures.fetch_add(1, Ordering::SeqCst);
tracing::debug!(
fmspc,
ca,
error = %e,
"Startup pre-provision failed for FMSPC/CA"
);
}
Ok(Err(e)) => {
failures += 1;
self.prewarm_stats.failures.fetch_add(1, Ordering::SeqCst);
tracing::debug!(error = %e, "Startup pre-provision task failed");
}
Err(e) => {
failures += 1;
self.prewarm_stats.failures.fetch_add(1, Ordering::SeqCst);
tracing::debug!(error = %e, "Startup pre-provision join error");
}
}
}
self.prewarm_stats.completed.store(true, Ordering::SeqCst);

tracing::info!(
discovered_fmspcs = self.prewarm_stats.discovered_fmspcs.load(Ordering::SeqCst),
attempted = self.prewarm_stats.attempted.load(Ordering::SeqCst),
successes,
failures,
"Completed PCCS startup pre-provisioning for TDX collateral"
);
}

/// Fetches available FMSPC entries from configured PCCS/PCS endpoint
async fn fetch_fmspcs(&self) -> Result<Vec<FmspcEntry>, DcapVerificationError> {
let url = format!(
"{}/sgx/certification/v4/fmspcs",
self.certification_base_url()
);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(15))
.build()
.map_err(anyhow::Error::from)?;
let response = client
.get(&url)
.send()
.await
.map_err(anyhow::Error::from)?;
if !response.status().is_success() {
return Err(anyhow::anyhow!("Failed to fetch {url}: {}", response.status()).into());
}
let body = response
.text()
.await
.map_err(anyhow::Error::from)
.context("Failed to read FMSPC list response body")?;
let entries: Vec<FmspcEntry> =
serde_json::from_str(&body).context("Failed to decode FMSPC list response")?;
Ok(entries)
}

/// Returns PCCS/PCS base URL without certification suffixes
fn certification_base_url(&self) -> String {
self.pccs_url
.trim_end_matches('/')
.trim_end_matches("/sgx/certification/v4")
.trim_end_matches("/tdx/certification/v4")
.to_string()
}
}

/// Cache key for PCCS collateral entries
Expand Down Expand Up @@ -342,6 +489,22 @@ struct QeIdentityNextUpdate {
next_update: String,
}

#[derive(Debug, serde::Deserialize)]
struct FmspcEntry {
fmspc: String,
#[allow(dead_code)]
platform: String,
}

#[derive(Default)]
struct PrewarmStats {
discovered_fmspcs: AtomicUsize,
attempted: AtomicUsize,
successes: AtomicUsize,
failures: AtomicUsize,
completed: AtomicBool,
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -640,4 +803,49 @@ mod tests {
assert!(!is_fresh_again);
assert_eq!(mock.tcb_call_count(), before_check_calls);
}

#[tokio::test]
async fn test_startup_prewarm_populates_cache_from_intel_pcs() {
let pccs = Pccs::new(None);

let mut prewarm_completed = false;
for _ in 0..40 {
if pccs.prewarm_stats.completed.load(Ordering::SeqCst) {
prewarm_completed = true;
break;
}
tokio::time::sleep(Duration::from_millis(500)).await;
}

let cache_guard = pccs.cache.read().await;
let total_entries = cache_guard.len();
let unique_fmspcs: std::collections::HashSet<_> =
cache_guard.keys().map(|k| k.fmspc.clone()).collect();
println!(
"startup prewarm summary: completed={}, discovered_fmspcs={}, attempted={}, successes={}, failures={}, cache_entries_total={}, cache_unique_fmspcs={}",
prewarm_completed,
pccs.prewarm_stats.discovered_fmspcs.load(Ordering::SeqCst),
pccs.prewarm_stats.attempted.load(Ordering::SeqCst),
pccs.prewarm_stats.successes.load(Ordering::SeqCst),
pccs.prewarm_stats.failures.load(Ordering::SeqCst),
total_entries,
unique_fmspcs.len()
);
if pccs.prewarm_stats.discovered_fmspcs.load(Ordering::SeqCst) == 0 {
println!("startup prewarm made no discovery progress in test window, skipping cache assertions");
return;
}
assert!(total_entries > 0, "expected startup pre-provision to populate PCCS cache");

let (fmspc, ca) = cache_guard
.keys()
.next()
.map(|k| (k.fmspc.clone(), k.ca.clone()))
.expect("expected startup pre-provision to populate PCCS cache");
drop(cache_guard);
let ca_static = ca_as_static(&ca).expect("unexpected CA value in warmed cache entry");
let now = unix_now().unwrap();
let (_, is_fresh) = pccs.get_collateral(fmspc, ca_static, now).await.unwrap();
assert!(!is_fresh);
}
}
Loading