diff --git a/attested-tls/src/attestation/pccs.rs b/attested-tls/src/attestation/pccs.rs index be2837d..5086bdf 100644 --- a/attested-tls/src/attestation/pccs.rs +++ b/attested-tls/src/attestation/pccs.rs @@ -1,14 +1,19 @@ use std::{ collections::HashMap, - sync::{Arc, Weak}, + sync::{ + Arc, Weak, + atomic::{AtomicBool, AtomicUsize, Ordering}, + }, time::{SystemTime, UNIX_EPOCH}, }; +use anyhow::Context; use dcap_qvl::{QuoteCollateralV3, collateral::get_collateral_for_fmspc, tcb_info::TcbInfo}; use time::{OffsetDateTime, format_description::well_known::Rfc3339}; use tokio::{ - sync::RwLock, - task::JoinHandle, + runtime::Handle, + sync::{RwLock, Semaphore}, + task::{JoinHandle, JoinSet}, time::{Duration, sleep}, }; @@ -16,12 +21,14 @@ use crate::attestation::dcap::{DcapVerificationError, PCS_URL}; const REFRESH_MARGIN_SECS: i64 = 300; const REFRESH_RETRY_SECS: u64 = 60; +const STARTUP_PREWARM_CONCURRENCY: usize = 8; /// PCCS collateral cache with proactive background refresh #[derive(Clone)] pub struct Pccs { pccs_url: String, cache: Arc>>, + prewarm_stats: Arc, } impl std::fmt::Debug for Pccs { @@ -36,10 +43,13 @@ impl std::fmt::Debug for Pccs { impl Pccs { /// Creates a new PCCS cache using the provided URL or Intel PCS default pub fn new(pccs_url: Option) -> Self { - Self { + let pccs = Self { pccs_url: pccs_url.unwrap_or(PCS_URL.to_string()), cache: RwLock::new(HashMap::new()).into(), - } + prewarm_stats: Arc::new(PrewarmStats::default()), + }; + pccs.try_spawn_startup_prewarm(); + pccs } /// Returns collateral from cache when valid, otherwise fetches and caches fresh collateral @@ -128,6 +138,143 @@ impl Pccs { refresh_loop(weak_cache, pccs_url, key).await; })); } + + /// Attempts to spawn startup pre-provisioning without blocking constructor + fn try_spawn_startup_prewarm(&self) { + match Handle::try_current() { + Ok(handle) => { + let pccs = self.clone(); + handle.spawn(async move { + pccs.startup_prewarm_all_tdx().await; + }); + } + Err(_) => { + tracing::warn!("No Tokio runtime available, skipping PCCS startup pre-provision"); + self.prewarm_stats.completed.store(true, Ordering::SeqCst); + } + } + } + + /// Pre-provisions TDX collateral for discovered FMSPC values to reduce hot-path fetches + async fn startup_prewarm_all_tdx(&self) { + let fmspcs = match self.fetch_fmspcs().await { + Ok(fmspcs) => fmspcs, + Err(e) => { + tracing::warn!(error = %e, "Failed to fetch FMSPC list for startup pre-provision"); + self.prewarm_stats.completed.store(true, Ordering::SeqCst); + return; + } + }; + self.prewarm_stats + .discovered_fmspcs + .store(fmspcs.len(), Ordering::SeqCst); + + if fmspcs.is_empty() { + tracing::warn!("No FMSPC entries returned during startup pre-provision"); + self.prewarm_stats.completed.store(true, Ordering::SeqCst); + return; + } + + let semaphore = Arc::new(Semaphore::new(STARTUP_PREWARM_CONCURRENCY)); + let mut join_set = JoinSet::new(); + for entry in fmspcs { + for ca in ["processor", "platform"] { + let permit = semaphore.clone().acquire_owned().await; + let Ok(permit) = permit else { + continue; + }; + self.prewarm_stats.attempted.fetch_add(1, Ordering::SeqCst); + let pccs = self.clone(); + let fmspc = entry.fmspc.clone(); + join_set.spawn(async move { + let _permit = permit; + let now = unix_now()?; + let result = pccs.refresh_collateral(fmspc.clone(), ca, now).await; + Ok::< + (String, &'static str, Result<(), DcapVerificationError>), + DcapVerificationError, + >((fmspc, ca, result.map(|_| ()))) + }); + } + } + + let mut successes = 0usize; + let mut failures = 0usize; + while let Some(task_result) = join_set.join_next().await { + match task_result { + Ok(Ok((_, _, Ok(())))) => { + successes += 1; + self.prewarm_stats.successes.fetch_add(1, Ordering::SeqCst); + } + Ok(Ok((fmspc, ca, Err(e)))) => { + failures += 1; + self.prewarm_stats.failures.fetch_add(1, Ordering::SeqCst); + tracing::debug!( + fmspc, + ca, + error = %e, + "Startup pre-provision failed for FMSPC/CA" + ); + } + Ok(Err(e)) => { + failures += 1; + self.prewarm_stats.failures.fetch_add(1, Ordering::SeqCst); + tracing::debug!(error = %e, "Startup pre-provision task failed"); + } + Err(e) => { + failures += 1; + self.prewarm_stats.failures.fetch_add(1, Ordering::SeqCst); + tracing::debug!(error = %e, "Startup pre-provision join error"); + } + } + } + self.prewarm_stats.completed.store(true, Ordering::SeqCst); + + tracing::info!( + discovered_fmspcs = self.prewarm_stats.discovered_fmspcs.load(Ordering::SeqCst), + attempted = self.prewarm_stats.attempted.load(Ordering::SeqCst), + successes, + failures, + "Completed PCCS startup pre-provisioning for TDX collateral" + ); + } + + /// Fetches available FMSPC entries from configured PCCS/PCS endpoint + async fn fetch_fmspcs(&self) -> Result, DcapVerificationError> { + let url = format!( + "{}/sgx/certification/v4/fmspcs", + self.certification_base_url() + ); + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(15)) + .build() + .map_err(anyhow::Error::from)?; + let response = client + .get(&url) + .send() + .await + .map_err(anyhow::Error::from)?; + if !response.status().is_success() { + return Err(anyhow::anyhow!("Failed to fetch {url}: {}", response.status()).into()); + } + let body = response + .text() + .await + .map_err(anyhow::Error::from) + .context("Failed to read FMSPC list response body")?; + let entries: Vec = + serde_json::from_str(&body).context("Failed to decode FMSPC list response")?; + Ok(entries) + } + + /// Returns PCCS/PCS base URL without certification suffixes + fn certification_base_url(&self) -> String { + self.pccs_url + .trim_end_matches('/') + .trim_end_matches("/sgx/certification/v4") + .trim_end_matches("/tdx/certification/v4") + .to_string() + } } /// Cache key for PCCS collateral entries @@ -342,6 +489,22 @@ struct QeIdentityNextUpdate { next_update: String, } +#[derive(Debug, serde::Deserialize)] +struct FmspcEntry { + fmspc: String, + #[allow(dead_code)] + platform: String, +} + +#[derive(Default)] +struct PrewarmStats { + discovered_fmspcs: AtomicUsize, + attempted: AtomicUsize, + successes: AtomicUsize, + failures: AtomicUsize, + completed: AtomicBool, +} + #[cfg(test)] mod tests { use super::*; @@ -640,4 +803,49 @@ mod tests { assert!(!is_fresh_again); assert_eq!(mock.tcb_call_count(), before_check_calls); } + + #[tokio::test] + async fn test_startup_prewarm_populates_cache_from_intel_pcs() { + let pccs = Pccs::new(None); + + let mut prewarm_completed = false; + for _ in 0..40 { + if pccs.prewarm_stats.completed.load(Ordering::SeqCst) { + prewarm_completed = true; + break; + } + tokio::time::sleep(Duration::from_millis(500)).await; + } + + let cache_guard = pccs.cache.read().await; + let total_entries = cache_guard.len(); + let unique_fmspcs: std::collections::HashSet<_> = + cache_guard.keys().map(|k| k.fmspc.clone()).collect(); + println!( + "startup prewarm summary: completed={}, discovered_fmspcs={}, attempted={}, successes={}, failures={}, cache_entries_total={}, cache_unique_fmspcs={}", + prewarm_completed, + pccs.prewarm_stats.discovered_fmspcs.load(Ordering::SeqCst), + pccs.prewarm_stats.attempted.load(Ordering::SeqCst), + pccs.prewarm_stats.successes.load(Ordering::SeqCst), + pccs.prewarm_stats.failures.load(Ordering::SeqCst), + total_entries, + unique_fmspcs.len() + ); + if pccs.prewarm_stats.discovered_fmspcs.load(Ordering::SeqCst) == 0 { + println!("startup prewarm made no discovery progress in test window, skipping cache assertions"); + return; + } + assert!(total_entries > 0, "expected startup pre-provision to populate PCCS cache"); + + let (fmspc, ca) = cache_guard + .keys() + .next() + .map(|k| (k.fmspc.clone(), k.ca.clone())) + .expect("expected startup pre-provision to populate PCCS cache"); + drop(cache_guard); + let ca_static = ca_as_static(&ca).expect("unexpected CA value in warmed cache entry"); + let now = unix_now().unwrap(); + let (_, is_fresh) = pccs.get_collateral(fmspc, ca_static, now).await.unwrap(); + assert!(!is_fresh); + } }