From 4ffba7546428334d12b957e796120ca63f8d0c2d Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Mon, 8 Dec 2025 20:58:33 +0100 Subject: [PATCH 01/31] feat(virtq): add packed virtual queue implementation Signed-off-by: Tomasz Andrzejak --- Cargo.lock | 18 + Cargo.toml | 1 + src/hyperlight_common/Cargo.toml | 12 + src/hyperlight_common/benches/buffer_pool.rs | 176 +++ src/hyperlight_common/src/virtq/consumer.rs | 633 +++++++++ src/hyperlight_common/src/virtq/mod.rs | 964 ++++++++++++- src/hyperlight_common/src/virtq/pool.rs | 1334 ++++++++++++++++++ src/hyperlight_common/src/virtq/producer.rs | 790 +++++++++++ src/hyperlight_guest/src/error.rs | 7 +- 9 files changed, 3925 insertions(+), 10 deletions(-) create mode 100644 src/hyperlight_common/benches/buffer_pool.rs create mode 100644 src/hyperlight_common/src/virtq/consumer.rs create mode 100644 src/hyperlight_common/src/virtq/pool.rs create mode 100644 src/hyperlight_common/src/virtq/producer.rs diff --git a/Cargo.lock b/Cargo.lock index 40790f9f1..d5d98755b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -181,6 +181,12 @@ version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" +[[package]] +name = "atomic_refcell" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21e4227379beff4205943696e6c3e0cd809bacdf3f0edd6e3dd153e2269571a4" + [[package]] name = "autocfg" version = "1.5.0" @@ -925,6 +931,12 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" +[[package]] +name = "fixedbitset" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" + [[package]] name = "flatbuffers" version = "25.12.19" @@ -1483,10 +1495,16 @@ version = "0.15.0" dependencies = [ "anyhow", "arbitrary", + "atomic_refcell", "bitflags 2.11.1", "bytemuck", + "bytes", + "criterion", + "fixedbitset", "flatbuffers", + "hyperlight-testing", "log", + "loom", "quickcheck", "rand 0.9.2", "smallvec", diff --git a/Cargo.toml b/Cargo.toml index e693368d0..84b38bb62 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,6 +50,7 @@ hyperlight-component-macro = { path = "src/hyperlight_component_macro", version [workspace.lints.rust] unsafe_op_in_unsafe_fn = "deny" +unexpected_cfgs = { level = "warn", check-cfg = [ 'cfg(loom)' ] } # this will generate symbols for release builds # so is handy for debugging issues in release builds diff --git a/src/hyperlight_common/Cargo.toml b/src/hyperlight_common/Cargo.toml index ad9558bc5..32a78026f 100644 --- a/src/hyperlight_common/Cargo.toml +++ b/src/hyperlight_common/Cargo.toml @@ -17,8 +17,11 @@ workspace = true [dependencies] arbitrary = {version = "1.4.2", optional = true, features = ["derive"]} anyhow = { version = "1.0.102", default-features = false } +atomic_refcell = "0.1.13" bitflags = "2.10.0" bytemuck = { version = "1.24", features = ["derive"] } +bytes = { version = "1", default-features = false } +fixedbitset = { version = "0.5.7", default-features = false } flatbuffers = { version = "25.12.19", default-features = false } log = "0.4.29" smallvec = "1.15.1" @@ -39,9 +42,18 @@ nanvix-unstable = ["i686-guest"] guest-counter = [] [dev-dependencies] +criterion = "0.8.1" +hyperlight-testing = { workspace = true } quickcheck = "1.0.3" rand = "0.9.2" +[target.'cfg(loom)'.dev-dependencies] +loom = "0.7" + [lib] bench = false # see https://bheisler.github.io/criterion.rs/book/faq.html#cargo-bench-gives-unrecognized-option-errors-for-valid-command-line-options doctest = false # reduce noise in test output + +[[bench]] +name = "buffer_pool" +harness = false diff --git a/src/hyperlight_common/benches/buffer_pool.rs b/src/hyperlight_common/benches/buffer_pool.rs new file mode 100644 index 000000000..614f160b0 --- /dev/null +++ b/src/hyperlight_common/benches/buffer_pool.rs @@ -0,0 +1,176 @@ +use std::hint::black_box; + +use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main}; +use hyperlight_common::virtq::{BufferPool, BufferProvider}; + +// Helper to create a pool for benchmarking +fn make_pool(size: usize) -> BufferPool { + let base = 0x10000; + BufferPool::::new(base, size).unwrap() +} + +// Single allocation performance +fn bench_alloc_single(c: &mut Criterion) { + let mut group = c.benchmark_group("alloc_single"); + + for size in [64, 128, 256, 512, 1024, 1500, 4096].iter() { + group.throughput(Throughput::Elements(1)); + group.bench_with_input(BenchmarkId::from_parameter(size), size, |b, &size| { + let pool = make_pool::<256, 4096>(4 * 1024 * 1024); + b.iter(|| { + let alloc = pool.alloc(black_box(size)).unwrap(); + pool.dealloc(alloc).unwrap(); + }); + }); + } + group.finish(); +} + +// LIFO recycling +fn bench_alloc_lifo(c: &mut Criterion) { + let mut group = c.benchmark_group("alloc_lifo"); + + for size in [256, 1500, 4096].iter() { + group.throughput(Throughput::Elements(100)); + group.bench_with_input(BenchmarkId::from_parameter(size), size, |b, &size| { + let pool = make_pool::<256, 4096>(4 * 1024 * 1024); + b.iter(|| { + for _ in 0..100 { + let alloc = pool.alloc(black_box(size)).unwrap(); + pool.dealloc(alloc).unwrap(); + } + }); + }); + } + group.finish(); +} + +// Fragmented allocation worst case +fn bench_alloc_fragmented(c: &mut Criterion) { + let mut group = c.benchmark_group("alloc_fragmented"); + + group.bench_function("fragmented_256", |b| { + let pool = make_pool::<256, 4096>(4 * 1024 * 1024); + + // Create fragmentation pattern: allocate many, free every other + let mut allocations = Vec::new(); + for _ in 0..100 { + allocations.push(pool.alloc(128).unwrap()); + } + for i in (0..100).step_by(2) { + pool.dealloc(allocations[i]).unwrap(); + } + + b.iter(|| { + let alloc = pool.alloc(black_box(256)).unwrap(); + pool.dealloc(alloc).unwrap(); + }); + }); + + group.finish(); +} + +// Realloc operations +fn bench_realloc(c: &mut Criterion) { + let mut group = c.benchmark_group("realloc"); + + // In-place grow (same tier) + group.bench_function("grow_inplace", |b| { + let pool = make_pool::<256, 4096>(4 * 1024 * 1024); + b.iter(|| { + let alloc = pool.alloc(256).unwrap(); + let grown = pool.resize(alloc, black_box(512)).unwrap(); + pool.dealloc(grown).unwrap(); + }); + }); + + // Relocate grow (cross tier) + group.bench_function("grow_relocate", |b| { + let pool = make_pool::<256, 4096>(4 * 1024 * 1024); + b.iter(|| { + let alloc = pool.alloc(128).unwrap(); + // Block in-place growth + let blocker = pool.alloc(256).unwrap(); + let grown = pool.resize(alloc, black_box(1500)).unwrap(); + pool.dealloc(grown).unwrap(); + pool.dealloc(blocker).unwrap(); + }); + }); + + // Shrink + group.bench_function("shrink", |b| { + let pool = make_pool::<256, 4096>(4 * 1024 * 1024); + b.iter(|| { + let alloc = pool.alloc(1500).unwrap(); + let shrunk = pool.resize(alloc, black_box(256)).unwrap(); + pool.dealloc(shrunk).unwrap(); + }); + }); + + group.finish(); +} + +// Free performance +fn bench_free(c: &mut Criterion) { + let mut group = c.benchmark_group("free"); + + for size in [256, 1500, 4096].iter() { + group.bench_with_input(BenchmarkId::from_parameter(size), size, |b, &size| { + let pool = make_pool::<256, 4096>(4 * 1024 * 1024); + b.iter(|| { + let alloc = pool.alloc(size).unwrap(); + pool.dealloc(black_box(alloc)).unwrap(); + }); + }); + } + + group.finish(); +} + +// Cursor optimization +fn bench_last_free_run(c: &mut Criterion) { + let mut group = c.benchmark_group("last_free_run"); + + // With cursor optimization (LIFO) + group.bench_function("lifo_pattern", |b| { + let pool = make_pool::<256, 4096>(4 * 1024 * 1024); + b.iter(|| { + let alloc = pool.alloc(256).unwrap(); + pool.dealloc(alloc).unwrap(); + let alloc2 = pool.alloc(black_box(256)).unwrap(); + pool.dealloc(alloc2).unwrap(); + }); + }); + + // Without cursor benefit (FIFO-like) + group.bench_function("fifo_pattern", |b| { + let pool = make_pool::<256, 4096>(4 * 1024 * 1024); + let mut queue = Vec::new(); + + // Pre-fill queue + for _ in 0..10 { + queue.push(pool.alloc(256).unwrap()); + } + + b.iter(|| { + // FIFO: free oldest, allocate new + let old = queue.remove(0); + pool.dealloc(old).unwrap(); + queue.push(pool.alloc(black_box(256)).unwrap()); + }); + }); + + group.finish(); +} + +criterion_group!( + benches, + bench_alloc_single, + bench_alloc_lifo, + bench_alloc_fragmented, + bench_realloc, + bench_free, + bench_last_free_run, +); + +criterion_main!(benches); diff --git a/src/hyperlight_common/src/virtq/consumer.rs b/src/hyperlight_common/src/virtq/consumer.rs new file mode 100644 index 000000000..4c7bbc9ba --- /dev/null +++ b/src/hyperlight_common/src/virtq/consumer.rs @@ -0,0 +1,633 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +use alloc::vec; +use alloc::vec::Vec; + +use bytes::Bytes; + +use super::*; + +/// In-flight entry tracking. +/// +/// Stored per descriptor ID while the entry is being processed. +/// Tracks that a descriptor slot is occupied. +#[derive(Debug, Clone, Copy)] +pub(crate) struct Inflight; + +/// Data received from the producer, safely copied out of shared memory. +/// +/// Created by [`VirtqConsumer::poll`]. The entry data is eagerly copied +/// from shared memory during poll using [`MemOps::read`] (volatile on +/// the host side), so accessing it requires no unsafe code and no +/// references into shared memory. +#[derive(Debug, Clone)] +pub struct RecvEntry { + token: Token, + data: Bytes, +} + +impl RecvEntry { + /// The token identifying this entry. + pub fn token(&self) -> Token { + self.token + } + + /// The entry payload, copied from shared memory. + /// + /// Returns empty [`Bytes`] when the chain has no readable buffers. + pub fn data(&self) -> &Bytes { + &self.data + } + + /// Consume the entry, taking ownership of the data. + pub fn into_data(self) -> Bytes { + self.data + } +} + +/// A pending completion, either writable or ack-only. +/// +/// Created by [`VirtqConsumer::poll`]. Must be submitted back via +/// [`VirtqConsumer::complete`] to release the descriptor. +#[must_use = "dropping without completing leaks the descriptor"] +pub enum SendCompletion { + /// Completion with a writable buffer (for chains with a completion buffer). + /// Use the `write*` methods on [`WritableCompletion`] to fill the + /// response buffer. + Writable(WritableCompletion), + /// Ack-only completion (for chains with only entry buffers). No response buffer. + /// Just pass back to [`VirtqConsumer::complete`] to acknowledge. + Ack(AckCompletion), +} + +impl SendCompletion { + /// The token identifying this completion. + pub fn token(&self) -> Token { + match self { + SendCompletion::Writable(wc) => wc.token(), + SendCompletion::Ack(ack) => ack.token(), + } + } + + /// Number of bytes written (0 for Ack). + pub fn written(&self) -> usize { + match self { + SendCompletion::Writable(wc) => wc.written(), + SendCompletion::Ack(_) => 0, + } + } + + fn id(&self) -> u16 { + match self { + SendCompletion::Writable(wc) => wc.id, + SendCompletion::Ack(ack) => ack.id, + } + } +} + +/// A completion with a writable buffer for response data. +/// +/// # Example +/// +/// ```ignore +/// if let SendCompletion::Writable(mut wc) = completion { +/// wc.write_all(b"response data")?; +/// consumer.complete(wc.into())?; +/// } +/// ``` +#[must_use = "dropping without completing leaks the descriptor"] +pub struct WritableCompletion { + mem: M, + id: u16, + token: Token, + elem: BufferElement, + written: usize, +} + +impl WritableCompletion { + fn new(mem: M, id: u16, token: Token, elem: BufferElement) -> Self { + Self { + mem, + id, + token, + elem, + written: 0, + } + } + + /// The token identifying this completion. + pub fn token(&self) -> Token { + self.token + } + + /// Total capacity of the completion buffer in bytes. + pub fn capacity(&self) -> usize { + self.elem.len as usize + } + + /// Number of bytes written so far. + pub fn written(&self) -> usize { + self.written + } + + /// Remaining writable capacity. + pub fn remaining(&self) -> usize { + self.capacity() - self.written + } + + /// Write bytes into the completion buffer, returning how many were written. + /// + /// Appends at the current write position. If `buf` is larger than the + /// remaining capacity, writes as many bytes as will fit (partial write). + /// + /// Returns the number of bytes actually written. + /// + /// # Errors + /// + /// - [`VirtqError::MemoryWriteError`] - underlying MemOps write failed + pub fn write(&mut self, buf: &[u8]) -> Result { + let to_write = buf.len().min(self.remaining()); + if to_write == 0 { + return Ok(0); + } + + let addr = self.elem.addr + self.written as u64; + self.mem + .write(addr, &buf[..to_write]) + .map_err(|_| VirtqError::MemoryWriteError)?; + + self.written += to_write; + Ok(to_write) + } + + /// Write the entire buffer or return an error. + /// + /// # Errors + /// + /// - [`VirtqError::CqeTooLarge`] - buf exceeds remaining capacity + /// - [`VirtqError::MemoryWriteError`] - underlying MemOps write failed + pub fn write_all(&mut self, buf: &[u8]) -> Result<(), VirtqError> { + if buf.len() > self.remaining() { + return Err(VirtqError::CqeTooLarge); + } + + let addr = self.elem.addr + self.written as u64; + self.mem + .write(addr, buf) + .map_err(|_| VirtqError::MemoryWriteError)?; + + self.written += buf.len(); + Ok(()) + } + + /// Reset the write cursor to the beginning. + /// + /// Previously written bytes in shared memory are not zeroed; the + /// `written` count is simply reset to 0. + pub fn reset(&mut self) { + self.written = 0; + } +} + +/// An ack-only completion for chains with no writable buffers. +/// +/// No response buffer - just pass back to [`VirtqConsumer::complete`] +/// to acknowledge processing and release the descriptor. +#[must_use = "dropping without completing leaks the descriptor"] +pub struct AckCompletion { + id: u16, + token: Token, +} + +impl AckCompletion { + fn new(id: u16, token: Token) -> Self { + Self { id, token } + } + + pub fn token(&self) -> Token { + self.token + } +} + +/// A high-level virtqueue consumer (device side). +/// +/// The consumer receives entries from the producer (driver), processes them, +/// and sends back completions. This is typically used on the device/host side. +/// +/// # Example +/// +/// ```ignore +/// let mut consumer = VirtqConsumer::new(layout, mem, notifier); +/// +/// // Poll and process +/// while let Some((entry, completion)) = consumer.poll(MAX_ENTRY_SIZE)? { +/// let data = entry.data(); +/// match completion { +/// SendCompletion::Writable(mut wc) => { +/// let response = handle_request(data); +/// wc.write_all(&response)?; +/// consumer.complete(wc.into())?; +/// } +/// SendCompletion::Ack(ack) => { +/// consumer.complete(ack.into())?; +/// } +/// } +/// } +/// +/// // Or defer completions +/// let mut pending = Vec::new(); +/// while let Some((entry, completion)) = consumer.poll(MAX_ENTRY_SIZE)? { +/// pending.push((process(entry), completion)); +/// } +/// for (result, completion) in pending { +/// // ... complete later ... +/// consumer.complete(completion)?; +/// } +/// ``` +pub struct VirtqConsumer { + inner: RingConsumer, + notifier: N, + inflight: Vec>, +} + +impl VirtqConsumer { + /// Create a new virtqueue consumer. + /// + /// # Arguments + /// + /// * `layout` - Ring memory layout (descriptor table and event suppression addresses) + /// * `mem` - Memory operations implementation for reading/writing to shared memory + /// * `notifier` - Callback for notifying the driver (producer) about completions + pub fn new(layout: Layout, mem: M, notifier: N) -> Self { + let inner = RingConsumer::new(layout, mem); + let inflight = vec![None; inner.len()]; + + Self { + inner, + notifier, + inflight, + } + } + + /// Poll for a single incoming entry from the driver. + /// + /// Returns a [`RecvEntry`] (data copied from shared memory) and a + /// [`SendCompletion`] (writable handle or ack token). Both are + /// independent owned values with no borrow on the consumer. + /// + /// # Arguments + /// + /// * `max_entry` - Maximum entry size to accept. Entries larger than + /// this will return [`VirtqError::EntryTooLarge`]. + /// + /// # Errors + /// + /// - [`VirtqError::EntryTooLarge`] - Entry data exceeds `max_entry` bytes + /// - [`VirtqError::BadChain`] - Descriptor chain format not recognized + /// - [`VirtqError::InvalidState`] - Descriptor ID collision (driver bug) + /// - [`VirtqError::MemoryReadError`] - Failed to read entry from shared memory + pub fn poll( + &mut self, + max_entry: usize, + ) -> Result)>, VirtqError> { + let (id, chain) = match self.inner.poll_available() { + Ok(x) => x, + Err(RingError::WouldBlock) => return Ok(None), + Err(e) => return Err(e.into()), + }; + + let (entry_elem, cqe_elem) = parse_chain(&chain)?; + + // Validate entry size + if let Some(ref elem) = entry_elem + && elem.len as usize > max_entry + { + return Err(VirtqError::EntryTooLarge); + } + + // Reserve the inflight slot + let slot = self + .inflight + .get_mut(id as usize) + .ok_or(VirtqError::InvalidState)?; + + if slot.is_some() { + return Err(VirtqError::InvalidState); + } + + *slot = Some(Inflight); + let token = Token(id); + + // Copy entry data from shared memory + let data = entry_elem + .map(|elem| self.read_element(&elem)) + .transpose()? + .unwrap_or_default(); + + let entry = RecvEntry { token, data }; + + // Build the appropriate completion handle + let completion = if let Some(elem) = cqe_elem { + let mem = self.inner.mem().clone(); + let cqe = WritableCompletion::new(mem, id, token, elem); + SendCompletion::Writable(cqe) + } else { + let ack = AckCompletion::new(id, token); + SendCompletion::Ack(ack) + }; + + Ok(Some((entry, completion))) + } + + /// Submit a completed entry back to the ring. + /// + /// Accepts both [`WritableCompletion`] (with written byte count) and + /// [`AckCompletion`] (zero-length) via the [`SendCompletion`] enum. + /// Clears the inflight slot and notifies the producer if event + /// suppression allows. + pub fn complete(&mut self, completion: SendCompletion) -> Result<(), VirtqError> { + let id = completion.id(); + let written = completion.written() as u32; + + let slot = self + .inflight + .get_mut(id as usize) + .ok_or(VirtqError::InvalidState)?; + + if slot.is_none() { + return Err(VirtqError::InvalidState); + } + + *slot = None; + + if self.inner.submit_used_with_notify(id, written)? { + self.notifier.notify(QueueStats { + num_free: self.inner.num_free(), + num_inflight: self.inner.num_inflight(), + }); + } + + Ok(()) + } + + /// Get the current available cursor position. + /// + /// Returns the position where the next available descriptor will be + /// consumed. Useful for setting up descriptor-based event suppression. + #[inline] + pub fn avail_cursor(&self) -> RingCursor { + self.inner.avail_cursor() + } + + /// Get the current used cursor position. + /// + /// Returns the position where the next used descriptor will be written. + /// Useful for setting up descriptor-based event suppression. + #[inline] + pub fn used_cursor(&self) -> RingCursor { + self.inner.used_cursor() + } + + /// Configure event suppression for available buffer notifications. + /// + /// This controls when the driver (producer) signals us about new buffers: + /// + /// - [`SuppressionKind::Enable`] - Always signal (default) - good for latency + /// - [`SuppressionKind::Disable`] - Never signal - caller must poll + /// - [`SuppressionKind::Descriptor`] - Signal only at specific cursor position + /// + /// # Example: Polling Mode + /// ```ignore + /// consumer.set_avail_suppression(SuppressionKind::Disable)?; + /// loop { + /// while let Some((entry, completion)) = consumer.poll(1024)? { + /// process(entry, completion); + /// } + /// // ... do other work ... + /// } + /// ``` + pub fn set_avail_suppression(&mut self, kind: SuppressionKind) -> Result<(), VirtqError> { + match kind { + SuppressionKind::Enable => self.inner.enable_avail_notifications()?, + SuppressionKind::Disable => self.inner.disable_avail_notifications()?, + SuppressionKind::Descriptor(cursor) => self + .inner + .enable_avail_notifications_desc(cursor.head(), cursor.wrap())?, + } + Ok(()) + } + + /// Read a buffer element from shared memory into `Bytes`. + fn read_element(&self, elem: &BufferElement) -> Result { + let mut buf = vec![0u8; elem.len as usize]; + self.inner + .mem() + .read(elem.addr, &mut buf) + .map_err(|_| VirtqError::MemoryReadError)?; + + Ok(Bytes::from(buf)) + } +} + +/// Parse a descriptor chain into entry/completion buffer elements. +/// +/// Returns `(entry_element, completion_element)`. +fn parse_chain( + chain: &BufferChain, +) -> Result<(Option, Option), VirtqError> { + let r = chain.readables(); + let w = chain.writables(); + + match (r.len(), w.len()) { + (1, 1) => Ok((Some(r[0]), Some(w[0]))), + (0, 1) => Ok((None, Some(w[0]))), + (1, 0) => Ok((Some(r[0]), None)), + _ => Err(VirtqError::BadChain), + } +} + +impl From> for SendCompletion { + fn from(wc: WritableCompletion) -> Self { + SendCompletion::Writable(wc) + } +} + +impl From for SendCompletion { + fn from(ack: AckCompletion) -> Self { + SendCompletion::Ack(ack) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::virtq::ring::tests::make_ring; + use crate::virtq::test_utils::*; + + #[test] + fn test_write_only_entry_is_empty() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let se = producer.chain().completion(16).build().unwrap(); + producer.submit(se).unwrap(); + + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert!(entry.data().is_empty()); + assert!(matches!(completion, SendCompletion::Writable(_))); + + if let SendCompletion::Writable(mut wc) = completion { + wc.write_all(b"response").unwrap(); + consumer.complete(wc.into()).unwrap(); + } + } + + #[test] + fn test_read_only_ack_completion() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let mut se = producer.chain().entry(16).build().unwrap(); + se.write_all(b"hello").unwrap(); + producer.submit(se).unwrap(); + + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.data().as_ref(), b"hello"); + assert!(matches!(completion, SendCompletion::Ack(_))); + + consumer.complete(completion).unwrap(); + } + + #[test] + fn test_readwrite_round_trip() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let mut se = producer.chain().entry(32).completion(64).build().unwrap(); + se.write_all(b"hello world").unwrap(); + producer.submit(se).unwrap(); + + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.data().as_ref(), b"hello world"); + + if let SendCompletion::Writable(mut wc) = completion { + assert_eq!(wc.capacity(), 64); + assert_eq!(wc.written(), 0); + assert_eq!(wc.remaining(), 64); + wc.write_all(b"response").unwrap(); + assert_eq!(wc.written(), 8); + assert_eq!(wc.remaining(), 56); + consumer.complete(wc.into()).unwrap(); + } else { + panic!("expected Writable completion for entry+completion chain"); + } + } + + #[test] + fn test_writable_partial_write() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let se = producer.chain().completion(8).build().unwrap(); + producer.submit(se).unwrap(); + + let (_entry, completion) = consumer.poll(1024).unwrap().unwrap(); + + if let SendCompletion::Writable(mut wc) = completion { + let n = wc.write(b"hello world!").unwrap(); + assert_eq!(n, 8); + assert_eq!(wc.remaining(), 0); + consumer.complete(wc.into()).unwrap(); + } else { + panic!("expected Writable"); + } + } + + #[test] + fn test_writable_write_all_too_large() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let se = producer.chain().completion(4).build().unwrap(); + producer.submit(se).unwrap(); + let (_entry, completion) = consumer.poll(1024).unwrap().unwrap(); + + if let SendCompletion::Writable(mut wc) = completion { + let err = wc.write_all(b"too long").unwrap_err(); + assert!(matches!(err, VirtqError::CqeTooLarge)); + } else { + panic!("expected Writable"); + } + } + + #[test] + fn test_writable_reset() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let se = producer.chain().completion(16).build().unwrap(); + producer.submit(se).unwrap(); + + let (_entry, completion) = consumer.poll(1024).unwrap().unwrap(); + + if let SendCompletion::Writable(mut wc) = completion { + wc.write_all(b"first").unwrap(); + assert_eq!(wc.written(), 5); + wc.reset(); + assert_eq!(wc.written(), 0); + assert_eq!(wc.remaining(), 16); + wc.write_all(b"second").unwrap(); + assert_eq!(wc.written(), 6); + consumer.complete(wc.into()).unwrap(); + } else { + panic!("expected Writable"); + } + } + + #[test] + fn test_multiple_pending_completions() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let se1 = producer.chain().completion(16).build().unwrap(); + producer.submit(se1).unwrap(); + let se2 = producer.chain().completion(16).build().unwrap(); + producer.submit(se2).unwrap(); + + let (_e1, c1) = consumer.poll(1024).unwrap().unwrap(); + let (_e2, c2) = consumer.poll(1024).unwrap().unwrap(); + + // Complete in reverse order + consumer.complete(c2).unwrap(); + consumer.complete(c1).unwrap(); + } + + #[test] + fn test_entry_into_data() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let mut se = producer.chain().entry(16).build().unwrap(); + se.write_all(b"abc").unwrap(); + producer.submit(se).unwrap(); + + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + let data = entry.into_data(); + assert_eq!(data.as_ref(), b"abc"); + consumer.complete(completion).unwrap(); + } +} diff --git a/src/hyperlight_common/src/virtq/mod.rs b/src/hyperlight_common/src/virtq/mod.rs index 326aac933..dd648e3dc 100644 --- a/src/hyperlight_common/src/virtq/mod.rs +++ b/src/hyperlight_common/src/virtq/mod.rs @@ -14,14 +14,20 @@ See the License for the specific language governing permissions and limitations under the License. */ -//! Packed Virtqueue - Ring Primitives +//! Packed Virtqueue Implementation //! -//! This module provides low-level ring primitives for virtio packed virtqueues, -//! implementing the VIRTIO 1.1+ packed ring format with proper memory ordering -//! and event suppression support. +//! This module provides a high-level API for virtio packed virtqueues, built on top of +//! the lower-level ring primitives. It implements the VIRTIO 1.1+ packed ring format +//! with proper memory ordering and event suppression support. //! //! # Architecture //! +//! The implementation is split into layers: +//! +//! - **High-level API** ([`VirtqProducer`], [`VirtqConsumer`]): Manages buffer allocation, +//! entry/completion lifecycle, and notification decisions. This is the recommended API +//! for most use cases. +//! //! - **Ring primitives** ([`RingProducer`], [`RingConsumer`]): Low-level descriptor ring //! operations with explicit buffer chain management. Use this when you need full control //! over buffer layouts or custom allocation strategies. @@ -29,11 +35,108 @@ limitations under the License. //! - **Descriptor and event types** ([`Descriptor`], [`EventSuppression`]): Raw virtio //! data structures for direct memory manipulation. //! -//! - **Memory access** ([`MemOps`]): Trait abstracting memory read/write operations, -//! allowing the ring to work with different memory backends (host vs guest). +//! # Quick Start +//! +//! ## Single Entry/Completion +//! +//! ```ignore +//! // Producer (driver) side - build entry, submit, get completion +//! let mut entry = producer.chain() +//! .entry(64) +//! .completion(128) +//! .build()?; +//! entry.write_all(b"entry data")?; +//! let token = producer.submit(entry)?; +//! // ... wait for notification ... +//! if let Some(completion) = producer.poll()? { +//! process(completion.data); +//! } +//! +//! // Consumer (device) side - receive entry, send completion +//! if let Some((entry, completion)) = consumer.poll(max_request_size)? { +//! let request = entry.data(); +//! match completion { +//! SendCompletion::Writable(mut wc) => { +//! let response = handle(request); +//! wc.write_all(&response)?; +//! consumer.complete(wc.into())?; +//! } +//! SendCompletion::Ack(ack) => { +//! consumer.complete(ack.into())?; +//! } +//! } +//! } +//! +//! // Multiple pending completions (no borrow on consumer) +//! let mut pending = Vec::new(); +//! while let Some((entry, completion)) = consumer.poll(max_request_size)? { +//! pending.push((process(entry), completion)); +//! } +//! for (result, completion) in pending { +//! consumer.complete(completion)?; +//! } +//! ``` +//! +//! ## Multiple Entries +//! +//! Each submit checks event suppression and notifies independently: +//! +//! ```ignore +//! for data in entries { +//! let mut se = producer.chain() +//! .entry(data.len()) +//! .completion(64) +//! .build()?; +//! se.write_all(data)?; +//! producer.submit(se)?; +//! } +//! ``` +//! +//! ## Completion Batching with Event Suppression +//! +//! To receive a single notification when multiple requests complete: +//! +//! ```ignore +//! // Submit entries +//! for data in entries { +//! let mut se = producer.chain() +//! .entry(data.len()) +//! .completion(64) +//! .build()?; +//! se.write_all(data)?; +//! producer.submit(se)?; +//! } +//! +//! // Tell device: "notify me only after completing past this cursor" +//! let cursor = producer.used_cursor(); +//! producer.set_used_suppression(SuppressionKind::Descriptor(cursor))?; +//! +//! // Wait for single notification, then drain all responses +//! producer.drain(|token, data| { +//! handle_response(token, data); +//! })?; +//! ``` +//! +//! # Event Suppression +//! +//! Both sides can control when they want to be notified using [`SuppressionKind`]: +//! +//! - [`SuppressionKind::Enable`]: Always notify (default, lowest latency) +//! - [`SuppressionKind::Disable`]: Never notify (polling mode, lowest overhead) +//! - [`SuppressionKind::Descriptor`]: Notify at specific ring position (batching) +//! +//! See [`VirtqProducer::set_used_suppression`] and [`VirtqConsumer::set_avail_suppression`]. //! //! # Low-Level API //! +//! For advanced use cases, the ring module exposes lower-level primitives: +//! +//! - [`RingProducer`] / [`RingConsumer`]: Direct ring access with [`BufferChain`] submission +//! - [`BufferChainBuilder`]: Construct scatter-gather buffer lists +//! - [`RingCursor`]: Track ring positions for event suppression +//! +//! Example using low-level API: +//! //! ```ignore //! let chain = BufferChainBuilder::new() //! .readable(header_addr, header_len) @@ -48,16 +151,53 @@ limitations under the License. //! ``` mod access; +mod consumer; mod desc; mod event; +mod pool; +mod producer; mod ring; use core::num::NonZeroU16; pub use access::*; +pub use consumer::*; pub use desc::*; pub use event::*; +pub use pool::*; +pub use producer::*; pub use ring::*; +use thiserror::Error; + +/// A trait for notifying about new requests in the virtqueue. +pub trait Notifier { + fn notify(&self, stats: QueueStats); +} + +/// Errors that can occur in the virtqueue operations. +#[derive(Error, Debug)] +pub enum VirtqError { + #[error("Ring error: {0}")] + RingError(#[from] RingError), + #[error("Allocation error: {0}")] + Alloc(#[from] AllocError), + #[error("Invalid token")] + BadToken, + #[error("Invalid chain received")] + BadChain, + #[error("Entry data too large for allocated buffer")] + EntryTooLarge, + #[error("Completion data too large for allocated buffer")] + CqeTooLarge, + #[error("Internal state error")] + InvalidState, + #[error("Memory write error")] + MemoryWriteError, + #[error("Memory read error")] + MemoryReadError, + #[error("No readable buffer in this entry")] + NoReadableBuffer, +} /// Layout of a packed virtqueue ring in shared memory. /// @@ -166,8 +306,50 @@ impl Layout { } } +/// Statistics about the current virtqueue state. +/// +/// Provided to the [`Notifier`] when sending notifications, allowing +/// the notifier to make decisions based on queue pressure. +#[derive(Debug, Clone, Copy)] +pub struct QueueStats { + /// Number of free descriptor slots available. + pub num_free: usize, + /// Number of descriptors currently in-flight (submitted but not completed). + pub num_inflight: usize, +} + +/// Event suppression mode for controlling when notifications are sent. +/// +/// This configures when the other side should signal (interrupt/kick) us +/// about new data. Used to optimize batching and reduce interrupt overhead. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SuppressionKind { + /// Always signal after each operation (default behavior). + Enable, + /// Never signal. + Disable, + /// Signal only when reaching a specific descriptor position. + Descriptor(RingCursor), +} + +/// A token representing a sent entry in the virtqueue. +/// +/// Tokens uniquely identify in-flight requests and are used to correlate +/// requests with their responses. The token value corresponds to the +/// descriptor ID in the underlying ring. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub struct Token(pub u16); + +impl From for Allocation { + fn from(value: BufferElement) -> Self { + Allocation { + addr: value.addr, + len: value.len as usize, + } + } +} + const _: () = { - #[allow(clippy::unwrap_used)] const fn verify_layout(num_descs: usize) { let base = 0x1000u64; @@ -219,3 +401,771 @@ const _: () = { verify_layout(512); verify_layout(1024); }; + +/// Shared test utilities for virtqueue tests. +#[cfg(test)] +pub(crate) mod test_utils { + use alloc::sync::Arc; + use core::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; + + use super::*; + use crate::virtq::ring::tests::{OwnedRing, TestMem}; + + /// Simple notifier that tracks notification count. + #[derive(Debug, Clone)] + pub(crate) struct TestNotifier { + pub(crate) count: Arc, + } + + impl TestNotifier { + pub(crate) fn new() -> Self { + Self { + count: Arc::new(AtomicUsize::new(0)), + } + } + + pub(crate) fn notification_count(&self) -> usize { + self.count.load(Ordering::Relaxed) + } + } + + impl Notifier for TestNotifier { + fn notify(&self, _stats: QueueStats) { + self.count.fetch_add(1, Ordering::Relaxed); + } + } + + /// Simple test buffer pool that allocates from a range. + #[derive(Clone)] + pub(crate) struct TestPool { + base: u64, + next: Arc, + size: usize, + } + + impl TestPool { + pub(crate) fn new(base: u64, size: usize) -> Self { + Self { + base, + next: Arc::new(AtomicU64::new(base)), + size, + } + } + } + + impl BufferProvider for TestPool { + fn alloc(&self, len: usize) -> Result { + let addr = self.next.fetch_add(len as u64, Ordering::Relaxed); + let end = addr + len as u64; + if end > self.base + self.size as u64 { + return Err(AllocError::OutOfMemory); + } + Ok(Allocation { addr, len }) + } + + fn dealloc(&self, _alloc: Allocation) -> Result<(), AllocError> { + // Simple pool doesn't track individual allocations + Ok(()) + } + + fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result { + // Simple implementation: always allocate new + self.dealloc(old_alloc)?; + self.alloc(new_len) + } + } + + /// Create test infrastructure: a producer, consumer, and notifier backed + /// by the supplied [`OwnedRing`]. + pub(crate) fn make_test_producer( + ring: &OwnedRing, + ) -> ( + VirtqProducer, + VirtqConsumer, + TestNotifier, + ) { + let layout = ring.layout(); + let mem = ring.mem(); + + // Pool needs to be in memory accessible via mem - use memory after ring layout + let pool_base = mem.base_addr() + Layout::query_size(ring.len()) as u64 + 0x100; + let pool = TestPool::new(pool_base, 0x8000); + let notifier = TestNotifier::new(); + + let producer = VirtqProducer::new(layout, mem.clone(), notifier.clone(), pool); + let consumer = VirtqConsumer::new(layout, mem, notifier.clone()); + + (producer, consumer, notifier) + } +} + +#[cfg(test)] +mod tests { + use alloc::sync::Arc; + use core::sync::atomic::{AtomicUsize, Ordering}; + + use super::*; + use crate::virtq::ring::tests::{TestMem, make_ring}; + use crate::virtq::test_utils::*; + + /// Helper: build and submit an entry+completion chain using the chain() builder. + fn send_readwrite( + producer: &mut VirtqProducer, + entry_data: &[u8], + cqe_cap: usize, + ) -> Token { + let mut se = producer + .chain() + .entry(entry_data.len()) + .completion(cqe_cap) + .build() + .unwrap(); + se.write_all(entry_data).unwrap(); + producer.submit(se).unwrap() + } + + #[test] + fn test_submit_notifies() { + let ring = make_ring(16); + let (mut producer, mut consumer, notifier) = make_test_producer(&ring); + + let initial_count = notifier.notification_count(); + + let token = send_readwrite(&mut producer, b"hello", 64); + assert!(notifier.notification_count() > initial_count); + + let (entry, _completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.token(), token); + } + + #[test] + fn test_multiple_submits() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let tok1 = send_readwrite(&mut producer, b"request1", 64); + let tok2 = send_readwrite(&mut producer, b"request2", 64); + let tok3 = send_readwrite(&mut producer, b"request3", 64); + + // Consumer sees all requests + for _ in 0..3 { + let (_entry, completion) = consumer.poll(1024).unwrap().unwrap(); + consumer.complete(completion).unwrap(); + } + + // All completions available + let cqe1 = producer.poll().unwrap().unwrap(); + let cqe2 = producer.poll().unwrap().unwrap(); + let cqe3 = producer.poll().unwrap().unwrap(); + assert!( + [cqe1.token, cqe2.token, cqe3.token].contains(&tok1) + && [cqe1.token, cqe2.token, cqe3.token].contains(&tok2) + && [cqe1.token, cqe2.token, cqe3.token].contains(&tok3) + ); + } + + #[test] + fn test_completion_batching_with_suppression() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + // Submit entries + let tok1 = send_readwrite(&mut producer, b"req1", 64); + let tok2 = send_readwrite(&mut producer, b"req2", 64); + let tok3 = send_readwrite(&mut producer, b"req3", 64); + + // Set up completion batching via used suppression + let cursor = producer.used_cursor(); + producer + .set_used_suppression(SuppressionKind::Descriptor(cursor)) + .unwrap(); + + // Consumer processes requests + for _ in 0..3 { + let (_entry, completion) = consumer.poll(1024).unwrap().unwrap(); + let SendCompletion::Writable(mut wc) = completion else { + panic!("expected writable completion"); + }; + wc.write_all(b"cqe-data").unwrap(); + consumer.complete(wc.into()).unwrap(); + } + + // Producer can drain all responses + let mut responses = Vec::new(); + producer + .drain(|tok, _data| { + responses.push(tok); + }) + .unwrap(); + + assert_eq!(responses.len(), 3); + assert!(responses.contains(&tok1)); + assert!(responses.contains(&tok2)); + assert!(responses.contains(&tok3)); + } + + #[test] + fn test_notifier_receives_context() { + #[derive(Debug, Clone)] + struct CtxNotifier { + last_num_free: Arc, + last_num_inflight: Arc, + count: Arc, + } + + impl Notifier for CtxNotifier { + fn notify(&self, stats: QueueStats) { + self.last_num_free.store(stats.num_free, Ordering::Relaxed); + self.last_num_inflight + .store(stats.num_inflight, Ordering::Relaxed); + self.count.fetch_add(1, Ordering::Relaxed); + } + } + + let ring = make_ring(16); + let layout = ring.layout(); + let mem = ring.mem(); + let pool_base = mem.base_addr() + Layout::query_size(ring.len()) as u64 + 0x100; + let pool = TestPool::new(pool_base, 0x8000); + let notifier = CtxNotifier { + last_num_free: Arc::new(AtomicUsize::new(0)), + last_num_inflight: Arc::new(AtomicUsize::new(0)), + count: Arc::new(AtomicUsize::new(0)), + }; + + let mut producer = VirtqProducer::new(layout, mem, notifier.clone(), pool); + + let mut se = producer.chain().entry(4).completion(32).build().unwrap(); + se.write_all(b"test").unwrap(); + producer.submit(se).unwrap(); + assert_eq!(notifier.count.load(Ordering::Relaxed), 1); + assert!(notifier.last_num_inflight.load(Ordering::Relaxed) > 0); + } + + #[test] + fn test_chain_zero_copy_batch() { + let ring = make_ring(16); + let (mut producer, mut consumer, notifier) = make_test_producer(&ring); + + let initial_count = notifier.notification_count(); + + // Zero-copy entry via buf_mut + let mut se1 = producer.chain().entry(64).completion(128).build().unwrap(); + let buf = se1.buf_mut().unwrap(); + buf[..6].copy_from_slice(b"zc-ent"); + se1.set_written(6).unwrap(); + let _tok1 = producer.submit(se1).unwrap(); + + // Write-based entry + let mut se2 = producer.chain().entry(64).completion(64).build().unwrap(); + se2.write_all(b"copy-ent").unwrap(); + let _tok2 = producer.submit(se2).unwrap(); + + // Completion-only chain + let se3 = producer.chain().completion(32).build().unwrap(); + let tok3 = producer.submit(se3).unwrap(); + + // Each submit may notify independently + assert!(notifier.notification_count() > initial_count); + + // Consumer sees all three entries + let (entry1, completion1) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry1.data().as_ref(), b"zc-ent"); + consumer.complete(completion1).unwrap(); + + let (entry2, completion2) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry2.data().as_ref(), b"copy-ent"); + consumer.complete(completion2).unwrap(); + + let (_entry3, completion3) = consumer.poll(1024).unwrap().unwrap(); + let SendCompletion::Writable(mut wc) = completion3 else { + panic!("expected writable completion"); + }; + wc.write_all(b"resp").unwrap(); + consumer.complete(wc.into()).unwrap(); + + // Drain completions + let _ = producer.poll().unwrap().unwrap(); + let _ = producer.poll().unwrap().unwrap(); + + let cqe = producer.poll().unwrap().unwrap(); + assert_eq!(cqe.token, tok3); + assert_eq!(&cqe.data[..], b"resp"); + } + + #[test] + fn test_chain_zero_copy_send() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + // Zero-copy send: allocate, write directly, submit + let mut se = producer.chain().entry(64).completion(128).build().unwrap(); + let buf = se.buf_mut().unwrap(); + assert_eq!(buf.len(), 64); + buf[..5].copy_from_slice(b"hello"); + se.set_written(5).unwrap(); + let token = producer.submit(se).unwrap(); + + // Consumer sees the data + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.token(), token); + assert_eq!(entry.data().as_ref(), b"hello"); + + // Write response + let SendCompletion::Writable(mut wc) = completion else { + panic!("expected writable completion"); + }; + wc.write_all(b"world").unwrap(); + consumer.complete(wc.into()).unwrap(); + let cqe = producer.poll().unwrap().unwrap(); + assert_eq!(&cqe.data[..], b"world"); + } + + #[test] + fn test_full_round_trip() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + // Send an entry + let token = send_readwrite(&mut producer, b"round-trip-entry", 128); + + // Consumer receives and responds + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.token(), token); + assert_eq!(entry.data().as_ref(), b"round-trip-entry"); + + let SendCompletion::Writable(mut wc) = completion else { + panic!("expected writable completion"); + }; + assert!(wc.capacity() >= 128); + wc.write_all(b"round-trip-rsp").unwrap(); + consumer.complete(wc.into()).unwrap(); + + // Producer gets the completion + let cqe = producer.poll().unwrap().unwrap(); + assert_eq!(cqe.token, token); + assert_eq!(&cqe.data[..], b"round-trip-rsp"); + } + + #[test] + fn test_cancel_submits_zero_length() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let token = send_readwrite(&mut producer, b"entry-data", 64); + + let (_entry, completion) = consumer.poll(1024).unwrap().unwrap(); + consumer.complete(completion).unwrap(); + + let cqe = producer.poll().unwrap().unwrap(); + assert_eq!(cqe.token, token); + assert_eq!(cqe.data.len(), 0); + assert!(cqe.data.is_empty()); + } + + #[test] + fn test_hold_completion_and_complete() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let token = send_readwrite(&mut producer, b"deferred", 64); + + // Poll and hold the completion + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.token(), token); + assert_eq!(entry.data().as_ref(), b"deferred"); + + let SendCompletion::Writable(mut wc) = completion else { + panic!("expected writable completion"); + }; + wc.write_all(b"deferred-cqe").unwrap(); + consumer.complete(wc.into()).unwrap(); + + let cqe = producer.poll().unwrap().unwrap(); + assert_eq!(cqe.token, token); + assert_eq!(&cqe.data[..], b"deferred-cqe"); + } + + #[test] + fn test_concurrent_pending_completions() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let tok1 = send_readwrite(&mut producer, b"first", 64); + let tok2 = send_readwrite(&mut producer, b"second", 64); + + // Poll both + let (entry1, completion1) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry1.token(), tok1); + assert_eq!(entry1.data().as_ref(), b"first"); + + let (entry2, completion2) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry2.token(), tok2); + assert_eq!(entry2.data().as_ref(), b"second"); + + // Complete second first (out of order) + let SendCompletion::Writable(mut wc2) = completion2 else { + panic!("expected writable"); + }; + wc2.write_all(b"resp2").unwrap(); + consumer.complete(wc2.into()).unwrap(); + + let SendCompletion::Writable(mut wc1) = completion1 else { + panic!("expected writable"); + }; + wc1.write_all(b"resp1").unwrap(); + consumer.complete(wc1.into()).unwrap(); + + let cqe1 = producer.poll().unwrap().unwrap(); + let cqe2 = producer.poll().unwrap().unwrap(); + let mut responses: Vec<_> = vec![ + (cqe1.token, cqe1.data.to_vec()), + (cqe2.token, cqe2.data.to_vec()), + ]; + responses.sort_by_key(|(t, _)| t.0); + + let expected_first = responses.iter().find(|(t, _)| *t == tok1).unwrap(); + let expected_second = responses.iter().find(|(t, _)| *t == tok2).unwrap(); + assert_eq!(&expected_first.1[..], b"resp1"); + assert_eq!(&expected_second.1[..], b"resp2"); + } +} +#[cfg(all(test, loom))] +mod fuzz { + //! Loom-based concurrency testing for the virtqueue implementation. + //! + //! Loom will explores all possible thread interleavings to find data races + //! and other concurrency bugs. However, it has specific requirements that + //! make our memory model more involved: + //! + //! ## Flag-Based Synchronization + //! + //! The virtqueue protocol uses flag-based synchronization: + //! 1. Producer writes descriptor fields (addr, len, id), then writes flags with release semantics + //! 2. Consumer reads flags with acquire semantics, then reads descriptor fields + //! + //! Loom would see this as concurrent access to the same memory and report a race, even though + //! acquire/release on flags provides proper synchronization. + //! + //! ## Shadow Atomics for Flags + //! + //! We maintain shadow atomics that loom tracks for synchronization: + //! + //! - `desc_flags`: One `AtomicU16` per descriptor for flags field + //! - `drv_flags`: `AtomicU16` for driver event suppression flags + //! - `dev_flags`: `AtomicU16` for device event suppression flags + //! + //! The `load_acquire`/`store_release` operations use these loom atomics, + //! while `read`/`write` access the underlying data directly. + //! + //! ## Memory Regions + //! + //! We use a `BTreeMap` to map addresses to memory regions: + //! - `Desc(idx)`: Individual descriptors in the ring + //! - `DrvEvt`: Driver event suppression structure + //! - `DevEvt`: Device event suppression structure + //! - `Pool`: Buffer pool for entry/completion data + + use alloc::collections::BTreeMap; + use alloc::sync::Arc; + use alloc::vec; + use core::num::NonZeroU16; + + use bytemuck::Zeroable; + use loom::sync::atomic::{AtomicU16, AtomicUsize, Ordering}; + use loom::thread; + + use super::*; + use crate::virtq::desc::Descriptor; + use crate::virtq::pool::BufferPoolSync; + + #[derive(Debug)] + pub struct MemErr; + + #[derive(Debug, Clone, Copy)] + enum RegionKind { + Desc(usize), + DrvEvt, + DevEvt, + Pool, + } + + #[derive(Debug, Clone, Copy)] + struct RegionInfo { + kind: RegionKind, + size: usize, + } + + #[derive(Debug)] + pub struct LoomMem { + descs: Vec, + drv: core::cell::UnsafeCell, + dev: core::cell::UnsafeCell, + pool: loom::cell::UnsafeCell>, + + desc_flags: Vec, + drv_flags: AtomicU16, + dev_flags: AtomicU16, + + regions: BTreeMap, + layout: Layout, + } + + unsafe impl Sync for LoomMem {} + unsafe impl Send for LoomMem {} + + impl LoomMem { + pub fn new(ring_base: u64, num_descs: usize, pool_base: u64, pool_size: usize) -> Self { + let descs_nz = NonZeroU16::new(num_descs as u16).unwrap(); + let layout = unsafe { Layout::from_base(ring_base, descs_nz).unwrap() }; + + let descs: Vec<_> = (0..num_descs).map(|_| Descriptor::zeroed()).collect(); + let desc_flags: Vec<_> = (0..num_descs).map(|_| AtomicU16::new(0)).collect(); + + let mut regions = BTreeMap::new(); + + // Register each descriptor as a separate region + for i in 0..num_descs { + let addr = layout.desc_table_addr() + (i * Descriptor::SIZE) as u64; + regions.insert( + addr, + RegionInfo { + kind: RegionKind::Desc(i), + size: Descriptor::SIZE, + }, + ); + } + + regions.insert( + layout.drv_evt_addr(), + RegionInfo { + kind: RegionKind::DrvEvt, + size: EventSuppression::SIZE, + }, + ); + + regions.insert( + layout.dev_evt_addr(), + RegionInfo { + kind: RegionKind::DevEvt, + size: EventSuppression::SIZE, + }, + ); + + regions.insert( + pool_base, + RegionInfo { + kind: RegionKind::Pool, + size: pool_size, + }, + ); + + Self { + descs, + drv: core::cell::UnsafeCell::new(EventSuppression::zeroed()), + dev: core::cell::UnsafeCell::new(EventSuppression::zeroed()), + pool: loom::cell::UnsafeCell::new(vec![0u8; pool_size]), + desc_flags, + drv_flags: AtomicU16::new(0), + dev_flags: AtomicU16::new(0), + regions, + layout, + } + } + + pub fn layout(&self) -> Layout { + self.layout + } + + fn region(&self, addr: u64) -> Option<(RegionInfo, usize)> { + let (&base, &info) = self.regions.range(..=addr).next_back()?; + let offset = (addr - base) as usize; + + if offset < info.size { + Some((info, offset)) + } else { + None + } + } + + fn desc_ptr(&self, idx: usize) -> *mut Descriptor { + self.descs.as_ptr().cast_mut().wrapping_add(idx) + } + } + + unsafe impl MemOps for Arc { + type Error = MemErr; + + fn read(&self, addr: u64, dst: &mut [u8]) -> Result<(), Self::Error> { + let (info, offset) = self.region(addr).ok_or(MemErr)?; + + match info.kind { + RegionKind::Desc(idx) => { + let desc = unsafe { &*self.desc_ptr(idx) }; + let bytes = bytemuck::bytes_of(desc); + dst.copy_from_slice(&bytes[offset..offset + dst.len()]); + } + RegionKind::DrvEvt => { + let evt = unsafe { &*self.drv.get() }; + let bytes = bytemuck::bytes_of(evt); + dst.copy_from_slice(&bytes[offset..offset + dst.len()]); + } + RegionKind::DevEvt => { + let evt = unsafe { &*self.dev.get() }; + let bytes = bytemuck::bytes_of(evt); + dst.copy_from_slice(&bytes[offset..offset + dst.len()]); + } + RegionKind::Pool => { + self.pool.with(|buf| { + dst.copy_from_slice(&(unsafe { &*buf })[offset..offset + dst.len()]); + }); + } + } + Ok(()) + } + + fn write(&self, addr: u64, src: &[u8]) -> Result<(), Self::Error> { + let (info, offset) = self.region(addr).ok_or(MemErr)?; + + match info.kind { + RegionKind::Desc(idx) => { + let desc = unsafe { &mut *self.desc_ptr(idx) }; + let bytes = bytemuck::bytes_of_mut(desc); + bytes[offset..offset + src.len()].copy_from_slice(src); + } + RegionKind::DrvEvt => { + let evt = unsafe { &mut *self.drv.get() }; + let bytes = bytemuck::bytes_of_mut(evt); + bytes[offset..offset + src.len()].copy_from_slice(src); + } + RegionKind::DevEvt => { + let evt = unsafe { &mut *self.dev.get() }; + let bytes = bytemuck::bytes_of_mut(evt); + bytes[offset..offset + src.len()].copy_from_slice(src); + } + RegionKind::Pool => { + self.pool.with_mut(|buf| { + (unsafe { &mut *buf })[offset..offset + src.len()].copy_from_slice(src); + }); + } + } + Ok(()) + } + + fn load_acquire(&self, addr: u64) -> Result { + let (info, _offset) = self.region(addr).ok_or(MemErr)?; + + Ok(match info.kind { + RegionKind::Desc(idx) => self.desc_flags[idx].load(Ordering::Acquire), + RegionKind::DrvEvt => self.drv_flags.load(Ordering::Acquire), + RegionKind::DevEvt => self.dev_flags.load(Ordering::Acquire), + RegionKind::Pool => return Err(MemErr), + }) + } + + fn store_release(&self, addr: u64, val: u16) -> Result<(), Self::Error> { + let (info, _offset) = self.region(addr).ok_or(MemErr)?; + + match info.kind { + RegionKind::Desc(idx) => self.desc_flags[idx].store(val, Ordering::Release), + RegionKind::DrvEvt => self.drv_flags.store(val, Ordering::Release), + RegionKind::DevEvt => self.dev_flags.store(val, Ordering::Release), + RegionKind::Pool => return Err(MemErr), + } + Ok(()) + } + + unsafe fn as_slice(&self, addr: u64, len: usize) -> Result<&[u8], Self::Error> { + let (info, offset) = self.region(addr).ok_or(MemErr)?; + + match info.kind { + RegionKind::Pool => { + // Safety: pool memory is a contiguous Vec; caller ensures + // no concurrent writes for the lifetime of the returned slice. + let buf = unsafe { &*self.pool.get() }; + Ok(&buf[offset..offset + len]) + } + _ => Err(MemErr), + } + } + + unsafe fn as_mut_slice(&self, addr: u64, len: usize) -> Result<&mut [u8], Self::Error> { + let (info, offset) = self.region(addr).ok_or(MemErr)?; + + match info.kind { + RegionKind::Pool => { + let buf = unsafe { &mut *self.pool.get() }; + Ok(&mut buf[offset..offset + len]) + } + _ => Err(MemErr), + } + } + } + + #[derive(Debug)] + pub struct Notify { + kicks: AtomicUsize, + } + + impl Notify { + pub fn new() -> Self { + Self { + kicks: AtomicUsize::new(0), + } + } + } + + impl Notifier for Arc { + fn notify(&self, _stats: QueueStats) { + self.kicks.fetch_add(1, Ordering::Relaxed); + } + } + + #[test] + fn virtq_ping_pong() { + loom::model(|| { + let ring_base = 0x10000; + let pool_base = 0x40000; + let pool_size = 0x10000; + + let mem = Arc::new(LoomMem::new(ring_base, 8, pool_base, pool_size)); + let pool = BufferPoolSync::<256, 4096>::new(pool_base, pool_size).unwrap(); + let notify = Arc::new(Notify::new()); + + let mut prod = VirtqProducer::new(mem.layout(), mem.clone(), notify.clone(), pool); + let mut cons = VirtqConsumer::new(mem.layout(), mem.clone(), notify.clone()); + + let t_prod = thread::spawn(move || { + let mut se = prod.chain().entry(4).completion(32).build().unwrap(); + se.write_all(b"ping").unwrap(); + let tok = prod.submit(se).unwrap(); + loop { + if let Some(r) = prod.poll().unwrap() { + assert_eq!(r.token, tok); + assert_eq!(&r.data[..], b"pong"); + break; + } + thread::yield_now(); + } + }); + + let t_cons = thread::spawn(move || { + let (entry, completion) = loop { + if let Some(r) = cons.poll(1024).unwrap() { + break r; + } + thread::yield_now(); + }; + assert_eq!(entry.data().as_ref(), b"ping"); + let SendCompletion::Writable(mut wc) = completion else { + panic!("expected writable completion"); + }; + wc.write_all(b"pong").unwrap(); + cons.complete(wc.into()).unwrap(); + }); + + t_prod.join().unwrap(); + t_cons.join().unwrap(); + }); + } +} diff --git a/src/hyperlight_common/src/virtq/pool.rs b/src/hyperlight_common/src/virtq/pool.rs new file mode 100644 index 000000000..0324c08fe --- /dev/null +++ b/src/hyperlight_common/src/virtq/pool.rs @@ -0,0 +1,1334 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +//! Simple bitmap-based allocator for virtio buffer management. +//! +//! This module provides two layers: +//! +//! - [`Slab`] - a fixed-size region allocator with a power-of-two slot size `N`, +//! backed by a flat bitmap (`FixedBitSet`). +//! - [`BufferPool`] - a two-tier pool that composes two slabs: one with small +//! slots (e.g. 256 bytes) for control messages / small descriptors, and one +//! with page-sized slots (e.g. 4 KiB) for data buffers. +//! +//! # Design and algorithm +//! +//! The core allocation strategy is a bitmap allocator that performs a linear +//! search over the bitmap, but implemented via `fixedbitset`'s SIMD iteration +//! over zero bits. This is conceptually simpler than tree-based allocators +//! (e.g. linked lists or bitmaps representing a tree as in +//! ), yet for "moderate" region sizes it can +//! be faster in practice: +//! +//! - `FixedBitSet::zeroes()` and related methods use word/SIMD operations to +//! skip over runs of set bits, so the linear search is over words rather than +//! individual bits. +//! - We scan for a contiguous run of free bits corresponding to the required +//! number of slots; no auxiliary tree structure is maintained. +//! +//! The tree-based approach (bitmap encoding a tree and doing a binary search +//! in O(log(n)) time) is a natural next step if larger regions or stricter worst +//! case bounds are required; switching to such a representation should be +//! relatively straightforward since all allocation paths go through a single +//! `find_slots` function. +//! +//! # Locality characteristics +//! +//! The allocator tends to preserve spatial locality: +//! +//! - It searches from low indices upward, returning the first run of free +//! slots large enough for the request. Slots are merged if necessary. +//! - Freed runs are cached in `last_free_run` and reused eagerly, which +//! introduces a mild LIFO behavior for recently freed blocks. +//! - As a result, consecutive allocations are likely to end up in nearby slots, +//! which keeps virtqueue descriptors, control buffers, and data buffers +//! clustered in memory and helps cache performance. +//! +//! # Two-tier buffer pool +//! +//! [`BufferPool`] divides the underlying region into two slabs with different +//! slot sizes: +//! +//! - The lower tier (`Slab`, default `L = 256`) is intended for +//! *smaller allocations* - control messages, descriptor metadata, and other +//! small structures. Small allocations first try this tier. +//! - The upper tier (`Slab`, default `U = 4096`) uses page sized slots +//! and is intended for larger data buffers. +//! +//! The split of the region is currently fixed at a constant fraction +//! (`LOWER_FRACTION`) for the lower slab and the remainder for the upper slab. +//! +//! Allocation policy: +//! +//! - Requests `<= L` bytes are first attempted in the lower slab; on +//! `OutOfMemory` they fall back to the upper slab. +//! - Larger requests go directly to the upper slab. +//! - [`BufferPool::resize`] will try to grow or shrink in place within the +//! owning slab (`Slab::resize`) but will never move allocations between +//! slabs. + +#[cfg(all(test, loom))] +use alloc::sync::Arc; +use core::cmp::Ordering; + +use atomic_refcell::AtomicRefCell; +use fixedbitset::FixedBitSet; +use thiserror::Error; + +use super::access::MemOps; + +#[derive(Debug, Error, Copy, Clone)] +pub enum AllocError { + #[error("Invalid region addr {0}")] + InvalidAlign(u64), + #[error("Invalid free addr {0} and size {1}")] + InvalidFree(u64, usize), + #[error("Invalid argument")] + InvalidArg, + #[error("Empty region")] + EmptyRegion, + #[error("Out of memory")] + OutOfMemory, + #[error("Overflow")] + Overflow, +} + +/// Allocation result +#[derive(Debug, Clone, Copy)] +pub struct Allocation { + /// Starting address of the allocation + pub addr: u64, + /// Length of the allocation in bytes rounded up to slab size + pub len: usize, +} + +/// Trait for buffer providers. +pub trait BufferProvider { + /// Allocate at least `len` bytes. + fn alloc(&self, len: usize) -> Result; + + /// Free a previously allocated block. + fn dealloc(&self, alloc: Allocation) -> Result<(), AllocError>; + + /// Resize by trying in-place grow; otherwise reserve a new block and free old. + fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result; +} + +/// The owner of a mapped buffer, ensuring its lifetime. +/// +/// Holds a pool allocation and provides direct access to the underlying +/// shared memory via [`MemOps::as_slice`]. Implements `AsRef<[u8]>` so it +/// can be used with [`Bytes::from_owner`](bytes::Bytes::from_owner) for +/// zero-copy `Bytes` backed by shared memory. +/// +/// When dropped, the allocation is returned to the pool. +#[derive(Debug, Clone)] +pub struct BufferOwner { + pub(crate) pool: P, + pub(crate) mem: M, + pub(crate) alloc: Allocation, + pub(crate) written: usize, +} + +impl Drop for BufferOwner { + fn drop(&mut self) { + let _ = self.pool.dealloc(self.alloc); + } +} + +impl AsRef<[u8]> for BufferOwner { + fn as_ref(&self) -> &[u8] { + let len = self.written.min(self.alloc.len); + // Safety: BufferOwner keeps both the pool allocation and the M + // alive, so the memory region is valid. Protocol-level descriptor + // ownership transfer guarantees no concurrent writes. + match unsafe { self.mem.as_slice(self.alloc.addr, len) } { + Ok(slice) => slice, + Err(_) => &[], + } + } +} + +/// A guard that runs a cleanup function when dropped, unless dismissed. +pub struct AllocGuard(Option<(Allocation, F)>); + +impl AllocGuard { + pub fn new(alloc: Allocation, cleanup: F) -> Self { + Self(Some((alloc, cleanup))) + } + + pub fn release(mut self) -> Allocation { + self.0.take().unwrap().0 + } +} + +impl core::ops::Deref for AllocGuard { + type Target = Allocation; + + fn deref(&self) -> &Allocation { + &self.0.as_ref().unwrap().0 + } +} + +impl Drop for AllocGuard { + fn drop(&mut self) { + if let Some((alloc, cleanup)) = self.0.take() { + cleanup(alloc) + } + } +} + +#[derive(Debug, Clone)] +pub struct Slab { + /// Base address of the slab + base_addr: u64, + /// Flat bitmap to track allocated/free slots + used_slots: FixedBitSet, + /// Last free allocation cache + last_free_run: Option, +} + +impl Slab { + /// Create a new slab allocator over a fixed region. + /// Region is rounded down to a multiple of N. + pub fn new(base_addr: u64, region_len: usize) -> Result { + let usable = region_len - (region_len % N); + let num_slots = usable / N; + let used_slots = FixedBitSet::with_capacity(num_slots); + + if base_addr % (N as u64) != 0 { + return Err(AllocError::InvalidAlign(base_addr)); + } + + if num_slots == 0 { + return Err(AllocError::EmptyRegion); + } + + Ok(Self { + base_addr, + used_slots, + last_free_run: None, + }) + } + + /// Get the address of a slot by its index + #[inline] + fn addr_of(&self, slot_idx: usize) -> Option { + self.base_addr + .checked_add((slot_idx as u64).checked_mul(N as u64)?) + } + + /// Get the slot index for a given address + #[inline] + fn slot_of(&self, addr: u64) -> usize { + let off = (addr - self.base_addr) as usize; + off / N + } + + /// Invalidate last_free_run cache if it overlaps with the given allocation. + fn maybe_invalidate_last_run(&mut self, alloc: Allocation) { + if let Some(run) = &self.last_free_run { + let new_end = alloc.addr + alloc.len as u64; + let run_end = run.addr + run.len as u64; + + if alloc.addr < run_end && run.addr < new_end { + self.last_free_run = None; + } + } + } + + /// Find a run of slots to satisfy at least `len` bytes starting at `start`. + pub fn find_slots(&mut self, slots_num: usize) -> Option { + debug_assert!(slots_num > 0); + + // Check last free run optimization + if let Some(alloc) = self.last_free_run + && alloc.len >= slots_num * N + { + let pos = self.slot_of(alloc.addr); + let _ = self.last_free_run.take(); + return Some(pos); + } + + // Fallback to full search + self.used_slots.zeroes().find(|&next_free| { + self.used_slots + .count_zeroes(next_free..next_free + slots_num) + == slots_num + }) + } + + /// Allocate at least `len` bytes by merging consecutive slots. + pub fn alloc(&mut self, len: usize) -> Result { + if len == 0 { + return Err(AllocError::InvalidArg); + } + + let total = self.used_slots.len(); + let need_slots = len.div_ceil(N); + if need_slots > total { + return Err(AllocError::OutOfMemory); + } + + let idx = self.find_slots(need_slots).ok_or(AllocError::OutOfMemory)?; + self.used_slots.insert_range(idx..idx + need_slots); + let addr = self.addr_of(idx).ok_or(AllocError::Overflow)?; + + let alloc = Allocation { + addr, + len: need_slots * N, + }; + + self.maybe_invalidate_last_run(alloc); + Ok(alloc) + } + + /// Free a previously allocated slot or multiple slots. + /// + /// `len` must be a multiple of N and `addr` must be N-aligned to base. + pub fn dealloc(&mut self, alloc: Allocation) -> Result<(), AllocError> { + let Allocation { addr, len } = alloc; + if len == 0 || len % N != 0 || addr < self.base_addr { + return Err(AllocError::InvalidFree(addr, len)); + } + let alloc_slots = len / N; + let off = (addr - self.base_addr) as usize; + if off % N != 0 { + return Err(AllocError::InvalidFree(addr, len)); + } + let start = off / N; + let num_slots = self.used_slots.len(); + if start + alloc_slots > num_slots { + return Err(AllocError::InvalidFree(addr, len)); + } + + // Ensure all bits are set (avoid double-free) + if !self + .used_slots + .contains_all_in_range(start..start + alloc_slots) + { + return Err(AllocError::InvalidFree(addr, len)); + } + + // Mark as free + self.used_slots.remove_range(start..start + alloc_slots); + self.last_free_run = Some(alloc); + + Ok(()) + } + + /// Try to grow a block in place by reserving adjacent free slots to the right. + /// + /// Returns Ok(None) if in-place growth is not possible. Returns Err on invalid input. + pub fn try_grow_inplace( + &mut self, + old_alloc: Allocation, + new_len: usize, + ) -> Result, AllocError> { + let Allocation { + addr: old_addr, + len: old_len, + } = old_alloc; + + if new_len <= old_len || old_len == 0 || old_len % N != 0 { + return Err(AllocError::InvalidFree(old_addr, old_len)); + } + + let old_slots = old_len / N; + let need_slots = new_len.div_ceil(N); + let off = (old_addr - self.base_addr) as usize; + if off % N != 0 { + return Err(AllocError::InvalidFree(old_addr, old_len)); + } + + let start = off / N; + if start + need_slots > self.used_slots.len() { + return Ok(None); + } + // Existing range must be allocated + if !self + .used_slots + .contains_all_in_range(start..start + old_slots) + { + return Err(AllocError::InvalidFree(old_addr, old_len)); + } + + // Extension must be free + if self + .used_slots + .count_ones(start + old_slots..start + need_slots) + > 0 + { + return Ok(None); + } + + // Mark extension as allocated + self.used_slots + .insert_range(start + old_slots..start + need_slots); + + let alloc = Allocation { + addr: old_addr, + len: need_slots * N, + }; + + self.maybe_invalidate_last_run(alloc); + Ok(Some(alloc)) + } + + /// Shrink a block in place by freeing excess slots to the right. + pub fn shrink_inplace( + &mut self, + old_alloc: Allocation, + new_len: usize, + ) -> Result { + let Allocation { + addr: old_addr, + len: old_len, + } = old_alloc; + + if new_len >= old_len || old_len == 0 || old_len % N != 0 { + return Err(AllocError::InvalidFree(old_addr, old_len)); + } + + let old_slots = old_len / N; + let need_slots = new_len.div_ceil(N); + let off = (old_addr - self.base_addr) as usize; + if off % N != 0 { + return Err(AllocError::InvalidFree(old_addr, old_len)); + } + + let start = off / N; + if start + old_slots > self.used_slots.len() { + return Err(AllocError::InvalidFree(old_addr, old_len)); + } + // Existing range must be allocated + if !self + .used_slots + .contains_all_in_range(start..start + old_slots) + { + return Err(AllocError::InvalidFree(old_addr, old_len)); + } + + // Free the excess slots + self.used_slots + .remove_range(start + need_slots..start + old_slots); + + Ok(Allocation { + addr: old_addr, + len: need_slots * N, + }) + } + + /// Reallocate by trying in-place grow; otherwise reserve a new run of slots and free old. + /// Caller should copy the payload; this function only manages reservations. + pub fn resize( + &mut self, + old_alloc: Allocation, + new_len: usize, + ) -> Result { + if new_len == 0 { + return Err(AllocError::InvalidArg); + } + + match new_len.cmp(&old_alloc.len) { + Ordering::Greater => { + match self.try_grow_inplace(old_alloc, new_len) { + // in-place growth succeeded + Ok(Some(new_alloc)) => Ok(new_alloc), + // in-place growth failed; allocate new and free old + Ok(None) => { + let new_alloc = self.alloc(new_len)?; + self.dealloc(old_alloc)?; + Ok(new_alloc) + } + // other errors are propagated + Err(err) => Err(err), + } + } + Ordering::Less => self.shrink_inplace(old_alloc, new_len), + Ordering::Equal => Ok(old_alloc), + } + } + + /// Usable size rounded up to slot multiple. + pub fn usable_size(&self, _addr: usize, len: usize) -> usize { + if len == 0 { 0 } else { len.div_ceil(N) * N } + } + + /// Number of free bytes in the slab. + pub fn free_bytes(&self) -> usize { + (self.used_slots.len() - self.used_slots.count_ones(..)) * N + } + + /// Total capacity of the slab in bytes. + pub fn capacity(&self) -> usize { + self.used_slots.len() * N + } + + /// Get the address range covered by this slab. + pub fn range(&self) -> core::ops::Range { + let end = self.base_addr + self.capacity() as u64; + self.base_addr..end + } + + /// Check if an address is within this slab's range. + pub fn contains(&self, addr: u64) -> bool { + self.range().contains(&addr) + } + + /// Get the slot size N. + pub const fn slot_size() -> usize { + N + } +} + +#[inline] +fn align_up(val: usize, align: usize) -> usize { + assert!(align > 0); + if val == 0 { + return 0; + } + val.div_ceil(align) * align +} + +#[derive(Debug)] +struct Inner { + lower: Slab, + upper: Slab, +} + +/// Two tier buffer pool with small and large slabs. +#[derive(Debug)] +pub struct BufferPool { + inner: AtomicRefCell>, +} + +impl BufferPool { + /// Create a new buffer pool over a fixed region. + pub fn new(base_addr: u64, region_len: usize) -> Result { + let inner = Inner::::new(base_addr, region_len)?; + Ok(Self { + inner: inner.into(), + }) + } +} + +#[cfg(all(test, loom))] +#[derive(Debug, Clone)] +pub struct BufferPoolSync { + inner: std::sync::Arc>>, +} + +#[cfg(all(test, loom))] +impl BufferPoolSync { + /// Create a new buffer pool over a fixed region. + pub fn new(base_addr: u64, region_len: usize) -> Result { + let inner = Inner::::new(base_addr, region_len)?; + Ok(Self { + inner: Arc::new(std::sync::Mutex::new(inner)), + }) + } +} + +impl Inner { + /// Create a new buffer pool over a fixed region. + pub fn new(base_addr: u64, region_len: usize) -> Result { + const LOWER_FRACTION: usize = 8; + + let lower_region = region_len / LOWER_FRACTION; + let upper_region = region_len - lower_region; + + let mut aligned = base_addr; + aligned = align_up(aligned as usize, L) as u64; + let lower = Slab::::new(aligned, lower_region)?; + + // advance and align upper base to N + aligned = aligned + .checked_add(lower.capacity() as u64) + .ok_or(AllocError::Overflow)?; + + aligned = align_up(aligned as usize, U) as u64; + let upper = Slab::::new(aligned, upper_region)?; + + Ok(Self { lower, upper }) + } + + /// Allocate at least `len` bytes. + pub fn alloc(&mut self, len: usize) -> Result { + if len <= L { + match self.lower.alloc(len) { + Ok(alloc) => return Ok(alloc), + Err(AllocError::OutOfMemory) => {} + Err(e) => return Err(e), + } + } + + // fallback to upper slab + self.upper.alloc(len) + } + + /// Free a previously allocated block. + pub fn dealloc(&mut self, alloc: Allocation) -> Result<(), AllocError> { + if self.lower.contains(alloc.addr) { + self.lower.dealloc(alloc) + } else { + self.upper.dealloc(alloc) + } + } + + /// Reallocate by trying in-place grow; otherwise reserve a new block and free old. + pub fn resize( + &mut self, + old_alloc: Allocation, + new_len: usize, + ) -> Result { + if self.lower.contains(old_alloc.addr) { + maybe_move(&mut self.lower, &mut self.upper, old_alloc, new_len) + } else { + maybe_move(&mut self.upper, &mut self.lower, old_alloc, new_len) + } + } +} + +/// Try to realloc using slab that owns the old allocation; if that fails, +/// try to allocate in the other slab. The function prefers to move allocations +/// between slabs only when necessary based on size thresholds. +#[inline] +fn maybe_move( + slab: &mut Slab, + other: &mut Slab, + old_alloc: Allocation, + new_len: usize, +) -> Result { + let needs_move = if A < B { new_len > A } else { new_len <= B }; + if !needs_move { + return slab.resize(old_alloc, new_len); + } + + let new_alloc = other.alloc(new_len)?; + + slab.dealloc(old_alloc)?; + Ok(new_alloc) +} + +impl BufferProvider for BufferPool { + fn alloc(&self, len: usize) -> Result { + self.inner.borrow_mut().alloc(len) + } + + fn dealloc(&self, alloc: Allocation) -> Result<(), AllocError> { + self.inner.borrow_mut().dealloc(alloc) + } + + fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result { + self.inner.borrow_mut().resize(old_alloc, new_len) + } +} + +#[cfg(all(test, loom))] +impl BufferProvider for BufferPoolSync { + fn alloc(&self, len: usize) -> Result { + self.inner.lock().expect("poisoned mutex").alloc(len) + } + + fn dealloc(&self, alloc: Allocation) -> Result<(), AllocError> { + self.inner.lock().expect("poisoned mutex").dealloc(alloc) + } + + fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result { + self.inner + .lock() + .expect("poisoned mutex") + .resize(old_alloc, new_len) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_slab(size: usize) -> Slab { + let base = align_up(0x10000, N) as u64; + Slab::::new(base, size).unwrap() + } + + fn make_pool(size: usize) -> BufferPool { + let base = align_up(0x10000, L.max(U)) as u64; + BufferPool::::new(base, size).unwrap() + } + + #[test] + fn test_slab_new_success() { + let slab = Slab::<256>::new(0x10000, 1024).unwrap(); + assert_eq!(slab.capacity(), 1024); + assert_eq!(slab.free_bytes(), 1024); + } + + #[test] + fn test_slab_new_misaligned() { + let result = Slab::<256>::new(0x10001, 1024); + assert!(matches!(result, Err(AllocError::InvalidAlign(0x10001)))); + } + + #[test] + fn test_slab_new_empty_region() { + let result = Slab::<256>::new(0x10000, 100); + assert!(matches!(result, Err(AllocError::EmptyRegion))); + } + + #[test] + fn test_slab_alloc_single_slot() { + let mut slab = make_slab::<256>(1024); + let alloc = slab.alloc(128).unwrap(); + assert_eq!(alloc.len, 256); + assert_eq!(slab.free_bytes(), 1024 - 256); + } + + #[test] + fn test_slab_alloc_multiple_slots() { + let mut slab = make_slab::<256>(1024); + let alloc = slab.alloc(600).unwrap(); + assert_eq!(alloc.len, 768); // 3 slots × 256 bytes + assert_eq!(slab.free_bytes(), 1024 - 768); + } + + #[test] + fn test_slab_alloc_zero_length() { + let mut slab = make_slab::<256>(1024); + let result = slab.alloc(0); + assert!(matches!(result, Err(AllocError::InvalidArg))); + } + + #[test] + fn test_slab_alloc_too_large() { + let mut slab = make_slab::<256>(1024); + let result = slab.alloc(2048); + assert!(matches!(result, Err(AllocError::OutOfMemory))); + } + + #[test] + fn test_slab_alloc_until_full() { + let mut slab = make_slab::<256>(1024); + + // Allocate all 4 slots + let _a1 = slab.alloc(256).unwrap(); + let a2 = slab.alloc(256).unwrap(); + let _a3 = slab.alloc(256).unwrap(); + let _a4 = slab.alloc(256).unwrap(); + + assert_eq!(slab.free_bytes(), 0); + + // Next allocation should fail + let result = slab.alloc(256); + assert!(matches!(result, Err(AllocError::OutOfMemory))); + + // Free one and retry + slab.dealloc(a2).unwrap(); + let a5 = slab.alloc(256).unwrap(); + assert_eq!(a5.addr, a2.addr); // Should reuse same slot + } + + #[test] + fn test_slab_free_success() { + let mut slab = make_slab::<256>(1024); + let alloc = slab.alloc(256).unwrap(); + assert_eq!(slab.free_bytes(), 768); + + slab.dealloc(alloc).unwrap(); + assert_eq!(slab.free_bytes(), 1024); + } + + #[test] + fn test_slab_free_invalid_length() { + let mut slab = make_slab::<256>(1024); + let mut alloc = slab.alloc(256).unwrap(); + alloc.len = 100; // Invalid: not multiple of N + + let result = slab.dealloc(alloc); + assert!(matches!(result, Err(AllocError::InvalidFree(_, 100)))); + } + + #[test] + fn test_slab_free_double_free() { + let mut slab = make_slab::<256>(1024); + let alloc = slab.alloc(256).unwrap(); + + slab.dealloc(alloc).unwrap(); + let result = slab.dealloc(alloc); + assert!(matches!(result, Err(AllocError::InvalidFree(_, _)))); + } + + #[test] + fn test_slab_free_invalid_address() { + let mut slab = make_slab::<256>(1024); + let alloc = Allocation { + addr: 0x99999, + len: 256, + }; + + let result = slab.dealloc(alloc); + assert!(matches!(result, Err(AllocError::InvalidFree(0x99999, _)))); + } + + #[test] + fn test_slab_cursor_optimization_lifo() { + let mut slab = make_slab::<256>(1024); + + let a1 = slab.alloc(256).unwrap(); + let addr1 = a1.addr; + + slab.dealloc(a1).unwrap(); + + // Next allocation should reuse same slot (cursor moved back) + let a2 = slab.alloc(256).unwrap(); + assert_eq!(a2.addr, addr1); + } + + #[test] + fn test_slab_cursor_rewind_for_single_slot() { + let mut slab = make_slab::<256>(1024); + + let _a1 = slab.alloc(256).unwrap(); + let a2 = slab.alloc(256).unwrap(); + let _a3 = slab.alloc(256).unwrap(); + + // Free single-slot at position 1, before cursor at 3 + slab.dealloc(a2).unwrap(); + + // Cursor should rewind to 1 + let a4 = slab.alloc(256).unwrap(); + // Should reuse slot 1 + assert_eq!(a4.addr, a2.addr); + } + + #[test] + fn test_slab_grow_inplace_success() { + let mut slab = make_slab::<256>(1024); + let alloc = slab.alloc(256).unwrap(); + + // Grow from 256 to 512 (adjacent slot is free) + let grown = slab.try_grow_inplace(alloc, 512).unwrap(); + assert!(grown.is_some()); + assert_eq!(grown.unwrap().len, 512); + assert_eq!(grown.unwrap().addr, alloc.addr); + } + + #[test] + fn test_slab_grow_inplace_blocked() { + let mut slab = make_slab::<256>(1024); + let a1 = slab.alloc(256).unwrap(); + let _a2 = slab.alloc(256).unwrap(); // Blocks growth + + // Can't grow because next slot is allocated + let result = slab.try_grow_inplace(a1, 512).unwrap(); + assert!(result.is_none()); + } + + #[test] + fn test_slab_shrink_inplace() { + let mut slab = make_slab::<256>(1024); + let alloc = slab.alloc(512).unwrap(); // 2 slots + + let shrunk = slab.shrink_inplace(alloc, 256).unwrap(); + assert_eq!(shrunk.len, 256); + assert_eq!(shrunk.addr, alloc.addr); + assert_eq!(slab.free_bytes(), 1024 - 256); + } + + #[test] + fn test_slab_realloc_grow_inplace() { + let mut slab = make_slab::<256>(1024); + let alloc = slab.alloc(256).unwrap(); + + let new_alloc = slab.resize(alloc, 512).unwrap(); + assert_eq!(new_alloc.addr, alloc.addr); // Same address (in-place) + assert_eq!(new_alloc.len, 512); + } + + #[test] + fn test_slab_realloc_grow_relocate() { + let mut slab = make_slab::<256>(1024); + let a1 = slab.alloc(256).unwrap(); + let _a2 = slab.alloc(256).unwrap(); // Blocks growth + + let new_alloc = slab.resize(a1, 512).unwrap(); + assert_ne!(new_alloc.addr, a1.addr); // Different address (relocated) + assert_eq!(new_alloc.len, 512); + } + + #[test] + fn test_slab_realloc_shrink() { + let mut slab = make_slab::<256>(1024); + let alloc = slab.alloc(512).unwrap(); + + let new_alloc = slab.resize(alloc, 256).unwrap(); + assert_eq!(new_alloc.addr, alloc.addr); + assert_eq!(new_alloc.len, 256); + } + + #[test] + fn test_slab_realloc_same_size() { + let mut slab = make_slab::<256>(1024); + let alloc = slab.alloc(256).unwrap(); + + let new_alloc = slab.resize(alloc, 256).unwrap(); + assert_eq!(new_alloc.addr, alloc.addr); + assert_eq!(new_alloc.len, alloc.len); + } + + #[test] + fn test_slab_fragmentation_handling() { + let mut slab = make_slab::<256>(1024); + + // Create fragmentation: [U][F][U][F] + let a1 = slab.alloc(256).unwrap(); + let a2 = slab.alloc(256).unwrap(); + let _a3 = slab.alloc(256).unwrap(); + let _a4 = slab.alloc(256).unwrap(); + + slab.dealloc(a2).unwrap(); + slab.dealloc(a1).unwrap(); + + // Should still be able to allocate 2-slot buffer + let big = slab.alloc(512).unwrap(); + assert_eq!(big.len, 512); + } + + #[test] + fn test_pool_new_success() { + let pool = BufferPool::<256, 4096>::new(0x10000, 1024 * 1024).unwrap(); + assert!(pool.inner.borrow().lower.capacity() > 0); + assert!(pool.inner.borrow().upper.capacity() > 0); + } + + #[test] + fn test_pool_alloc_small_to_lower() { + let pool = make_pool::<256, 4096>(1024 * 1024); + let alloc = pool.alloc(128).unwrap(); + + // Should come from lower slab + assert!(pool.inner.borrow().lower.contains(alloc.addr)); + assert_eq!(alloc.len, 256); + } + + #[test] + fn test_pool_alloc_large_to_upper() { + let pool = make_pool::<256, 4096>(1024 * 1024); + let alloc = pool.alloc(1500).unwrap(); + + // Should come from upper slab + assert!(pool.inner.borrow().upper.contains(alloc.addr)); + assert_eq!(alloc.len, 4096); + } + + #[test] + fn test_pool_alloc_fallback_to_upper() { + let pool = make_pool::<256, 4096>(1024 * 1024); + + // Fill lower slab completely + let mut allocations = Vec::new(); + while pool.inner.borrow().lower.free_bytes() > 0 { + allocations.push(pool.inner.borrow_mut().lower.alloc(256).unwrap()); + } + + // Small allocation should fallback to upper slab + let alloc = pool.alloc(128).unwrap(); + assert!(pool.inner.borrow().upper.contains(alloc.addr)); + } + + #[test] + fn test_pool_free_from_lower() { + let pool = make_pool::<256, 4096>(1024 * 1024); + let alloc = pool.alloc(128).unwrap(); + + let free_before = pool.inner.borrow().lower.free_bytes(); + pool.dealloc(alloc).unwrap(); + assert_eq!( + pool.inner.borrow().lower.free_bytes(), + free_before + alloc.len + ); + } + + #[test] + fn test_pool_free_from_upper() { + let pool = make_pool::<256, 4096>(1024 * 1024); + let alloc = pool.alloc(1500).unwrap(); + + let free_before = pool.inner.borrow().upper.free_bytes(); + pool.dealloc(alloc).unwrap(); + assert_eq!( + pool.inner.borrow().upper.free_bytes(), + free_before + alloc.len + ); + } + + #[test] + fn test_pool_realloc_within_same_tier() { + let pool = make_pool::<256, 4096>(1024 * 1024); + let alloc = pool.alloc(128).unwrap(); + + // Realloc within lower tier (128 -> 200, both fit in 256 slots) + let new_alloc = pool.resize(alloc, 200).unwrap(); + assert!(pool.inner.borrow().lower.contains(new_alloc.addr)); + } + + #[test] + fn test_pool_realloc_move_to_different_tier() { + let pool = make_pool::<256, 4096>(1024 * 1024); + let alloc = pool.alloc(128).unwrap(); + assert!(pool.inner.borrow().lower.contains(alloc.addr)); + + // Realloc to size that needs upper tier + let new_alloc = pool.resize(alloc, 1500).unwrap(); + assert!(pool.inner.borrow().upper.contains(new_alloc.addr)); + } + + #[test] + fn test_pool_realloc_shrink_stays_in_tier() { + let pool = make_pool::<256, 4096>(1024 * 1024); + let alloc = pool.alloc(1500).unwrap(); + assert!(pool.inner.borrow().upper.contains(alloc.addr)); + + // Shrink but stay in upper tier + let new_alloc = pool.resize(alloc, 1000).unwrap(); + assert!(pool.inner.borrow().upper.contains(new_alloc.addr)); + } + + #[test] + fn test_pool_stress_many_allocations() { + let pool = make_pool::<256, 4096>(4 * 1024 * 1024); + let mut allocations = Vec::new(); + + // Allocate many buffers + for i in 0..100 { + let size = if i % 2 == 0 { 128 } else { 1500 }; + allocations.push(pool.alloc(size).unwrap()); + } + + // Free half of them + for i in (0..100).step_by(2) { + pool.dealloc(allocations[i]).unwrap(); + } + + // Should be able to allocate again + for i in 0..50 { + let size = if i % 2 == 0 { 128 } else { 1500 }; + let _alloc = pool.alloc(size).unwrap(); + } + } + + #[test] + fn test_pool_mixed_workload() { + let pool = make_pool::<256, 4096>(2 * 1024 * 1024); + + // Simulate virtio-net workload + let desc_buf = pool.alloc(64).unwrap(); // Control message + let rx_buf1 = pool.alloc(1500).unwrap(); // MTU packet + let rx_buf2 = pool.alloc(1500).unwrap(); // MTU packet + let tx_buf = pool.alloc(4096).unwrap(); // Large buffer + + // Free and reallocate + pool.dealloc(rx_buf1).unwrap(); + let rx_buf3 = pool.alloc(1500).unwrap(); + + // Should reuse freed buffer (LIFO) + assert_eq!(rx_buf3.addr, rx_buf1.addr); + + pool.dealloc(desc_buf).unwrap(); + pool.dealloc(rx_buf2).unwrap(); + pool.dealloc(rx_buf3).unwrap(); + pool.dealloc(tx_buf).unwrap(); + } + + #[test] + fn test_pool_zero_allocation_error() { + let pool = make_pool::<256, 4096>(1024 * 1024); + let result = pool.alloc(0); + assert!(matches!(result, Err(AllocError::InvalidArg))); + } + + #[test] + fn test_pool_too_large_allocation() { + let pool = make_pool::<256, 4096>(1024 * 1024); + let result = pool.alloc(2 * 1024 * 1024); // Larger than pool + assert!(matches!(result, Err(AllocError::OutOfMemory))); + } + + #[test] + fn test_align_up_helper() { + assert_eq!(align_up(0, 256), 0); + assert_eq!(align_up(1, 256), 256); + assert_eq!(align_up(256, 256), 256); + assert_eq!(align_up(257, 256), 512); + assert_eq!(align_up(511, 256), 512); + assert_eq!(align_up(512, 256), 512); + } + + #[test] + fn test_slab_usable_size() { + let slab = make_slab::<256>(1024); + assert_eq!(slab.usable_size(0, 0), 0); + assert_eq!(slab.usable_size(0, 1), 256); + assert_eq!(slab.usable_size(0, 256), 256); + assert_eq!(slab.usable_size(0, 257), 512); + } + + #[test] + fn test_slab_contains() { + let slab = make_slab::<256>(1024); + let range = slab.range(); + + assert!(slab.contains(range.start)); + assert!(!slab.contains(range.end)); // Exclusive end + assert!(!slab.contains(0)); + } + + // Edge case: allocation exactly at boundary + #[test] + fn test_pool_boundary_allocation() { + let pool = make_pool::<256, 4096>(1024 * 1024); + + // Allocate exactly at boundary + let alloc = pool.alloc(256).unwrap(); + assert!(pool.inner.borrow().lower.contains(alloc.addr)); + + // Allocate just over boundary + let alloc2 = pool.alloc(257).unwrap(); + assert!(pool.inner.borrow().upper.contains(alloc2.addr)); + } + + // Test overflow protection + #[test] + fn test_addr_of_overflow_protection() { + let slab = make_slab::<4096>(8192); + + // This should not panic due to overflow checks + let addr = slab.addr_of(usize::MAX); + assert!(addr.is_none()); + } + + #[test] + fn test_no_overlapping_allocations() { + let mut slab = make_slab::<4096>(32768); // 8 slots + + // Allocate slot 0-1 + let a1 = slab.alloc(8000).unwrap(); + assert_eq!(a1.len, 8192); + + // Shrink to slot 0 only + let a2 = slab.shrink_inplace(a1, 4000).unwrap(); + assert_eq!(a2.len, 4096); + + // Allocate at slot 1-2 + let a3 = slab.alloc(8000).unwrap(); + assert_eq!(a3.len, 8192); + let slot1_addr = a2.addr + 4096; + assert_eq!(a3.addr, slot1_addr); + + // Free slot 0 + slab.dealloc(a2).unwrap(); + + // Try to allocate 2 slots - should NOT get slot 0-1 because slot 1 is occupied! + let a4 = slab.alloc(8000).unwrap(); + assert_ne!(a4.addr, a2.addr); // Should be at a different location + + slab.dealloc(a3).unwrap(); + slab.dealloc(a4).unwrap(); + } +} + +#[cfg(test)] +mod fuzz { + use quickcheck::{Arbitrary, Gen, QuickCheck}; + + use super::*; + + const MAX_OPS: usize = 10; + const MAX_ALLOC_SIZE: usize = 8192; + + #[derive(Clone, Debug)] + enum Op { + Alloc(usize), + Dealloc(usize), + Resize(usize, usize), + } + + impl Arbitrary for Op { + fn arbitrary(g: &mut Gen) -> Self { + match u8::arbitrary(g) % 3 { + 0 => Op::Alloc(usize::arbitrary(g) % MAX_ALLOC_SIZE + 1), + 1 => Op::Dealloc(usize::arbitrary(g)), + 2 => Op::Resize( + usize::arbitrary(g), + usize::arbitrary(g) % MAX_ALLOC_SIZE + 1, + ), + _ => unreachable!(), + } + } + } + + #[derive(Clone, Debug)] + struct Scenario { + pool_size: usize, + ops: Vec, + } + + impl Arbitrary for Scenario { + fn arbitrary(g: &mut Gen) -> Self { + let pool_size = (usize::arbitrary(g) % (4 * 1024 * 1024)) + (1024 * 1024); + let num_ops = usize::arbitrary(g) % MAX_OPS + 1; + let ops = (0..num_ops).map(|_| Op::arbitrary(g)).collect(); + + Scenario { pool_size, ops } + } + } + + fn run_scenario(s: Scenario) -> bool { + let base = align_up(0x10000, 4096) as u64; + let pool = match BufferPool::<256, 4096>::new(base, s.pool_size) { + Ok(p) => p, + Err(_) => return true, + }; + + let mut allocations: Vec = Vec::new(); + + for op in &s.ops { + match op { + Op::Alloc(size) => match pool.alloc(*size) { + Ok(alloc) => { + assert!(alloc.len >= *size); + allocations.push(alloc); + } + Err(AllocError::OutOfMemory) => {} + Err(_) => { + return false; + } + }, + Op::Dealloc(idx) => { + if allocations.is_empty() { + continue; + } + + let idx = idx % allocations.len(); + let alloc = allocations.swap_remove(idx); + + match pool.dealloc(alloc) { + Ok(_) => {} + Err(_) => return false, + } + } + Op::Resize(idx, new_size) => { + if allocations.is_empty() { + continue; + } + + let idx = idx % allocations.len(); + let old_alloc = allocations[idx]; + + match pool.resize(old_alloc, *new_size) { + Ok(new_alloc) => { + assert!(new_alloc.len >= *new_size); + allocations[idx] = new_alloc; + } + Err(AllocError::OutOfMemory) => {} + Err(_) => return false, + } + } + } + + if check_pool_invariants(&pool, &allocations).is_err() { + return false; + } + } + + // Cleanup + for alloc in &allocations { + if pool.dealloc(*alloc).is_err() { + return false; + } + } + + check_pool_invariants(&pool, &allocations).is_ok() + } + + fn check_slab_invariants(slab: &Slab) -> Result<(), &'static str> { + // Check that number of used + free slots equals total + let used = slab.used_slots.count_ones(..); + let free = slab.used_slots.count_zeroes(..); + if used + free != slab.used_slots.len() { + return Err("used + free != total slots"); + } + + let expected_free = free * N; + if slab.free_bytes() != expected_free { + return Err("free_bytes doesn't match bitmap"); + } + + if let Some(alloc) = slab.last_free_run { + if alloc.len == 0 || alloc.len % N != 0 { + return Err("last_free_run has invalid length"); + } + if !slab.contains(alloc.addr) { + return Err("last_free_run addr outside range"); + } + } + + Ok(()) + } + + fn check_pool_invariants( + pool: &BufferPool, + allocations: &[Allocation], + ) -> Result<(), &'static str> { + check_slab_invariants(&pool.inner.borrow().lower)?; + check_slab_invariants(&pool.inner.borrow().upper)?; + + if pool.inner.borrow().lower.range().end > pool.inner.borrow().upper.range().start { + return Err("lower and upper ranges overlap"); + } + + let mut seen = std::collections::HashSet::new(); + + for alloc in allocations { + if !pool.inner.borrow().lower.contains(alloc.addr) + && !pool.inner.borrow().upper.contains(alloc.addr) + { + return Err("allocation address outside pool ranges"); + } + + if alloc.len % L != 0 && alloc.len % U != 0 { + return Err("allocation length not aligned to any tier"); + } + + if !seen.insert(alloc.addr) { + return Err("duplicate allocation address in tracking"); + } + } + + Ok(()) + } + + #[test] + fn prop_allocator_invariants() { + #[cfg(miri)] + let tests = 10; + #[cfg(not(miri))] + let tests = 1000; + + QuickCheck::new() + .tests(tests) + .quickcheck(run_scenario as fn(Scenario) -> bool); + } +} diff --git a/src/hyperlight_common/src/virtq/producer.rs b/src/hyperlight_common/src/virtq/producer.rs new file mode 100644 index 000000000..95db0b7ba --- /dev/null +++ b/src/hyperlight_common/src/virtq/producer.rs @@ -0,0 +1,790 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +use alloc::vec; +use alloc::vec::Vec; + +use bytes::Bytes; + +use super::*; + +/// A completion received by the driver (producer) side. +/// +/// Contains the completion data and metadata about the completed entry. +/// The `data` field is a zero-copy [`Bytes`] backed by a shared-memory +/// pool allocation that is returned when the last `Bytes` clone is dropped. +#[derive(Debug)] +pub struct RecvCompletion { + /// Token identifying which entry this completion corresponds to. + pub token: Token, + /// Completion data from the device. + pub data: Bytes, +} + +/// Allocation tracking for an in-flight descriptor chain. +/// +/// Each variant corresponds to a buffer layout submitted by the driver +/// (guest/producer) and consumed by the device (host/consumer). +/// "Readable" and "writable" are from the device's perspective, following +/// the virtio convention. +#[derive(Debug, Clone, Copy)] +pub(crate) enum Inflight { + /// Driver sends data, device only acknowledges (fire-and-forget). + /// The readable buffer carries the entry; no writable buffer for a + /// device response. + ReadOnly { entry: Allocation }, + /// Driver pre-posts a writable buffer for the device to fill. + /// No readable entry - the device writes a response into the + /// completion buffer unprompted (e.g. event delivery). + WriteOnly { completion: Allocation }, + /// Bidirectional: driver sends an entry, device writes a response. + /// The readable buffer carries the entry, the writable buffer + /// receives the completion (typical request/response pattern). + ReadWrite { + entry: Allocation, + completion: Allocation, + }, +} + +impl Inflight { + fn entry(&self) -> Option { + match self { + Inflight::ReadOnly { entry } => Some(*entry), + Inflight::ReadWrite { entry, .. } => Some(*entry), + Inflight::WriteOnly { .. } => None, + } + } + + fn completion(&self) -> Option { + match self { + Inflight::WriteOnly { completion } => Some(*completion), + Inflight::ReadWrite { completion, .. } => Some(*completion), + Inflight::ReadOnly { .. } => None, + } + } + + fn try_into_chain(self, entry_len: usize) -> Result { + if let Some(entry) = self.entry() + && entry_len > entry.len + { + return Err(VirtqError::EntryTooLarge); + } + + Ok(match self { + Inflight::ReadOnly { entry } => BufferChainBuilder::new() + .readable(entry.addr, entry_len as u32) + .build()?, + Inflight::WriteOnly { completion } => BufferChainBuilder::new() + .writable(completion.addr, completion.len as u32) + .build()?, + Inflight::ReadWrite { entry, completion } => BufferChainBuilder::new() + .readable(entry.addr, entry_len as u32) + .writable(completion.addr, completion.len as u32) + .build()?, + }) + } +} + +/// A high-level virtqueue producer (driver side). +/// +/// The producer sends entries to the consumer (device), and receives completions. +/// This is typically used on the driver/guest side. +/// +/// # Example +/// +/// ```ignore +/// let mut producer = VirtqProducer::new(layout, mem, notifier, pool); +/// +/// // Build and submit an entry +/// let mut se = producer.chain().entry(64).completion(64).build()?; +/// se.write_all(b"hello")?; +/// let token = producer.submit(se)?; +/// +/// // Later, poll for completion +/// if let Some(cqe) = producer.poll()? { +/// assert_eq!(cqe.token, token); +/// println!("Got completion: {:?}", cqe.data); +/// } +/// ``` +pub struct VirtqProducer { + inner: RingProducer, + notifier: N, + pool: P, + inflight: Vec>, +} + +impl VirtqProducer +where + M: MemOps + Clone, + N: Notifier, + P: BufferProvider + Clone, +{ + /// Create a new virtqueue producer. + /// + /// # Arguments + /// + /// * `layout` - Ring memory layout (descriptor table and event suppression addresses) + /// * `mem` - Memory operations implementation for reading/writing to shared memory + /// * `notifier` - Callback for notifying the device (consumer) about new entries + /// * `pool` - Buffer allocator for entry/completion data + pub fn new(layout: Layout, mem: M, notifier: N, pool: P) -> Self { + let inner = RingProducer::new(layout, mem); + let inflight = vec![None; inner.len()]; + + Self { + inner, + pool, + notifier, + inflight, + } + } + + /// Poll for a single completion from the device. + /// + /// Returns `Ok(Some(completion))` if a completion is available, `Ok(None)` if no + /// completions are ready (would block), or an error if the device misbehaved. + /// + /// The returned [`RecvCompletion::data`] is a zero-copy [`Bytes`] backed by the + /// shared-memory allocation via [`BufferOwner`]. The pool allocation is + /// held alive as long as any `Bytes` clone exists, and is returned to the + /// pool when the last clone is dropped. + /// + /// # Errors + /// + /// - [`VirtqError::InvalidState`] - Device returned invalid descriptor ID or + /// wrote more data than the completion buffer capacity + pub fn poll(&mut self) -> Result, VirtqError> + where + M: Send + Sync + 'static, + P: Send + Sync + 'static, + { + let used = match self.inner.poll_used() { + Ok(u) => u, + Err(RingError::WouldBlock) => return Ok(None), + Err(e) => return Err(e.into()), + }; + + let id = used.id as usize; + let inf = self + .inflight + .get_mut(id) + .ok_or(VirtqError::InvalidState)? + .take() + .ok_or(VirtqError::InvalidState)?; + + let written = used.len as usize; + + // Free entry buffers (request data no longer needed) + if let Some(entry) = inf.entry() { + self.pool.dealloc(entry)?; + } + + // Read completion data + let data = match inf.completion() { + Some(buf) => { + if written > buf.len { + let _ = self.pool.dealloc(buf); + return Err(VirtqError::InvalidState); + } + let owner = BufferOwner { + pool: self.pool.clone(), + mem: self.inner.mem().clone(), + alloc: buf, + written, + }; + Bytes::from_owner(owner) + } + None => Bytes::new(), + }; + + Ok(Some(RecvCompletion { + token: Token(used.id), + data, + })) + } + + /// Drain all available completions, calling the provided closure for each. + /// + /// This is a convenience method that repeatedly calls [`poll`](Self::poll) + /// until no more completions are available. + /// + /// # Arguments + /// + /// * `f` - Closure called for each completion with its token and data + /// + /// # Example + /// + /// ```ignore + /// producer.drain(|token, data| { + /// println!("Got {:?}: {} bytes", token, data.len()); + /// })?; + /// ``` + pub fn drain(&mut self, mut f: impl FnMut(Token, Bytes)) -> Result<(), VirtqError> + where + M: Send + Sync + 'static, + P: Send + Sync + 'static, + { + while let Some(cqe) = self.poll()? { + f(cqe.token, cqe.data); + } + + Ok(()) + } + + /// Begin building a descriptor chain for submission. + /// + /// Returns a [`ChainBuilder`] that allocates buffers from the pool. + /// ``` + pub fn chain(&self) -> ChainBuilder { + ChainBuilder::new(self.inner.mem().clone(), self.pool.clone()) + } + + /// Submit a [`SendEntry`] to the ring. + /// + /// Publishes the descriptor chain, stores the in-flight tracking state, + /// and notifies the consumer if event suppression allows. + /// + /// # Errors + /// + /// - [`VirtqError::EntryTooLarge`] - written exceeds entry buffer capacity + /// - [`VirtqError::RingError`] - ring is full + /// - [`VirtqError::InvalidState`] - descriptor ID collision + pub fn submit(&mut self, mut entry: SendEntry) -> Result { + let written = entry.written; + let inflight = entry.inflight.take().ok_or(VirtqError::InvalidState)?; + + let cursor_before = self.inner.avail_cursor(); + let chain = inflight.try_into_chain(written)?; + let id = self.inner.submit_available(&chain)?; + + let slot = self + .inflight + .get_mut(id as usize) + .ok_or(VirtqError::InvalidState)?; + + if slot.is_some() { + return Err(VirtqError::InvalidState); + } + + *slot = Some(inflight); + + let should_notify = self.inner.should_notify_since(cursor_before)?; + if should_notify { + self.notifier.notify(QueueStats { + num_free: self.inner.num_free(), + num_inflight: self.inner.num_inflight(), + }); + } + + Ok(Token(id)) + } + + /// Get the current used cursor position. + /// + /// Useful for setting up descriptor-based event suppression. + #[inline] + pub fn used_cursor(&self) -> RingCursor { + self.inner.used_cursor() + } + + /// Configure event suppression for used buffer notifications. + /// + /// This controls when the device (consumer) signals us about completed buffers: + /// + /// - [`SuppressionKind::Enable`]: Always signal (default) - good for latency + /// - [`SuppressionKind::Disable`]: Never signal - caller must poll + /// - [`SuppressionKind::Descriptor`]: Signal only at specific cursor position + /// + /// # Example: Completion Batching + /// + /// ```ignore + /// // Submit entries, then suppress notifications until all complete + /// let mut se = producer.chain().entry(64).completion(128).build()?; + /// se.write_all(b"entry1")?; + /// producer.submit(se)?; + /// let cursor = producer.used_cursor(); + /// producer.set_used_suppression(SuppressionKind::Descriptor(cursor))?; + /// // Device will notify only after reaching that cursor position + /// ``` + pub fn set_used_suppression(&mut self, kind: SuppressionKind) -> Result<(), VirtqError> { + match kind { + SuppressionKind::Enable => self.inner.enable_used_notifications()?, + SuppressionKind::Disable => self.inner.disable_used_notifications()?, + SuppressionKind::Descriptor(cursor) => self + .inner + .enable_used_notifications_desc(cursor.head(), cursor.wrap())?, + } + Ok(()) + } +} + +/// Builder for configuring a descriptor chain's buffer layout. +/// +/// If dropped without building, no resources are leaked (allocations are +/// deferred to [`build`](Self::build)). +#[must_use = "call .build() to create a SendEntry"] +pub struct ChainBuilder { + mem: M, + pool: P, + entry_cap: Option, + cqe_cap: Option, +} + +impl ChainBuilder { + fn new(mem: M, pool: P) -> Self { + Self { + mem, + pool, + entry_cap: None, + cqe_cap: None, + } + } + + fn alloc( + &self, + size: usize, + ) -> Result>, VirtqError> { + let alloc = self.pool.alloc(size)?; + let pool = self.pool.clone(); + + Ok(AllocGuard::new(alloc, move |a| { + let _ = pool.dealloc(a); + })) + } + + /// Request an entry buffer of `cap` bytes. + /// + /// The entry holds data sent from the driver to the consumer (device). + /// The actual allocation is deferred to [`build`](Self::build). + pub fn entry(mut self, cap: usize) -> Self { + self.entry_cap = Some(cap); + self + } + + /// Request a completion buffer of `cap` bytes. + /// + /// The completion buffer is filled by the consumer and returned via + /// [`VirtqProducer::poll`] as [`RecvCompletion`]. + pub fn completion(mut self, cap: usize) -> Self { + self.cqe_cap = Some(cap); + self + } + + /// Allocate buffers and return a [`SendEntry`] for writing. + /// + /// # Errors + /// + /// - [`VirtqError::InvalidState`] - Neither entry nor completion requested + /// - [`VirtqError::Alloc`] - Pool exhausted + pub fn build(self) -> Result, VirtqError> { + if self.entry_cap.is_none() && self.cqe_cap.is_none() { + return Err(VirtqError::InvalidState); + } + + let entry_alloc = self.entry_cap.map(|cap| self.alloc(cap)).transpose()?; + let completion_alloc = self.cqe_cap.map(|cap| self.alloc(cap)).transpose()?; + + let inflight = match (entry_alloc, completion_alloc) { + (Some(entry), Some(cqe)) => Inflight::ReadWrite { + entry: entry.release(), + completion: cqe.release(), + }, + (Some(entry), None) => Inflight::ReadOnly { + entry: entry.release(), + }, + (None, Some(cqe)) => Inflight::WriteOnly { + completion: cqe.release(), + }, + (None, None) => unreachable!(), + }; + + Ok(SendEntry { + mem: self.mem, + pool: self.pool, + inflight: Some(inflight), + written: 0, + }) + } +} + +/// A configured entry ready for writing and submission. +/// +/// Created by [`ChainBuilder::build`]. Write data into the entry buffer +/// with [`write_all`](Self::write_all), +/// or use [`buf_mut`](Self::buf_mut) for zero-copy direct access. +/// Then submit via [`VirtqProducer::submit`]. +/// +/// # Examples +/// +/// ```ignore +/// let mut se = producer.chain().entry(64).completion(128).build()?; +/// se.write_all(b"header")?; +/// se.write_all(b" body")?; +/// let tok = producer.submit(se)?; +/// +/// // Zero-copy direct access +/// let mut se = producer.chain().entry(128).build()?; +/// let buf = se.buf_mut()?; +/// let n = serialize_into(buf); +/// se.set_written(n)?; +/// let tok = producer.submit(se)?; +/// ``` +/// +/// If dropped without submitting, allocated buffers are returned to the pool. +#[must_use = "dropping without submitting deallocates the buffers"] +pub struct SendEntry { + mem: M, + pool: P, + written: usize, + inflight: Option, +} + +impl SendEntry { + fn entry(&self) -> Result { + self.inflight + .as_ref() + .and_then(|i| i.entry()) + .ok_or(VirtqError::NoReadableBuffer) + } + + /// Total entry buffer capacity in bytes. + /// + /// Returns 0 when there are no entry buffers. + pub fn capacity(&self) -> usize { + self.inflight + .as_ref() + .and_then(|i| i.entry()) + .map_or(0, |a| a.len) + } + + /// Number of bytes written so far via [`write_all`](Self::write_all) + /// or [`set_written`](Self::set_written). + pub fn written(&self) -> usize { + self.written + } + + /// Set the write cursor to an explicit byte count. + /// + /// Use this after [`buf_mut`](Self::buf_mut) where you wrote directly + /// into the buffer. The value tells the consumer how many bytes of + /// the entry buffer are valid. + /// + /// # Errors + /// + /// - [`VirtqError::EntryTooLarge`] - `written` exceeds entry buffer capacity + pub fn set_written(&mut self, written: usize) -> Result<(), VirtqError> { + if written > self.capacity() { + return Err(VirtqError::EntryTooLarge); + } + + self.written = written; + Ok(()) + } + + /// Remaining writable capacity in the entry buffer. + pub fn remaining(&self) -> usize { + self.capacity() - self.written + } + + /// Write the entire buffer into the entry. + /// + /// Appends at the current write position. Uses [`MemOps::write`] + /// (volatile on host side). + /// + /// # Errors + /// + /// - [`VirtqError::EntryTooLarge`] - buf exceeds remaining capacity + /// - [`VirtqError::NoReadableBuffer`] - no entry buffer allocated + /// - [`VirtqError::MemoryWriteError`] - underlying write failed + pub fn write_all(&mut self, buf: &[u8]) -> Result<(), VirtqError> { + let alloc = self.entry()?; + + if buf.len() > self.remaining() { + return Err(VirtqError::EntryTooLarge); + } + + let addr = alloc.addr + self.written as u64; + self.mem + .write(addr, buf) + .map_err(|_| VirtqError::MemoryWriteError)?; + + self.written += buf.len(); + Ok(()) + } + + /// Zero-copy access to the full entry buffer in shared memory. + /// + /// Returns `&mut [u8]` pointing directly into the allocated buffer. + /// This is safe on the guest side (producer). After writing, call + /// [`set_written`](Self::set_written) to record how many bytes are valid. + /// + /// **Note**: This bypasses the write cursor. Use either `buf_mut()` + + /// `set_written(n)` or the `write_all` method, not both. + /// + /// # Errors + /// + /// - [`VirtqError::NoReadableBuffer`] - no entry buffer allocated + /// - [`VirtqError::MemoryWriteError`] - failed to access shared memory + pub fn buf_mut(&mut self) -> Result<&mut [u8], VirtqError> { + let alloc = self.entry()?; + unsafe { + self.mem + .as_mut_slice(alloc.addr, alloc.len) + .map_err(|_| VirtqError::MemoryWriteError) + } + } +} + +impl Drop for SendEntry { + fn drop(&mut self) { + let inf = match self.inflight.take() { + Some(i) => i, + None => return, // already submitted + }; + if let Some(a) = inf.entry() { + let _ = self.pool.dealloc(a); + } + if let Some(a) = inf.completion() { + let _ = self.pool.dealloc(a); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::virtq::ring::tests::make_ring; + use crate::virtq::test_utils::*; + + #[test] + fn test_chain_readwrite_build() { + let ring = make_ring(16); + let (producer, _consumer, _notifier) = make_test_producer(&ring); + + let se = producer.chain().entry(64).completion(128).build().unwrap(); + assert_eq!(se.capacity(), 64); + assert_eq!(se.written(), 0); + assert_eq!(se.remaining(), 64); + } + + #[test] + fn test_chain_entry_only_build() { + let ring = make_ring(16); + let (producer, _consumer, _notifier) = make_test_producer(&ring); + + let se = producer.chain().entry(32).build().unwrap(); + assert_eq!(se.capacity(), 32); + } + + #[test] + fn test_chain_completion_only_build() { + let ring = make_ring(16); + let (producer, _consumer, _notifier) = make_test_producer(&ring); + + let se = producer.chain().completion(64).build().unwrap(); + assert_eq!(se.capacity(), 0); + } + + #[test] + fn test_chain_empty_build_fails() { + let ring = make_ring(16); + let (producer, _consumer, _notifier) = make_test_producer(&ring); + + let result = producer.chain().build(); + assert!(matches!(result, Err(VirtqError::InvalidState))); + } + + #[test] + fn test_send_entry_write_all_and_submit() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let mut se = producer.chain().entry(64).completion(128).build().unwrap(); + + se.write_all(b"hello").unwrap(); + se.write_all(b" world").unwrap(); + assert_eq!(se.written(), 11); + assert_eq!(se.remaining(), 53); + let tok = producer.submit(se).unwrap(); + + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.token(), tok); + assert_eq!(entry.data().as_ref(), b"hello world"); + consumer.complete(completion).unwrap(); + } + + #[test] + fn test_send_entry_buf_mut() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let mut se = producer.chain().entry(64).completion(128).build().unwrap(); + let buf = se.buf_mut().unwrap(); + assert_eq!(buf.len(), 64); + buf[..5].copy_from_slice(b"hello"); + se.set_written(5).unwrap(); + let _tok = producer.submit(se).unwrap(); + + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.data().as_ref(), b"hello"); + consumer.complete(completion).unwrap(); + } + + #[test] + fn test_send_entry_write_too_large() { + let ring = make_ring(16); + let (producer, _consumer, _notifier) = make_test_producer(&ring); + + let mut se = producer.chain().entry(4).build().unwrap(); + let err = se.write_all(b"too long").unwrap_err(); + assert!(matches!(err, VirtqError::EntryTooLarge)); + } + + #[test] + fn test_writeonly_has_no_entry_buffer() { + let ring = make_ring(16); + let (producer, _consumer, _notifier) = make_test_producer(&ring); + + let mut se = producer.chain().completion(32).build().unwrap(); + let err = se.write_all(b"data").unwrap_err(); + assert!(matches!(err, VirtqError::NoReadableBuffer)); + } + + #[test] + fn test_drop_chain_builder_deallocs() { + let ring = make_ring(16); + let (mut producer, _consumer, _notifier) = make_test_producer(&ring); + + { + let _builder = producer.chain().entry(64).completion(128); + // dropped without build + } + + // Ring should still be fully usable + let se = producer.chain().entry(64).completion(128).build().unwrap(); + let tok = producer.submit(se).unwrap(); + assert!(tok.0 < 16); + } + + #[test] + fn test_drop_send_entry_deallocs() { + let ring = make_ring(16); + let (mut producer, _consumer, _notifier) = make_test_producer(&ring); + + { + let _se = producer.chain().entry(64).completion(128).build().unwrap(); + // dropped without submit + } + + // Ring should still be fully usable + let se = producer.chain().entry(64).completion(128).build().unwrap(); + let tok = producer.submit(se).unwrap(); + assert!(tok.0 < 16); + } + + #[test] + fn test_submit_notifies() { + let ring = make_ring(16); + let (mut producer, _consumer, notifier) = make_test_producer(&ring); + + let initial_count = notifier.notification_count(); + + let mut se = producer.chain().entry(64).completion(128).build().unwrap(); + se.write_all(b"hello").unwrap(); + producer.submit(se).unwrap(); + + assert!(notifier.notification_count() > initial_count); + } + + #[test] + fn test_set_written_too_large() { + let ring = make_ring(16); + let (producer, _consumer, _notifier) = make_test_producer(&ring); + + let mut se = producer.chain().entry(32).completion(64).build().unwrap(); + let err = se.set_written(64).unwrap_err(); + assert!(matches!(err, VirtqError::EntryTooLarge)); + } + + #[test] + fn test_write_only_round_trip() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let se = producer.chain().completion(32).build().unwrap(); + let token = producer.submit(se).unwrap(); + + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.token(), token); + assert!(entry.data().is_empty()); + + if let SendCompletion::Writable(mut wc) = completion { + wc.write_all(b"filled-by-consumer").unwrap(); + consumer.complete(wc.into()).unwrap(); + } else { + panic!("expected Writable"); + } + + let cqe = producer.poll().unwrap().unwrap(); + assert_eq!(cqe.token, token); + assert_eq!(cqe.data.len(), b"filled-by-consumer".len()); + assert_eq!(&cqe.data[..], b"filled-by-consumer"); + } + + #[test] + fn test_read_only_round_trip() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let mut se = producer.chain().entry(32).build().unwrap(); + se.write_all(b"fire-and-forget").unwrap(); + let token = producer.submit(se).unwrap(); + + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.token(), token); + assert_eq!(entry.data().as_ref(), b"fire-and-forget"); + assert!(matches!(completion, SendCompletion::Ack(_))); + consumer.complete(completion).unwrap(); + + let cqe = producer.poll().unwrap().unwrap(); + assert_eq!(cqe.token, token); + assert_eq!(cqe.data.len(), 0); + assert!(cqe.data.is_empty()); + } + + #[test] + fn test_readwrite_round_trip() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let mut se = producer.chain().entry(64).completion(128).build().unwrap(); + se.write_all(b"request data").unwrap(); + let token = producer.submit(se).unwrap(); + + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.data().as_ref(), b"request data"); + if let SendCompletion::Writable(mut wc) = completion { + wc.write_all(b"response data").unwrap(); + consumer.complete(wc.into()).unwrap(); + } else { + panic!("expected Writable"); + } + + let cqe = producer.poll().unwrap().unwrap(); + assert_eq!(cqe.token, token); + assert_eq!(&cqe.data[..], b"response data"); + } +} diff --git a/src/hyperlight_guest/src/error.rs b/src/hyperlight_guest/src/error.rs index 0a33bce79..9f256684b 100644 --- a/src/hyperlight_guest/src/error.rs +++ b/src/hyperlight_guest/src/error.rs @@ -17,9 +17,10 @@ limitations under the License. use alloc::format; use alloc::string::{String, ToString as _}; +use anyhow; pub use hyperlight_common::flatbuffer_wrappers::guest_error::ErrorCode; use hyperlight_common::func::Error as FuncError; -use {anyhow, serde_json}; +use serde_json; pub type Result = core::result::Result; @@ -171,10 +172,10 @@ impl GuestErrorContext for core::result::Result { #[macro_export] macro_rules! bail { ($ec:expr => $($msg:tt)*) => { - return ::core::result::Result::Err($crate::error::HyperlightGuestError::new($ec, ::alloc::format!($($msg)*))); + return ::core::result::Result::Err($crate::error::HyperlightGuestError::new($ec, ::alloc::format!($($msg)*))) }; ($($msg:tt)*) => { - $crate::bail!($crate::error::ErrorCode::GuestError => $($msg)*); + $crate::bail!($crate::error::ErrorCode::GuestError => $($msg)*) }; } From 95069bd19209ad5fa4bb1ab8b398e187087bf1e1 Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Wed, 25 Mar 2026 13:51:23 +0100 Subject: [PATCH 02/31] feat(virtq): add virtqueue ring plumbing in scratch region Place G2H and H2G packed virtqueue descriptor rings at deterministic offsets in the scratch region. Signed-off-by: Tomasz Andrzejak --- .gitmodules | 3 + Cargo.lock | 1 + .../src/arch/aarch64/layout.rs | 7 +- .../src/arch/amd64/layout.rs | 15 +- src/hyperlight_common/src/arch/i686/layout.rs | 13 +- src/hyperlight_common/src/layout.rs | 45 +- src/hyperlight_common/src/outb.rs | 3 + src/hyperlight_common/src/virtq/desc.rs.rej | 23 + src/hyperlight_common/src/virtq/mod.rs | 1 + src/hyperlight_common/src/virtq/msg.rs | 75 +++ src/hyperlight_common/src/virtq/stream.rs | 579 ++++++++++++++++++ src/hyperlight_guest_bin/bindgen_wrapper.h | 8 + src/hyperlight_guest_bin/src/lib.rs | 4 + src/hyperlight_guest_bin/src/virtq_init.rs | 50 ++ src/hyperlight_host/Cargo.toml | 1 + src/hyperlight_host/src/mem/layout.rs | 56 +- src/hyperlight_host/src/mem/mgr.rs | 47 +- src/hyperlight_host/src/mem/shared_mem.rs | 60 +- src/hyperlight_host/src/sandbox/config.rs | 22 + .../src/sandbox/initialized_multi_use.rs | 2 + src/hyperlight_host/src/sandbox/outb.rs | 4 + src/hyperlight_libc/third_party/mimalloc | 1 + src/tests/rust_guests/dummyguest/Cargo.lock | 33 +- src/tests/rust_guests/simpleguest/Cargo.lock | 33 +- src/tests/rust_guests/witguest/Cargo.lock | 29 +- 25 files changed, 1037 insertions(+), 78 deletions(-) create mode 100644 src/hyperlight_common/src/virtq/desc.rs.rej create mode 100644 src/hyperlight_common/src/virtq/msg.rs create mode 100644 src/hyperlight_common/src/virtq/stream.rs create mode 100644 src/hyperlight_guest_bin/bindgen_wrapper.h create mode 100644 src/hyperlight_guest_bin/src/virtq_init.rs create mode 160000 src/hyperlight_libc/third_party/mimalloc diff --git a/.gitmodules b/.gitmodules index b7b8a5c57..8c039127b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -2,3 +2,6 @@ path = src/hyperlight_libc/third_party/picolibc url = https://github.com/hyperlight-dev/picolibc-bsd.git shallow = true +[submodule "src/hyperlight_libc/third_party/mimalloc"] + path = src/hyperlight_libc/third_party/mimalloc + url = https://github.com/microsoft/mimalloc diff --git a/Cargo.lock b/Cargo.lock index d5d98755b..887ece62d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1608,6 +1608,7 @@ dependencies = [ "bitflags 2.11.1", "blake3", "built", + "bytemuck", "cfg-if", "cfg_aliases", "chrono", diff --git a/src/hyperlight_common/src/arch/aarch64/layout.rs b/src/hyperlight_common/src/arch/aarch64/layout.rs index 20f17026c..25bd99a1e 100644 --- a/src/hyperlight_common/src/arch/aarch64/layout.rs +++ b/src/hyperlight_common/src/arch/aarch64/layout.rs @@ -20,6 +20,11 @@ pub const SNAPSHOT_PT_GVA_MIN: usize = 0xffff_8000_0000_0000; pub const SNAPSHOT_PT_GVA_MAX: usize = 0xffff_80ff_ffff_ffff; pub const MAX_GPA: usize = 0x0000_000f_ffff_ffff; -pub fn min_scratch_size(_input_data_size: usize, _output_data_size: usize) -> usize { +pub fn min_scratch_size( + _input_data_size: usize, + _output_data_size: usize, + _g2h_num_descs: usize, + _h2g_num_descs: usize, +) -> usize { unimplemented!("min_scratch_size") } diff --git a/src/hyperlight_common/src/arch/amd64/layout.rs b/src/hyperlight_common/src/arch/amd64/layout.rs index 14a9cd62a..4731f21b2 100644 --- a/src/hyperlight_common/src/arch/amd64/layout.rs +++ b/src/hyperlight_common/src/arch/amd64/layout.rs @@ -37,8 +37,17 @@ pub const MAX_GPA: usize = 0x0000_000f_ffff_ffff; /// - A page for the smallest possible non-exception stack /// - (up to) 3 pages for mapping that /// - Two pages for the exception stack and metadata -/// - A page-aligned amount of memory for I/O buffers (for now) -pub fn min_scratch_size(input_data_size: usize, output_data_size: usize) -> usize { - (input_data_size + output_data_size).next_multiple_of(crate::vmem::PAGE_SIZE) +/// - A page-aligned amount of memory for I/O buffers and virtqueue rings +pub fn min_scratch_size( + input_data_size: usize, + output_data_size: usize, + g2h_num_descs: usize, + h2g_num_descs: usize, +) -> usize { + let g2h_ring_size = crate::virtq::Layout::query_size(g2h_num_descs); + let h2g_ring_size = crate::virtq::Layout::query_size(h2g_num_descs); + + (input_data_size + output_data_size + g2h_ring_size + h2g_ring_size) + .next_multiple_of(crate::vmem::PAGE_SIZE) + 12 * crate::vmem::PAGE_SIZE } diff --git a/src/hyperlight_common/src/arch/i686/layout.rs b/src/hyperlight_common/src/arch/i686/layout.rs index cdc3af7d1..ff93f1d09 100644 --- a/src/hyperlight_common/src/arch/i686/layout.rs +++ b/src/hyperlight_common/src/arch/i686/layout.rs @@ -21,10 +21,11 @@ pub const MAX_GVA: usize = 0xffff_ffff; /// regions are large enough to reach that address. pub const MAX_GPA: usize = 0xFEDF_FFFF; -/// Minimum scratch region size: IO buffers (page-aligned) plus 12 pages -/// for bookkeeping and the exception stack. Page table space is validated -/// separately by `set_pt_size()`. -pub fn min_scratch_size(input_data_size: usize, output_data_size: usize) -> usize { - (input_data_size + output_data_size).next_multiple_of(crate::vmem::PAGE_SIZE) - + 12 * crate::vmem::PAGE_SIZE +pub fn min_scratch_size( + _input_data_size: usize, + _output_data_size: usize, + _g2h_num_descs: usize, + _h2g_num_descs: usize, +) -> usize { + crate::vmem::PAGE_SIZE } diff --git a/src/hyperlight_common/src/layout.rs b/src/hyperlight_common/src/layout.rs index 1a7ca0880..234ad6f78 100644 --- a/src/hyperlight_common/src/layout.rs +++ b/src/hyperlight_common/src/layout.rs @@ -33,12 +33,28 @@ pub use arch::{MAX_GPA, MAX_GVA}; ))] pub use arch::{SNAPSHOT_PT_GVA_MAX, SNAPSHOT_PT_GVA_MIN}; -// offsets down from the top of scratch memory for various things pub const SCRATCH_TOP_SIZE_OFFSET: u64 = 0x08; pub const SCRATCH_TOP_ALLOCATOR_OFFSET: u64 = 0x10; pub const SCRATCH_TOP_SNAPSHOT_PT_GPA_BASE_OFFSET: u64 = 0x18; pub const SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET: u64 = 0x20; -pub const SCRATCH_TOP_EXN_STACK_OFFSET: u64 = 0x30; +pub const SCRATCH_TOP_G2H_RING_GVA_OFFSET: u64 = 0x28; +pub const SCRATCH_TOP_H2G_RING_GVA_OFFSET: u64 = 0x30; +pub const SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET: u64 = 0x38; +pub const SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET: u64 = 0x3a; +pub const SCRATCH_TOP_EXN_STACK_OFFSET: u64 = 0x40; + +// fields must not overlap, and exception stack address must be 16-byte aligned. +const _: () = { + assert!(SCRATCH_TOP_SIZE_OFFSET + 8 <= SCRATCH_TOP_ALLOCATOR_OFFSET); + assert!(SCRATCH_TOP_ALLOCATOR_OFFSET + 8 <= SCRATCH_TOP_SNAPSHOT_PT_GPA_BASE_OFFSET); + assert!(SCRATCH_TOP_SNAPSHOT_PT_GPA_BASE_OFFSET + 8 <= SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET); + assert!(SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET + 8 <= SCRATCH_TOP_G2H_RING_GVA_OFFSET); + assert!(SCRATCH_TOP_G2H_RING_GVA_OFFSET + 8 <= SCRATCH_TOP_H2G_RING_GVA_OFFSET); + assert!(SCRATCH_TOP_H2G_RING_GVA_OFFSET + 8 <= SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET); + assert!(SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET + 2 <= SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET); + assert!(SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET + 2 <= SCRATCH_TOP_EXN_STACK_OFFSET); + assert!(SCRATCH_TOP_EXN_STACK_OFFSET % 0x10 == 0); +}; /// Offset from the top of scratch memory for a shared host-guest u64 counter. /// @@ -56,5 +72,30 @@ pub fn scratch_base_gva(size: usize) -> u64 { (MAX_GVA - size + 1) as u64 } +/// Compute the byte offset from the scratch base to the G2H ring. +/// +/// TODO(ring): Remove input/output +pub const fn g2h_ring_scratch_offset(input_data_size: usize, output_data_size: usize) -> usize { + let io_off = input_data_size + output_data_size; + let align = crate::virtq::Descriptor::ALIGN; + + (io_off + align - 1) & !(align - 1) +} + +/// Compute the byte offset from the scratch base to the H2G ring. +/// +/// TODO(ring): Remove input/output +pub const fn h2g_ring_scratch_offset( + input_data_size: usize, + output_data_size: usize, + g2h_num_descs: usize, +) -> usize { + let g2h_offset = g2h_ring_scratch_offset(input_data_size, output_data_size); + let g2h_size = crate::virtq::Layout::query_size(g2h_num_descs); + let align = crate::virtq::Descriptor::ALIGN; + + (g2h_offset + g2h_size + align - 1) & !(align - 1) +} + /// Compute the minimum scratch region size needed for a sandbox. pub use arch::min_scratch_size; diff --git a/src/hyperlight_common/src/outb.rs b/src/hyperlight_common/src/outb.rs index 3bfb99848..0f9c25e00 100644 --- a/src/hyperlight_common/src/outb.rs +++ b/src/hyperlight_common/src/outb.rs @@ -105,6 +105,8 @@ pub enum OutBAction { TraceMemoryAlloc = 105, #[cfg(feature = "mem_profile")] TraceMemoryFree = 106, + /// Notification that entries are available on a virtqueue. + VirtqNotify = 109, } /// IO-port actions intercepted at the hypervisor level (in `run_vcpu`) @@ -137,6 +139,7 @@ impl TryFrom for OutBAction { 105 => Ok(OutBAction::TraceMemoryAlloc), #[cfg(feature = "mem_profile")] 106 => Ok(OutBAction::TraceMemoryFree), + 109 => Ok(OutBAction::VirtqNotify), _ => Err(anyhow::anyhow!("Invalid OutBAction value: {}", val)), } } diff --git a/src/hyperlight_common/src/virtq/desc.rs.rej b/src/hyperlight_common/src/virtq/desc.rs.rej new file mode 100644 index 000000000..2172452ba --- /dev/null +++ b/src/hyperlight_common/src/virtq/desc.rs.rej @@ -0,0 +1,23 @@ +diff a/src/hyperlight_common/src/virtq/desc.rs b/src/hyperlight_common/src/virtq/desc.rs (rejected hunks) +@@ -58,12 +58,15 @@ pub struct Descriptor { + pub flags: u16, + } + +-const _: () = assert!(core::mem::size_of::() == 16); +-const _: () = assert!(Descriptor::ALIGN == 16); +-const _: () = assert!(Descriptor::ADDR_OFFSET == 0); +-const _: () = assert!(Descriptor::LEN_OFFSET == 8); +-const _: () = assert!(Descriptor::ID_OFFSET == 12); +-const _: () = assert!(Descriptor::FLAGS_OFFSET == 14); ++#[allow(clippy::disallowed_macros)] ++const _: () = { ++ assert!(core::mem::size_of::() == 16); ++ assert!(Descriptor::ALIGN == 16); ++ assert!(Descriptor::ADDR_OFFSET == 0); ++ assert!(Descriptor::LEN_OFFSET == 8); ++ assert!(Descriptor::ID_OFFSET == 12); ++ assert!(Descriptor::FLAGS_OFFSET == 14); ++}; + + impl Descriptor { + // VIRTIO spec requires 16-byte alignment for descriptors diff --git a/src/hyperlight_common/src/virtq/mod.rs b/src/hyperlight_common/src/virtq/mod.rs index dd648e3dc..2e10491b2 100644 --- a/src/hyperlight_common/src/virtq/mod.rs +++ b/src/hyperlight_common/src/virtq/mod.rs @@ -154,6 +154,7 @@ mod access; mod consumer; mod desc; mod event; +pub mod msg; mod pool; mod producer; mod ring; diff --git a/src/hyperlight_common/src/virtq/msg.rs b/src/hyperlight_common/src/virtq/msg.rs new file mode 100644 index 000000000..9c7f69947 --- /dev/null +++ b/src/hyperlight_common/src/virtq/msg.rs @@ -0,0 +1,75 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Wire format header for all virtqueue messages. +//! +//! Every payload on both the G2H and H2G queues starts with this +//! fixed 8-byte header, enabling message type discrimination and +//! request/response correlation. + +/// Message types for the virtqueue wire protocol. +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MsgKind { + /// A function call request (FunctionCall payload follows). + Request = 0x01, + /// A function call response (FunctionCallResult payload follows). + Response = 0x02, + /// A stream data chunk. + StreamChunk = 0x03, + /// End-of-stream marker. + StreamEnd = 0x04, + /// Cancel a pending request. + Cancel = 0x05, +} + +/// Wire header for all virtqueue messages +#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)] +#[repr(C)] +pub struct VirtqMsgHeader { + /// Discriminates the message type. + pub kind: u8, + /// Per-type flags TODO(ring): add flags type. + pub flags: u8, + /// Caller-assigned correlation ID. Responses echo the request's ID. + pub req_id: u16, + /// Byte length of the payload following this header. + pub payload_len: u32, +} + +impl VirtqMsgHeader { + pub const SIZE: usize = core::mem::size_of::(); + + /// Create a new message header. + pub const fn new(kind: MsgKind, req_id: u16, payload_len: u32) -> Self { + Self { + kind: kind as u8, + flags: 0, + req_id, + payload_len, + } + } + + /// Create a new header with flags. + pub const fn with_flags(kind: MsgKind, flags: u8, req_id: u16, payload_len: u32) -> Self { + Self { + kind: kind as u8, + flags, + req_id, + payload_len, + } + } +} diff --git a/src/hyperlight_common/src/virtq/stream.rs b/src/hyperlight_common/src/virtq/stream.rs new file mode 100644 index 000000000..8dbc7da27 --- /dev/null +++ b/src/hyperlight_common/src/virtq/stream.rs @@ -0,0 +1,579 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Stream table and flow-control primitives for virtqueue byte streams. +//! +//! This module provides the `no_std`-friendly core types used on both +//! the guest and host sides: +//! +//! - [`StreamId`], [`StreamHandle`]: identification and cross-boundary +//! handle shape. +//! - [`StreamDirection`]: G2H vs H2G per-direction ID spaces. +//! - [`StreamTable`]: per-direction table of open streams, with a +//! generation counter (bumped on sandbox reset) that gates stale +//! messages. +//! - Credit accounting: [`WriterCredit`] / [`ReaderCredit`] wrappers +//! plus signed-arithmetic helpers that mirror virtio-vsock's +//! wrap-safe math. +//! +//! Only the data model and arithmetic lives here. Wiring to the +//! producer/consumer virtqueues lives in the guest and host crates +//! (Phase 2/3). + +use alloc::vec::Vec; + +use crate::virtq::msg::{STREAM_GEN_MASK, STREAM_ID_MAX}; + +/// Default initial credit (bytes) a newly-created local stream end +/// advertises to its peer. One BufferPool slot's worth for MVP. +pub const STREAM_INITIAL_CREDIT: u32 = 4096; + +/// Direction a stream flows over the virtqueue pair. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum StreamDirection { + /// Guest produces, host consumes. + Guest2Host, + /// Host produces, guest consumes. + Host2Guest, +} + +/// 12-bit stream id within a given direction and generation. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct StreamId(u16); + +impl StreamId { + /// Construct from a raw 12-bit value. Returns `None` if out of range. + pub const fn from_u16(id: u16) -> Option { + if id > STREAM_ID_MAX { + None + } else { + Some(Self(id)) + } + } + + pub const fn as_u16(self) -> u16 { + self.0 + } +} + +/// Opaque handle identifying a stream across the VM boundary. +/// +/// Carries the minimum information needed to route messages and to +/// reject stale traffic after sandbox reset. The `initial_credit` is +/// the buffer capacity the receiver advertises at handle-transfer time +/// (bootstraps the peer's `buf_alloc`). `None` means "use the default". +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct StreamHandle { + pub direction: StreamDirection, + pub generation: u8, + pub stream_id: StreamId, + pub initial_credit: Option, +} + +impl StreamHandle { + pub const fn new( + direction: StreamDirection, + generation: u8, + stream_id: StreamId, + initial_credit: Option, + ) -> Self { + Self { + direction, + generation, + stream_id, + initial_credit, + } + } +} + +/// Writer-side credit accounting. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub struct WriterCredit { + /// Total bytes this writer has transmitted. + pub tx_cnt: u32, + /// Last `fwd_cnt` advertised by the reader. + pub peer_fwd_cnt: u32, + /// Last `buf_alloc` advertised by the reader. + pub peer_buf_alloc: u32, +} + +impl WriterCredit { + /// Initial writer state. `peer_buf_alloc` should be seeded to the + /// reader's bootstrap credit. + pub const fn new(peer_buf_alloc: u32) -> Self { + Self { + tx_cnt: 0, + peer_fwd_cnt: 0, + peer_buf_alloc, + } + } + + /// Bytes the writer may transmit right now without overrunning the + /// reader's advertised buffer. Uses signed 64-bit arithmetic so + /// that a peer shrinking its buffer below the in-flight count + /// returns 0 instead of underflowing. + pub fn available(&self) -> u32 { + available_credit(self.peer_buf_alloc, self.tx_cnt, self.peer_fwd_cnt) + } + + /// Record that `n` bytes have been placed on the wire. Saturating + /// to avoid spurious panics on pathological peer state; the writer + /// is expected to have clamped `n` to `available()` beforehand. + pub fn record_sent(&mut self, n: u32) { + self.tx_cnt = self.tx_cnt.wrapping_add(n); + } + + /// Absorb a preamble observed from the reader. + /// + /// Only monotonically-newer counters are adopted: `fwd_cnt` uses + /// signed wrap-safe comparison; `buf_alloc` is whatever the reader + /// last said (it can grow or shrink). + pub fn observe_peer(&mut self, peer_fwd_cnt: u32, peer_buf_alloc: u32) { + if wrap_gt(peer_fwd_cnt, self.peer_fwd_cnt) { + self.peer_fwd_cnt = peer_fwd_cnt; + } + self.peer_buf_alloc = peer_buf_alloc; + } +} + +/// Reader-side credit accounting. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub struct ReaderCredit { + /// Total bytes this reader has delivered to the consumer. + pub fwd_cnt: u32, + /// Buffer capacity advertised to the writer. May be lowered by + /// raising it only when a threshold of space has been freed. + pub buf_alloc: u32, +} + +impl ReaderCredit { + pub const fn new(buf_alloc: u32) -> Self { + Self { + fwd_cnt: 0, + buf_alloc, + } + } + + /// Record that `n` bytes have been consumed by the reader. + pub fn record_consumed(&mut self, n: u32) { + self.fwd_cnt = self.fwd_cnt.wrapping_add(n); + } +} + +/// Signed wrap-safe "strictly greater than". Returns `true` if `a` is +/// a newer counter value than `b`, tolerating u32 wraparound. +pub fn wrap_gt(a: u32, b: u32) -> bool { + (a.wrapping_sub(b) as i32) > 0 +} + +/// Available credit (in bytes) given the reader's advertised buffer +/// and the writer's current tx/peer-fwd counters. Implements the +/// vsock-equivalent calculation with signed arithmetic to survive +/// peer_buf_alloc shrinking. +pub fn available_credit(peer_buf_alloc: u32, tx_cnt: u32, peer_fwd_cnt: u32) -> u32 { + let in_flight = tx_cnt.wrapping_sub(peer_fwd_cnt) as i64; + let avail = peer_buf_alloc as i64 - in_flight; + if avail <= 0 { + 0 + } else if avail > u32::MAX as i64 { + u32::MAX + } else { + avail as u32 + } +} + +/// Lifecycle state of a single stream entry. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StreamLifecycle { + /// Open - data and control messages flow normally. + Open, + /// Writer has sent StreamEnd; reader may still drain buffered + /// chunks, but no new data will arrive. + WriterClosed, + /// Reader has sent Cancel; writer must stop producing. + Cancelled, + /// Fully closed; entry will be reaped on the next sweep. + Closed, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum EndRole { + /// This side is the writer for the stream (peer reads). + Writer, + /// This side is the reader for the stream (peer writes). + Reader, +} + +#[derive(Debug)] +struct StreamEntry { + id: StreamId, + #[allow(dead_code)] // Used in later phases for role-based routing. + role: EndRole, + writer: WriterCredit, + reader: ReaderCredit, + state: StreamLifecycle, +} + +/// Per-direction stream table. +/// +/// - Allocates [`StreamId`]s monotonically (no ID reuse within a +/// generation) to avoid races with late Cancel/CreditUpdate messages. +/// - Tracks a `generation` counter bumped on sandbox reset; inbound +/// messages carrying a mismatched generation are considered stale +/// and dropped by the routing layer. +#[derive(Debug)] +pub struct StreamTable { + direction: StreamDirection, + generation: u8, + next_id: u16, + entries: Vec, +} + +/// Errors produced by [`StreamTable`] mutations. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StreamTableError { + /// All stream IDs for the current generation have been used. + IdSpaceExhausted, + /// Referenced stream id is not present in the table. + Unknown(StreamId), +} + +impl StreamTable { + pub fn new(direction: StreamDirection) -> Self { + Self { + direction, + generation: 0, + next_id: 0, + entries: Vec::new(), + } + } + + pub fn direction(&self) -> StreamDirection { + self.direction + } + + pub fn generation(&self) -> u8 { + self.generation + } + + pub fn len(&self) -> usize { + self.entries.iter().filter(|e| !e.is_reaped()).count() + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Bump the generation and drop all open streams. Called on sandbox + /// reset/restore. Returns the new generation. + pub fn reset(&mut self) -> u8 { + self.generation = (self.generation.wrapping_add(1)) & (STREAM_GEN_MASK as u8); + self.next_id = 0; + self.entries.clear(); + self.generation + } + + /// Allocate a new stream id and register it as a locally-owned end. + /// `role_is_writer` indicates whether this side is the writer + /// (peer reads) or reader (peer writes). + fn allocate( + &mut self, + role: EndRole, + local_buf_alloc: u32, + peer_buf_alloc: u32, + ) -> Result { + if self.next_id > STREAM_ID_MAX { + return Err(StreamTableError::IdSpaceExhausted); + } + let id = StreamId(self.next_id); + self.next_id += 1; + self.entries.push(StreamEntry { + id, + role, + writer: WriterCredit::new(peer_buf_alloc), + reader: ReaderCredit::new(local_buf_alloc), + state: StreamLifecycle::Open, + }); + Ok(id) + } + + /// Register a locally-owned writer end. Typically called by the + /// side that is producing into the stream. + pub fn open_writer( + &mut self, + local_buf_alloc: u32, + peer_buf_alloc: u32, + ) -> Result { + self.allocate(EndRole::Writer, local_buf_alloc, peer_buf_alloc) + } + + /// Register a locally-owned reader end. + pub fn open_reader( + &mut self, + local_buf_alloc: u32, + peer_buf_alloc: u32, + ) -> Result { + self.allocate(EndRole::Reader, local_buf_alloc, peer_buf_alloc) + } + + fn find_mut(&mut self, id: StreamId) -> Result<&mut StreamEntry, StreamTableError> { + self.entries + .iter_mut() + .find(|e| e.id == id && !e.is_reaped()) + .ok_or(StreamTableError::Unknown(id)) + } + + fn find(&self, id: StreamId) -> Result<&StreamEntry, StreamTableError> { + self.entries + .iter() + .find(|e| e.id == id && !e.is_reaped()) + .ok_or(StreamTableError::Unknown(id)) + } + + pub fn lifecycle(&self, id: StreamId) -> Result { + self.find(id).map(|e| e.state) + } + + pub fn writer_credit(&self, id: StreamId) -> Result { + self.find(id).map(|e| e.writer) + } + + pub fn reader_credit(&self, id: StreamId) -> Result { + self.find(id).map(|e| e.reader) + } + + /// Apply a preamble observed from the peer for this stream. Updates + /// writer-side peer counters (only fields the peer-as-reader + /// advertises are meaningful). + pub fn observe_peer( + &mut self, + id: StreamId, + peer_fwd_cnt: u32, + peer_buf_alloc: u32, + ) -> Result<(), StreamTableError> { + let entry = self.find_mut(id)?; + entry.writer.observe_peer(peer_fwd_cnt, peer_buf_alloc); + Ok(()) + } + + /// Record that the local writer end has put `n` bytes on the wire. + pub fn record_sent(&mut self, id: StreamId, n: u32) -> Result<(), StreamTableError> { + let entry = self.find_mut(id)?; + entry.writer.record_sent(n); + Ok(()) + } + + /// Record that the local reader end has consumed `n` bytes. + pub fn record_consumed(&mut self, id: StreamId, n: u32) -> Result<(), StreamTableError> { + let entry = self.find_mut(id)?; + entry.reader.record_consumed(n); + Ok(()) + } + + /// Mark the writer as closed (sent StreamEnd). + pub fn mark_writer_closed(&mut self, id: StreamId) -> Result<(), StreamTableError> { + let entry = self.find_mut(id)?; + entry.state = match entry.state { + StreamLifecycle::Open => StreamLifecycle::WriterClosed, + other => other, + }; + Ok(()) + } + + /// Mark the stream as cancelled. + pub fn mark_cancelled(&mut self, id: StreamId) -> Result<(), StreamTableError> { + let entry = self.find_mut(id)?; + entry.state = StreamLifecycle::Cancelled; + Ok(()) + } + + /// Fully close and reap the entry. + pub fn close(&mut self, id: StreamId) -> Result<(), StreamTableError> { + let entry = self.find_mut(id)?; + entry.state = StreamLifecycle::Closed; + Ok(()) + } + + /// Returns true if a message with the given generation should be + /// accepted for this table. + pub fn accepts_generation(&self, generation: u8) -> bool { + generation == self.generation + } +} + +impl StreamEntry { + fn is_reaped(&self) -> bool { + matches!(self.state, StreamLifecycle::Closed) + } + + #[cfg(test)] + fn role(&self) -> EndRole { + self.role + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn available_credit_basic() { + assert_eq!(available_credit(100, 0, 0), 100); + assert_eq!(available_credit(100, 30, 0), 70); + assert_eq!(available_credit(100, 100, 0), 0); + assert_eq!(available_credit(100, 100, 30), 30); + } + + #[test] + fn available_credit_peer_shrink() { + // peer shrinks buf_alloc below in-flight bytes + assert_eq!(available_credit(10, 50, 30), 0); + } + + #[test] + fn available_credit_wrap() { + // tx_cnt has wrapped past peer_fwd_cnt + let tx = 5u32; + let peer_fwd = u32::MAX - 10; + let in_flight = tx.wrapping_sub(peer_fwd); + assert_eq!(in_flight, 16); + assert_eq!(available_credit(100, tx, peer_fwd), 84); + } + + #[test] + fn wrap_gt_near_zero_and_max() { + assert!(wrap_gt(10, 5)); + assert!(!wrap_gt(5, 10)); + // 1 is "newer" than u32::MAX despite smaller absolute value + assert!(wrap_gt(1, u32::MAX)); + assert!(!wrap_gt(u32::MAX, 1)); + } + + #[test] + fn writer_observe_monotonic() { + let mut w = WriterCredit::new(100); + w.observe_peer(10, 200); + assert_eq!(w.peer_fwd_cnt, 10); + assert_eq!(w.peer_buf_alloc, 200); + // Older fwd_cnt is ignored + w.observe_peer(5, 150); + assert_eq!(w.peer_fwd_cnt, 10); + assert_eq!(w.peer_buf_alloc, 150); + } + + #[test] + fn writer_record_sent_and_available() { + let mut w = WriterCredit::new(100); + assert_eq!(w.available(), 100); + w.record_sent(60); + assert_eq!(w.available(), 40); + w.observe_peer(50, 100); + assert_eq!(w.available(), 90); + } + + #[test] + fn table_allocates_and_closes() { + let mut t = StreamTable::new(StreamDirection::Guest2Host); + assert!(t.is_empty()); + let a = t.open_writer(4096, 4096).unwrap(); + let b = t.open_reader(4096, 4096).unwrap(); + assert_ne!(a, b); + assert_eq!(t.len(), 2); + assert_eq!(t.lifecycle(a), Ok(StreamLifecycle::Open)); + t.mark_writer_closed(a).unwrap(); + assert_eq!(t.lifecycle(a), Ok(StreamLifecycle::WriterClosed)); + t.close(a).unwrap(); + assert_eq!(t.lifecycle(a), Err(StreamTableError::Unknown(a))); + assert_eq!(t.len(), 1); + } + + #[test] + fn table_ids_are_not_reused() { + let mut t = StreamTable::new(StreamDirection::Host2Guest); + let a = t.open_writer(4096, 4096).unwrap(); + t.close(a).unwrap(); + let b = t.open_writer(4096, 4096).unwrap(); + assert_ne!(a, b); + } + + #[test] + fn table_reset_bumps_generation_and_drops_entries() { + let mut t = StreamTable::new(StreamDirection::Guest2Host); + let _a = t.open_writer(4096, 4096).unwrap(); + let _b = t.open_reader(4096, 4096).unwrap(); + assert_eq!(t.generation(), 0); + let g = t.reset(); + assert_eq!(g, 1); + assert!(t.is_empty()); + let c = t.open_writer(4096, 4096).unwrap(); + assert_eq!(c, StreamId(0)); + assert!(t.accepts_generation(1)); + assert!(!t.accepts_generation(0)); + } + + #[test] + fn table_reset_wraps_generation() { + let mut t = StreamTable::new(StreamDirection::Guest2Host); + for _ in 0..16 { + t.reset(); + } + assert_eq!(t.generation(), 0); + } + + #[test] + fn table_id_space_exhaustion() { + let mut t = StreamTable::new(StreamDirection::Guest2Host); + // Drain to the end of the 12-bit space. + t.next_id = STREAM_ID_MAX; + let id = t.open_writer(4096, 4096).unwrap(); + assert_eq!(id.as_u16(), STREAM_ID_MAX); + let err = t.open_writer(4096, 4096).unwrap_err(); + assert_eq!(err, StreamTableError::IdSpaceExhausted); + } + + #[test] + fn observe_and_counters_flow() { + let mut t = StreamTable::new(StreamDirection::Guest2Host); + let id = t.open_writer(4096, 1024).unwrap(); + t.record_sent(id, 500).unwrap(); + let credit = t.writer_credit(id).unwrap(); + assert_eq!(credit.tx_cnt, 500); + assert_eq!(credit.available(), 524); + + t.observe_peer(id, 200, 2048).unwrap(); + let credit = t.writer_credit(id).unwrap(); + assert_eq!(credit.peer_fwd_cnt, 200); + assert_eq!(credit.peer_buf_alloc, 2048); + assert_eq!(credit.available(), 2048 - (500 - 200)); + } + + #[test] + fn unknown_ids_are_rejected() { + let mut t = StreamTable::new(StreamDirection::Guest2Host); + let fake = StreamId(7); + assert_eq!(t.lifecycle(fake), Err(StreamTableError::Unknown(fake))); + assert_eq!(t.record_sent(fake, 1), Err(StreamTableError::Unknown(fake))); + } + + #[test] + fn role_is_tracked() { + let mut t = StreamTable::new(StreamDirection::Guest2Host); + let a = t.open_writer(4096, 4096).unwrap(); + let b = t.open_reader(4096, 4096).unwrap(); + assert_eq!(t.find(a).unwrap().role(), EndRole::Writer); + assert_eq!(t.find(b).unwrap().role(), EndRole::Reader); + } +} diff --git a/src/hyperlight_guest_bin/bindgen_wrapper.h b/src/hyperlight_guest_bin/bindgen_wrapper.h new file mode 100644 index 000000000..404384514 --- /dev/null +++ b/src/hyperlight_guest_bin/bindgen_wrapper.h @@ -0,0 +1,8 @@ +/* Bindgen wrapper for picolibc types used by hyperlight guest */ + +/* Enable POSIX clock definitions that picolibc guards behind __rtems__ */ +#define _POSIX_MONOTONIC_CLOCK 200112L + +#include +#include +#include diff --git a/src/hyperlight_guest_bin/src/lib.rs b/src/hyperlight_guest_bin/src/lib.rs index 450b54930..b1e8030c2 100644 --- a/src/hyperlight_guest_bin/src/lib.rs +++ b/src/hyperlight_guest_bin/src/lib.rs @@ -53,6 +53,7 @@ pub mod host_comm; pub mod memory; #[cfg(target_arch = "x86_64")] pub mod paging; +mod virtq_init; /// Bridge between picolibc's POSIX expectations and the Hyperlight host. /// cbindgen:ignore @@ -256,6 +257,9 @@ pub(crate) extern "C" fn generic_init( OS_PAGE_SIZE = ops as u32; } + // Initialize virtqueues + virtq_init::init_virtqueues(); + // set up the logger let guest_log_level_filter = GuestLogFilter::try_from(max_log_level).expect("Invalid log level"); diff --git a/src/hyperlight_guest_bin/src/virtq_init.rs b/src/hyperlight_guest_bin/src/virtq_init.rs new file mode 100644 index 000000000..1f24f5d9b --- /dev/null +++ b/src/hyperlight_guest_bin/src/virtq_init.rs @@ -0,0 +1,50 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Guest-side virtqueue initialization. +//! +//! The host places virtqueue rings at deterministic offsets in the +//! scratch region and writes ring GVAs and queue depths to scratch-top +//! metadata. + +use hyperlight_common::layout::{ + self, SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET, SCRATCH_TOP_G2H_RING_GVA_OFFSET, + SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET, SCRATCH_TOP_H2G_RING_GVA_OFFSET, +}; +use hyperlight_common::virtq::Layout as VirtqLayout; + +/// Read a value from a scratch-top metadata slot. +unsafe fn read_scratch_top(offset: u64) -> T { + let addr = (layout::MAX_GVA as u64 - offset + 1) as *const T; + unsafe { core::ptr::read_volatile(addr) } +} + +/// Initialize virtqueue ring memory in the scratch region. +pub(crate) fn init_virtqueues() { + let g2h_gva: u64 = unsafe { read_scratch_top(SCRATCH_TOP_G2H_RING_GVA_OFFSET) }; + let g2h_depth: u16 = unsafe { read_scratch_top(SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET) }; + let h2g_gva: u64 = unsafe { read_scratch_top(SCRATCH_TOP_H2G_RING_GVA_OFFSET) }; + let h2g_depth: u16 = unsafe { read_scratch_top(SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET) }; + + assert!(g2h_depth > 0 && h2g_depth > 0); + assert!(g2h_gva != 0 && h2g_gva != 0); + + let size = VirtqLayout::query_size(g2h_depth as usize); + unsafe { core::ptr::write_bytes(g2h_gva as *mut u8, 0, size) }; + + let size = VirtqLayout::query_size(h2g_depth as usize); + unsafe { core::ptr::write_bytes(h2g_gva as *mut u8, 0, size) }; +} diff --git a/src/hyperlight_host/Cargo.toml b/src/hyperlight_host/Cargo.toml index f65780504..985642f9f 100644 --- a/src/hyperlight_host/Cargo.toml +++ b/src/hyperlight_host/Cargo.toml @@ -21,6 +21,7 @@ bench = false # see https://bheisler.github.io/criterion.rs/book/faq.html#cargo- workspace = true [dependencies] +bytemuck = { version = "1.25", features = ["derive"] } gdbstub = { version = "0.7.10", optional = true } gdbstub_arch = { version = "0.3.3", optional = true } goblin = { version = "0.10", default-features = false, features = ["std", "elf32", "elf64", "endian_fd"] } diff --git a/src/hyperlight_host/src/mem/layout.rs b/src/hyperlight_host/src/mem/layout.rs index 26615d579..183cdc2af 100644 --- a/src/hyperlight_host/src/mem/layout.rs +++ b/src/hyperlight_host/src/mem/layout.rs @@ -340,6 +340,8 @@ impl SandboxMemoryLayout { let min_scratch_size = hyperlight_common::layout::min_scratch_size( cfg.get_input_data_size(), cfg.get_output_data_size(), + cfg.get_g2h_queue_depth(), + cfg.get_h2g_queue_depth(), ); if scratch_size < min_scratch_size { return Err(MemoryRequestTooSmall(scratch_size, min_scratch_size)); @@ -483,13 +485,55 @@ impl SandboxMemoryLayout { 0 } + /// Get the offset into the scratch region of the G2H ring. + fn get_g2h_ring_scratch_offset(&self) -> usize { + hyperlight_common::layout::g2h_ring_scratch_offset( + self.sandbox_memory_config.get_input_data_size(), + self.sandbox_memory_config.get_output_data_size(), + ) + } + + /// Get the size of the G2H ring in bytes. + fn get_g2h_ring_size(&self) -> usize { + hyperlight_common::virtq::Layout::query_size( + self.sandbox_memory_config.get_g2h_queue_depth(), + ) + } + + /// Get the offset into the scratch region of the H2G ring. + fn get_h2g_ring_scratch_offset(&self) -> usize { + hyperlight_common::layout::h2g_ring_scratch_offset( + self.sandbox_memory_config.get_input_data_size(), + self.sandbox_memory_config.get_output_data_size(), + self.sandbox_memory_config.get_g2h_queue_depth(), + ) + } + + /// Get the size of the H2G ring in bytes. + fn get_h2g_ring_size(&self) -> usize { + hyperlight_common::virtq::Layout::query_size( + self.sandbox_memory_config.get_h2g_queue_depth(), + ) + } + + /// Get the GVA of the G2H ring in guest address space. + pub(crate) fn get_g2h_ring_gva(&self) -> u64 { + hyperlight_common::layout::scratch_base_gva(self.scratch_size) + + self.get_g2h_ring_scratch_offset() as u64 + } + + /// Get the GVA of the H2G ring in guest address space. + pub(crate) fn get_h2g_ring_gva(&self) -> u64 { + hyperlight_common::layout::scratch_base_gva(self.scratch_size) + + self.get_h2g_ring_scratch_offset() as u64 + } + /// Get the offset from the beginning of the scratch region to the /// location where page tables will be eagerly copied on restore #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub(crate) fn get_pt_base_scratch_offset(&self) -> usize { - (self.sandbox_memory_config.get_input_data_size() - + self.sandbox_memory_config.get_output_data_size()) - .next_multiple_of(hyperlight_common::vmem::PAGE_SIZE) + let after_rings = self.get_h2g_ring_scratch_offset() + self.get_h2g_ring_size(); + after_rings.next_multiple_of(hyperlight_common::vmem::PAGE_SIZE) } /// Get the base GPA to which the page tables will be eagerly @@ -594,6 +638,8 @@ impl SandboxMemoryLayout { let min_fixed_scratch = hyperlight_common::layout::min_scratch_size( self.sandbox_memory_config.get_input_data_size(), self.sandbox_memory_config.get_output_data_size(), + self.sandbox_memory_config.get_g2h_queue_depth(), + self.sandbox_memory_config.get_h2g_queue_depth(), ); let min_scratch = min_fixed_scratch + size; if self.scratch_size < min_scratch { @@ -816,6 +862,10 @@ impl SandboxMemoryLayout { // initialised here, because they are in the scratch // region---they are instead set in // [`SandboxMemoryManager::update_scratch_bookkeeping`]. + // + // Virtqueue ring layouts are also communicated via scratch-top + // metadata (queue depths), not the PEB. Both host and guest + // compute ring addresses from shared offset functions. Ok(()) } diff --git a/src/hyperlight_host/src/mem/mgr.rs b/src/hyperlight_host/src/mem/mgr.rs index 68f35ff7d..38a283614 100644 --- a/src/hyperlight_host/src/mem/mgr.rs +++ b/src/hyperlight_host/src/mem/mgr.rs @@ -15,6 +15,7 @@ limitations under the License. */ #[cfg(feature = "nanvix-unstable")] use std::mem::offset_of; +use std::num::NonZeroU16; use flatbuffers::FlatBufferBuilder; use hyperlight_common::flatbuffer_wrappers::function_call::{ @@ -22,7 +23,8 @@ use hyperlight_common::flatbuffer_wrappers::function_call::{ }; use hyperlight_common::flatbuffer_wrappers::function_types::FunctionCallResult; use hyperlight_common::flatbuffer_wrappers::guest_log_data::GuestLogData; -use hyperlight_common::vmem::{self, PAGE_TABLE_SIZE}; +use hyperlight_common::virtq::Layout as VirtqLayout; +use hyperlight_common::vmem::{self, PAGE_TABLE_SIZE, PageTableEntry, PhysAddr}; #[cfg(all(feature = "crashdump", not(feature = "i686-guest")))] use hyperlight_common::vmem::{BasicMapping, MappingKind}; use tracing::{Span, instrument}; @@ -612,6 +614,25 @@ impl SandboxMemoryManager { SandboxMemoryLayout::STACK_POINTER_SIZE_BYTES, )?; + // Write virtqueue metadata to scratch-top so the guest can + // discover ring locations without reading the PEB. + self.update_scratch_bookkeeping_item( + SCRATCH_TOP_G2H_RING_GVA_OFFSET, + self.layout.get_g2h_ring_gva(), + )?; + self.update_scratch_bookkeeping_item( + SCRATCH_TOP_H2G_RING_GVA_OFFSET, + self.layout.get_h2g_ring_gva(), + )?; + self.scratch_mem.write::( + scratch_size - SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET as usize, + self.layout.sandbox_memory_config.get_g2h_queue_depth() as u16, + )?; + self.scratch_mem.write::( + scratch_size - SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET as usize, + self.layout.sandbox_memory_config.get_h2g_queue_depth() as u16, + )?; + // Copy page tables from `shared_mem` into scratch. PT bytes // are appended to the snapshot blob at build time and live // just past the end of the guest-visible KVM slot (see @@ -856,6 +877,30 @@ impl SandboxMemoryManager { }) })?? } + + /// Compute the G2H virtqueue Layout from scratch region addresses. + pub(crate) fn g2h_virtq_layout(&self) -> Result { + let base = self.layout.get_g2h_ring_gva(); + let depth = self.layout.sandbox_memory_config.get_g2h_queue_depth(); + + let nz = NonZeroU16::new(depth as u16) + .ok_or_else(|| new_error!("G2H queue depth is zero"))?; + + unsafe { VirtqLayout::from_base(base, nz) } + .map_err(|e| new_error!("Invalid G2H virtq layout: {:?}", e)) + } + + /// Compute the H2G virtqueue Layout from scratch region addresses. + pub(crate) fn h2g_virtq_layout(&self) -> Result { + let base = self.layout.get_h2g_ring_gva(); + let depth = self.layout.sandbox_memory_config.get_h2g_queue_depth(); + + let nz = NonZeroU16::new(depth as u16) + .ok_or_else(|| new_error!("H2G queue depth is zero"))?; + + unsafe { VirtqLayout::from_base(base, nz) } + .map_err(|e| new_error!("Invalid H2G virtq layout: {:?}", e)) + } } #[cfg(test)] diff --git a/src/hyperlight_host/src/mem/shared_mem.rs b/src/hyperlight_host/src/mem/shared_mem.rs index 5f975f605..db1b407c7 100644 --- a/src/hyperlight_host/src/mem/shared_mem.rs +++ b/src/hyperlight_host/src/mem/shared_mem.rs @@ -878,57 +878,25 @@ impl SharedMemory for GuestSharedMemory { } } -/// An unsafe marker trait for types for which all bit patterns are valid. -/// This is required in order for it to be safe to read a value of a particular -/// type out of the sandbox from the HostSharedMemory. -/// -/// # Safety -/// This must only be implemented for types for which all bit patterns -/// are valid. It requires that any (non-undef/poison) value of the -/// correct size can be transmuted to the type. -pub unsafe trait AllValid {} -unsafe impl AllValid for u8 {} -unsafe impl AllValid for u16 {} -unsafe impl AllValid for u32 {} -unsafe impl AllValid for u64 {} -unsafe impl AllValid for i8 {} -unsafe impl AllValid for i16 {} -unsafe impl AllValid for i32 {} -unsafe impl AllValid for i64 {} -unsafe impl AllValid for [u8; 16] {} - impl HostSharedMemory { - /// Read a value of type T, whose representation is the same - /// between the sandbox and the host, and which has no invalid bit - /// patterns - pub fn read(&self, offset: usize) -> Result { + /// Read a value of type T from the sandbox at the given offset. + /// + /// T must implement [`bytemuck::Pod`] which guarantees all bit + /// patterns are valid and there is no padding. + pub fn read(&self, offset: usize) -> Result { bounds_check!(offset, std::mem::size_of::(), self.mem_size()); - unsafe { - let mut ret: core::mem::MaybeUninit = core::mem::MaybeUninit::uninit(); - { - let slice: &mut [u8] = core::slice::from_raw_parts_mut( - ret.as_mut_ptr() as *mut u8, - std::mem::size_of::(), - ); - self.copy_to_slice(slice, offset)?; - } - Ok(ret.assume_init()) - } + let mut val = T::zeroed(); + self.copy_to_slice(bytemuck::bytes_of_mut(&mut val), offset)?; + Ok(val) } - /// Write a value of type T, whose representation is the same - /// between the sandbox and the host, and which has no invalid bit - /// patterns - pub fn write(&self, offset: usize, data: T) -> Result<()> { + /// Write a value of type T into the sandbox at the given offset. + /// + /// T must implement [`bytemuck::Pod`] which guarantees all bit + /// patterns are valid and there is no padding. + pub fn write(&self, offset: usize, data: T) -> Result<()> { bounds_check!(offset, std::mem::size_of::(), self.mem_size()); - unsafe { - let slice: &[u8] = core::slice::from_raw_parts( - core::ptr::addr_of!(data) as *const u8, - std::mem::size_of::(), - ); - self.copy_from_slice(slice, offset)?; - } - Ok(()) + self.copy_from_slice(bytemuck::bytes_of(&data), offset) } /// Copy the contents of the slice into the sandbox at the diff --git a/src/hyperlight_host/src/sandbox/config.rs b/src/hyperlight_host/src/sandbox/config.rs index f12387a0b..120aa06cd 100644 --- a/src/hyperlight_host/src/sandbox/config.rs +++ b/src/hyperlight_host/src/sandbox/config.rs @@ -74,6 +74,12 @@ pub struct SandboxConfiguration { interrupt_vcpu_sigrtmin_offset: u8, /// How much writable memory to offer the guest scratch_size: usize, + /// Number of descriptors for the G2H (guest-to-host) virtqueue. Must be a power of 2. + /// Default: 64 sized to 2x H2G depth for deadlock prevention. + g2h_queue_depth: usize, + /// Number of descriptors for the host-to-guest virtqueue. Must be a power of 2. + /// Default: 32 + h2g_queue_depth: usize, } impl SandboxConfiguration { @@ -93,6 +99,10 @@ impl SandboxConfiguration { pub const DEFAULT_HEAP_SIZE: u64 = 131072; /// The default size of the scratch region pub const DEFAULT_SCRATCH_SIZE: usize = 0x48000; + /// The default G2H virtqueue depth (number of descriptors, must be power of 2) + pub const DEFAULT_G2H_QUEUE_DEPTH: usize = 64; + /// The default H2G virtqueue depth (number of descriptors, must be power of 2) + pub const DEFAULT_H2G_QUEUE_DEPTH: usize = 32; #[allow(clippy::too_many_arguments)] /// Create a new configuration for a sandbox with the given sizes. @@ -114,6 +124,8 @@ impl SandboxConfiguration { scratch_size, interrupt_retry_delay, interrupt_vcpu_sigrtmin_offset, + g2h_queue_depth: Self::DEFAULT_G2H_QUEUE_DEPTH, + h2g_queue_depth: Self::DEFAULT_H2G_QUEUE_DEPTH, #[cfg(gdb)] guest_debug_info, #[cfg(crashdump)] @@ -209,6 +221,16 @@ impl SandboxConfiguration { self.scratch_size } + /// Get the G2H virtqueue depth (number of descriptors). + pub fn get_g2h_queue_depth(&self) -> usize { + self.g2h_queue_depth + } + + /// Get the H2G virtqueue depth (number of descriptors). + pub fn get_h2g_queue_depth(&self) -> usize { + self.h2g_queue_depth + } + /// Set the size of the scratch regiong #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub fn set_scratch_size(&mut self, scratch_size: usize) { diff --git a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs index 241622cab..8b3cf8db2 100644 --- a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs +++ b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs @@ -1144,6 +1144,8 @@ mod tests { let min_scratch = hyperlight_common::layout::min_scratch_size( cfg.get_input_data_size(), cfg.get_output_data_size(), + cfg.get_g2h_queue_depth(), + cfg.get_h2g_queue_depth(), ); cfg.set_scratch_size(min_scratch + 0x10000 + 0x10000); diff --git a/src/hyperlight_host/src/sandbox/outb.rs b/src/hyperlight_host/src/sandbox/outb.rs index 9704a1fe3..bb73763a6 100644 --- a/src/hyperlight_host/src/sandbox/outb.rs +++ b/src/hyperlight_host/src/sandbox/outb.rs @@ -227,6 +227,10 @@ pub(crate) fn handle_outb( eprint!("{}", ch); Ok(()) } + OutBAction::VirtqNotify => { + // TODO(ring): acknowledge notification but no-op for now. + Ok(()) + } #[cfg(feature = "trace_guest")] OutBAction::TraceBatch => Ok(()), #[cfg(feature = "mem_profile")] diff --git a/src/hyperlight_libc/third_party/mimalloc b/src/hyperlight_libc/third_party/mimalloc new file mode 160000 index 000000000..09a27098a --- /dev/null +++ b/src/hyperlight_libc/third_party/mimalloc @@ -0,0 +1 @@ +Subproject commit 09a27098aa6e9286518bd9c74e6ffa7199c3f04e diff --git a/src/tests/rust_guests/dummyguest/Cargo.lock b/src/tests/rust_guests/dummyguest/Cargo.lock index 9cb6885bb..f2085335f 100644 --- a/src/tests/rust_guests/dummyguest/Cargo.lock +++ b/src/tests/rust_guests/dummyguest/Cargo.lock @@ -17,6 +17,12 @@ version = "1.0.102" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" +[[package]] +name = "atomic_refcell" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21e4227379beff4205943696e6c3e0cd809bacdf3f0edd6e3dd153e2269571a4" + [[package]] name = "bindgen" version = "0.71.1" @@ -72,11 +78,17 @@ dependencies = [ "syn", ] +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" + [[package]] name = "cc" -version = "1.2.60" +version = "1.2.62" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43c5703da9466b66a946814e1adf53ea2c90f10063b86290cc9eb67ce3478a20" +checksum = "a1dce859f0832a7d088c4f1119888ab94ef4b5d6795d1ce05afb7fe159d79f98" dependencies = [ "find-msvc-tools", "shlex", @@ -134,6 +146,12 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" +[[package]] +name = "fixedbitset" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" + [[package]] name = "flatbuffers" version = "25.12.19" @@ -152,17 +170,20 @@ checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" [[package]] name = "hashbrown" -version = "0.17.0" +version = "0.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f467dd6dccf739c208452f8014c75c18bb8301b050ad1cfb27153803edb0f51" +checksum = "ed5909b6e89a2db4456e54cd5f673791d7eca6732202bbf2a9cc504fe2f9b84a" [[package]] name = "hyperlight-common" version = "0.15.0" dependencies = [ "anyhow", + "atomic_refcell", "bitflags", "bytemuck", + "bytes", + "fixedbitset", "flatbuffers", "log", "smallvec", @@ -593,9 +614,9 @@ checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" [[package]] name = "winnow" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09dac053f1cd375980747450bfc7250c264eaae0583872e845c0c7cd578872b5" +checksum = "2ee1708bef14716a11bae175f579062d4554d95be2c6829f518df847b7b3fdd0" dependencies = [ "memchr", ] diff --git a/src/tests/rust_guests/simpleguest/Cargo.lock b/src/tests/rust_guests/simpleguest/Cargo.lock index 23475a643..455139d0d 100644 --- a/src/tests/rust_guests/simpleguest/Cargo.lock +++ b/src/tests/rust_guests/simpleguest/Cargo.lock @@ -17,6 +17,12 @@ version = "1.0.102" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" +[[package]] +name = "atomic_refcell" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21e4227379beff4205943696e6c3e0cd809bacdf3f0edd6e3dd153e2269571a4" + [[package]] name = "bindgen" version = "0.71.1" @@ -72,11 +78,17 @@ dependencies = [ "syn", ] +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" + [[package]] name = "cc" -version = "1.2.60" +version = "1.2.62" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43c5703da9466b66a946814e1adf53ea2c90f10063b86290cc9eb67ce3478a20" +checksum = "a1dce859f0832a7d088c4f1119888ab94ef4b5d6795d1ce05afb7fe159d79f98" dependencies = [ "find-msvc-tools", "shlex", @@ -126,6 +138,12 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" +[[package]] +name = "fixedbitset" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" + [[package]] name = "flatbuffers" version = "25.12.19" @@ -144,17 +162,20 @@ checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" [[package]] name = "hashbrown" -version = "0.17.0" +version = "0.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f467dd6dccf739c208452f8014c75c18bb8301b050ad1cfb27153803edb0f51" +checksum = "ed5909b6e89a2db4456e54cd5f673791d7eca6732202bbf2a9cc504fe2f9b84a" [[package]] name = "hyperlight-common" version = "0.15.0" dependencies = [ "anyhow", + "atomic_refcell", "bitflags", "bytemuck", + "bytes", + "fixedbitset", "flatbuffers", "log", "smallvec", @@ -598,9 +619,9 @@ checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" [[package]] name = "winnow" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09dac053f1cd375980747450bfc7250c264eaae0583872e845c0c7cd578872b5" +checksum = "2ee1708bef14716a11bae175f579062d4554d95be2c6829f518df847b7b3fdd0" dependencies = [ "memchr", ] diff --git a/src/tests/rust_guests/witguest/Cargo.lock b/src/tests/rust_guests/witguest/Cargo.lock index f9a6ffa6c..70f41063d 100644 --- a/src/tests/rust_guests/witguest/Cargo.lock +++ b/src/tests/rust_guests/witguest/Cargo.lock @@ -67,6 +67,12 @@ version = "1.0.102" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" +[[package]] +name = "atomic_refcell" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21e4227379beff4205943696e6c3e0cd809bacdf3f0edd6e3dd153e2269571a4" + [[package]] name = "bindgen" version = "0.71.1" @@ -122,11 +128,17 @@ dependencies = [ "syn", ] +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" + [[package]] name = "cc" -version = "1.2.61" +version = "1.2.62" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d16d90359e986641506914ba71350897565610e87ce0ad9e6f28569db3dd5c6d" +checksum = "a1dce859f0832a7d088c4f1119888ab94ef4b5d6795d1ce05afb7fe159d79f98" dependencies = [ "find-msvc-tools", "shlex", @@ -205,6 +217,12 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" +[[package]] +name = "fixedbitset" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" + [[package]] name = "flatbuffers" version = "25.12.19" @@ -229,9 +247,9 @@ checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" [[package]] name = "hashbrown" -version = "0.17.0" +version = "0.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f467dd6dccf739c208452f8014c75c18bb8301b050ad1cfb27153803edb0f51" +checksum = "ed5909b6e89a2db4456e54cd5f673791d7eca6732202bbf2a9cc504fe2f9b84a" dependencies = [ "foldhash", "serde", @@ -243,8 +261,11 @@ name = "hyperlight-common" version = "0.15.0" dependencies = [ "anyhow", + "atomic_refcell", "bitflags", "bytemuck", + "bytes", + "fixedbitset", "flatbuffers", "log", "smallvec", From 4f8e4449c9ee6ded073cb408a88ddbf8f2dea03c Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Wed, 25 Mar 2026 15:05:52 +0100 Subject: [PATCH 03/31] feat(virtq): add MemOps for host and guest Signed-off-by: Tomasz Andrzejak --- src/hyperlight_guest/src/lib.rs | 1 + src/hyperlight_guest/src/virtq_mem.rs | 67 ++++++++++++ src/hyperlight_host/src/mem/mgr.rs | 10 +- src/hyperlight_host/src/mem/mod.rs | 2 + src/hyperlight_host/src/mem/virtq_mem.rs | 124 +++++++++++++++++++++++ 5 files changed, 198 insertions(+), 6 deletions(-) create mode 100644 src/hyperlight_guest/src/virtq_mem.rs create mode 100644 src/hyperlight_host/src/mem/virtq_mem.rs diff --git a/src/hyperlight_guest/src/lib.rs b/src/hyperlight_guest/src/lib.rs index 19e5ac5f2..1aa456b31 100644 --- a/src/hyperlight_guest/src/lib.rs +++ b/src/hyperlight_guest/src/lib.rs @@ -26,6 +26,7 @@ pub mod exit; pub mod layout; pub mod prim_alloc; pub mod types; +pub mod virtq_mem; pub mod guest_handle { pub mod handle; diff --git a/src/hyperlight_guest/src/virtq_mem.rs b/src/hyperlight_guest/src/virtq_mem.rs new file mode 100644 index 000000000..8309deb79 --- /dev/null +++ b/src/hyperlight_guest/src/virtq_mem.rs @@ -0,0 +1,67 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Guest-side [`MemOps`] implementation for virtqueue access. + +use core::convert::Infallible; +use core::sync::atomic::{AtomicU16, Ordering}; +use core::{ptr, slice}; + +use hyperlight_common::virtq::MemOps; + +/// Guest-side memory accessor for virtqueue operations. Treats virtq +/// addresses as guest virtual addresses that map directly to memory. +#[derive(Clone, Copy, Debug)] +pub struct GuestMemOps; + +impl MemOps for GuestMemOps { + type Error = Infallible; + + fn read(&self, addr: u64, dst: &mut [u8]) -> Result { + let src = addr as *const u8; + unsafe { + ptr::copy_nonoverlapping(src, dst.as_mut_ptr(), dst.len()); + } + Ok(dst.len()) + } + + fn write(&self, addr: u64, src: &[u8]) -> Result { + let dst = addr as *mut u8; + unsafe { + ptr::copy_nonoverlapping(src.as_ptr(), dst, src.len()); + } + Ok(src.len()) + } + + fn load_acquire(&self, addr: u64) -> Result { + let ptr = addr as *const AtomicU16; + Ok(unsafe { (*ptr).load(Ordering::Acquire) }) + } + + fn store_release(&self, addr: u64, val: u16) -> Result<(), Self::Error> { + let ptr = addr as *const AtomicU16; + unsafe { (*ptr).store(val, Ordering::Release) }; + Ok(()) + } + + unsafe fn as_slice(&self, addr: u64, len: usize) -> Result<&[u8], Self::Error> { + Ok(unsafe { slice::from_raw_parts(addr as *const u8, len) }) + } + + unsafe fn as_mut_slice(&self, addr: u64, len: usize) -> Result<&mut [u8], Self::Error> { + Ok(unsafe { slice::from_raw_parts_mut(addr as *mut u8, len) }) + } +} diff --git a/src/hyperlight_host/src/mem/mgr.rs b/src/hyperlight_host/src/mem/mgr.rs index 38a283614..7356e12a0 100644 --- a/src/hyperlight_host/src/mem/mgr.rs +++ b/src/hyperlight_host/src/mem/mgr.rs @@ -881,10 +881,9 @@ impl SandboxMemoryManager { /// Compute the G2H virtqueue Layout from scratch region addresses. pub(crate) fn g2h_virtq_layout(&self) -> Result { let base = self.layout.get_g2h_ring_gva(); - let depth = self.layout.sandbox_memory_config.get_g2h_queue_depth(); + let depth = self.layout.sandbox_memory_config.get_g2h_queue_depth() as u16; - let nz = NonZeroU16::new(depth as u16) - .ok_or_else(|| new_error!("G2H queue depth is zero"))?; + let nz = NonZeroU16::new(depth).ok_or_else(|| new_error!("G2H queue depth is zero"))?; unsafe { VirtqLayout::from_base(base, nz) } .map_err(|e| new_error!("Invalid G2H virtq layout: {:?}", e)) @@ -893,10 +892,9 @@ impl SandboxMemoryManager { /// Compute the H2G virtqueue Layout from scratch region addresses. pub(crate) fn h2g_virtq_layout(&self) -> Result { let base = self.layout.get_h2g_ring_gva(); - let depth = self.layout.sandbox_memory_config.get_h2g_queue_depth(); + let depth = self.layout.sandbox_memory_config.get_h2g_queue_depth() as u16; - let nz = NonZeroU16::new(depth as u16) - .ok_or_else(|| new_error!("H2G queue depth is zero"))?; + let nz = NonZeroU16::new(depth).ok_or_else(|| new_error!("H2G queue depth is zero"))?; unsafe { VirtqLayout::from_base(base, nz) } .map_err(|e| new_error!("Invalid H2G virtq layout: {:?}", e)) diff --git a/src/hyperlight_host/src/mem/mod.rs b/src/hyperlight_host/src/mem/mod.rs index 64f5db2fe..4882bc75c 100644 --- a/src/hyperlight_host/src/mem/mod.rs +++ b/src/hyperlight_host/src/mem/mod.rs @@ -38,3 +38,5 @@ pub mod shared_mem; /// Utilities for writing shared memory tests #[cfg(all(test, not(miri)))] // uses proptest which isn't miri-compatible pub(crate) mod shared_mem_tests; +/// Host-side [`hyperlight_common::virtq::MemOps`] for virtqueue access. +pub(crate) mod virtq_mem; diff --git a/src/hyperlight_host/src/mem/virtq_mem.rs b/src/hyperlight_host/src/mem/virtq_mem.rs new file mode 100644 index 000000000..f96674c1d --- /dev/null +++ b/src/hyperlight_host/src/mem/virtq_mem.rs @@ -0,0 +1,124 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Host-side [`MemOps`] implementation for virtqueue access. +//! +//! Translates guest virtual addresses used in virtqueue descriptors +//! to offsets into the scratch [`HostSharedMemory`], reusing its +//! volatile access and locking patterns. + +use core::sync::atomic::{AtomicU16, Ordering}; + +use hyperlight_common::virtq::MemOps; + +use super::shared_mem::{HostSharedMemory, SharedMemory}; + +/// Error type for host memory operations. +#[derive(Debug, thiserror::Error)] +pub enum HostMemError { + #[error("address {addr:#x} out of bounds scratch_size={scratch_size}")] + OutOfBounds { addr: u64, scratch_size: usize }, + #[error("shared memory error: {0}")] + SharedMem(String), + #[error("as_slice/as_mut_slice not supported on host")] + DirectSliceNotSupported, +} + +/// Host-side memory accessor for virtqueue operations. +/// +/// Owns a clone of the scratch [`HostSharedMemory`] and translates +/// guest virtual addresses (in the scratch region) to offsets for the +/// existing volatile read/write methods. +#[derive(Clone)] +pub(crate) struct HostMemOps { + /// Cloned handle to the scratch shared memory + scratch: HostSharedMemory, + /// The guest virtual address that corresponds to scratch offset 0. + scratch_base_gva: u64, +} + +impl HostMemOps { + /// Create a new `HostMemOps` backed by shared memory. + pub(crate) fn new(scratch: &HostSharedMemory, scratch_base_gva: u64) -> Self { + Self { + scratch: scratch.clone(), + scratch_base_gva, + } + } + + /// Translate a guest virtual address to a scratch offset. + fn to_offset(&self, addr: u64) -> Result { + addr.checked_sub(self.scratch_base_gva) + .map(|o| o as usize) + .ok_or(HostMemError::OutOfBounds { + addr, + scratch_size: self.scratch.mem_size(), + }) + } + + /// Get a raw pointer into scratch memory at the given guest address. + fn raw_ptr(&self, addr: u64, len: usize) -> Result<*mut u8, HostMemError> { + let offset = self.to_offset(addr)?; + let scratch_size = self.scratch.mem_size(); + + if offset.checked_add(len).is_none_or(|end| end > scratch_size) { + return Err(HostMemError::OutOfBounds { addr, scratch_size }); + } + + Ok(self.scratch.base_ptr().wrapping_add(offset)) + } +} + +impl MemOps for HostMemOps { + type Error = HostMemError; + + fn read(&self, addr: u64, dst: &mut [u8]) -> Result { + let offset = self.to_offset(addr)?; + self.scratch + .copy_to_slice(dst, offset) + .map_err(|e| HostMemError::SharedMem(e.to_string()))?; + Ok(dst.len()) + } + + fn write(&self, addr: u64, src: &[u8]) -> Result { + let offset = self.to_offset(addr)?; + self.scratch + .copy_from_slice(src, offset) + .map_err(|e| HostMemError::SharedMem(e.to_string()))?; + Ok(src.len()) + } + + fn load_acquire(&self, addr: u64) -> Result { + let ptr = self.raw_ptr(addr, core::mem::size_of::())?; + let atomic = unsafe { &*(ptr as *const AtomicU16) }; + Ok(atomic.load(Ordering::Acquire)) + } + + fn store_release(&self, addr: u64, val: u16) -> Result<(), Self::Error> { + let ptr = self.raw_ptr(addr, core::mem::size_of::())?; + let atomic = unsafe { &*(ptr as *const AtomicU16) }; + atomic.store(val, Ordering::Release); + Ok(()) + } + + unsafe fn as_slice(&self, _addr: u64, _len: usize) -> Result<&[u8], Self::Error> { + Err(HostMemError::DirectSliceNotSupported) + } + + unsafe fn as_mut_slice(&self, _addr: u64, _len: usize) -> Result<&mut [u8], Self::Error> { + Err(HostMemError::DirectSliceNotSupported) + } +} From 4ebb11012f4586a90f4a2b371d609d3e583f44fc Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Wed, 25 Mar 2026 17:11:20 +0100 Subject: [PATCH 04/31] feat(virtq): create G2H producer during guest init Signed-off-by: Tomasz Andrzejak --- src/hyperlight_common/src/layout.rs | 11 ++- src/hyperlight_common/src/virtq/pool.rs | 24 ++++++ src/hyperlight_guest_bin/src/lib.rs | 4 +- src/hyperlight_guest_bin/src/virtq/mod.rs | 68 +++++++++++++++++ src/hyperlight_guest_bin/src/virtq/state.rs | 75 +++++++++++++++++++ src/hyperlight_guest_bin/src/virtq_init.rs | 50 ------------- src/hyperlight_host/src/mem/mgr.rs | 4 + src/hyperlight_host/src/sandbox/config.rs | 26 +++++++ src/hyperlight_host/tests/integration_test.rs | 10 ++- 9 files changed, 216 insertions(+), 56 deletions(-) create mode 100644 src/hyperlight_guest_bin/src/virtq/mod.rs create mode 100644 src/hyperlight_guest_bin/src/virtq/state.rs delete mode 100644 src/hyperlight_guest_bin/src/virtq_init.rs diff --git a/src/hyperlight_common/src/layout.rs b/src/hyperlight_common/src/layout.rs index 234ad6f78..70622343b 100644 --- a/src/hyperlight_common/src/layout.rs +++ b/src/hyperlight_common/src/layout.rs @@ -41,9 +41,9 @@ pub const SCRATCH_TOP_G2H_RING_GVA_OFFSET: u64 = 0x28; pub const SCRATCH_TOP_H2G_RING_GVA_OFFSET: u64 = 0x30; pub const SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET: u64 = 0x38; pub const SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET: u64 = 0x3a; +pub const SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET: u64 = 0x3c; pub const SCRATCH_TOP_EXN_STACK_OFFSET: u64 = 0x40; -// fields must not overlap, and exception stack address must be 16-byte aligned. const _: () = { assert!(SCRATCH_TOP_SIZE_OFFSET + 8 <= SCRATCH_TOP_ALLOCATOR_OFFSET); assert!(SCRATCH_TOP_ALLOCATOR_OFFSET + 8 <= SCRATCH_TOP_SNAPSHOT_PT_GPA_BASE_OFFSET); @@ -52,7 +52,8 @@ const _: () = { assert!(SCRATCH_TOP_G2H_RING_GVA_OFFSET + 8 <= SCRATCH_TOP_H2G_RING_GVA_OFFSET); assert!(SCRATCH_TOP_H2G_RING_GVA_OFFSET + 8 <= SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET); assert!(SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET + 2 <= SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET); - assert!(SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET + 2 <= SCRATCH_TOP_EXN_STACK_OFFSET); + assert!(SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET + 2 <= SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET); + assert!(SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET + 2 <= SCRATCH_TOP_EXN_STACK_OFFSET); assert!(SCRATCH_TOP_EXN_STACK_OFFSET % 0x10 == 0); }; @@ -72,9 +73,13 @@ pub fn scratch_base_gva(size: usize) -> u64 { (MAX_GVA - size + 1) as u64 } +pub const fn scratch_top_ptr(offset: u64) -> *mut T { + (MAX_GVA as u64 - offset + 1) as *mut T +} + /// Compute the byte offset from the scratch base to the G2H ring. /// -/// TODO(ring): Remove input/output +/// TODO(virtq): Remove input/output pub const fn g2h_ring_scratch_offset(input_data_size: usize, output_data_size: usize) -> usize { let io_off = input_data_size + output_data_size; let align = crate::virtq::Descriptor::ALIGN; diff --git a/src/hyperlight_common/src/virtq/pool.rs b/src/hyperlight_common/src/virtq/pool.rs index 0324c08fe..99d8f4109 100644 --- a/src/hyperlight_common/src/virtq/pool.rs +++ b/src/hyperlight_common/src/virtq/pool.rs @@ -126,6 +126,30 @@ pub trait BufferProvider { fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result; } +impl BufferProvider for alloc::rc::Rc { + fn alloc(&self, len: usize) -> Result { + (**self).alloc(len) + } + fn dealloc(&self, alloc: Allocation) -> Result<(), AllocError> { + (**self).dealloc(alloc) + } + fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result { + (**self).resize(old_alloc, new_len) + } +} + +impl BufferProvider for alloc::sync::Arc { + fn alloc(&self, len: usize) -> Result { + (**self).alloc(len) + } + fn dealloc(&self, alloc: Allocation) -> Result<(), AllocError> { + (**self).dealloc(alloc) + } + fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result { + (**self).resize(old_alloc, new_len) + } +} + /// The owner of a mapped buffer, ensuring its lifetime. /// /// Holds a pool allocation and provides direct access to the underlying diff --git a/src/hyperlight_guest_bin/src/lib.rs b/src/hyperlight_guest_bin/src/lib.rs index b1e8030c2..88b449267 100644 --- a/src/hyperlight_guest_bin/src/lib.rs +++ b/src/hyperlight_guest_bin/src/lib.rs @@ -53,7 +53,7 @@ pub mod host_comm; pub mod memory; #[cfg(target_arch = "x86_64")] pub mod paging; -mod virtq_init; +mod virtq; /// Bridge between picolibc's POSIX expectations and the Hyperlight host. /// cbindgen:ignore @@ -258,7 +258,7 @@ pub(crate) extern "C" fn generic_init( } // Initialize virtqueues - virtq_init::init_virtqueues(); + virtq::init_virtqueues(); // set up the logger let guest_log_level_filter = diff --git a/src/hyperlight_guest_bin/src/virtq/mod.rs b/src/hyperlight_guest_bin/src/virtq/mod.rs new file mode 100644 index 000000000..50c0dd6d9 --- /dev/null +++ b/src/hyperlight_guest_bin/src/virtq/mod.rs @@ -0,0 +1,68 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Guest-side virtqueue initialization. +//! +//! Zeroes ring memory and creates VirtqProducer instances by allocating +//! buffer pool pages from the scratch page allocator. + +pub(crate) mod state; + +use hyperlight_common::layout::{ + SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET, SCRATCH_TOP_G2H_RING_GVA_OFFSET, + SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET, SCRATCH_TOP_H2G_RING_GVA_OFFSET, + SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET, scratch_top_ptr, +}; +use hyperlight_common::mem::PAGE_SIZE_USIZE; +use hyperlight_common::virtq::Layout as VirtqLayout; +use hyperlight_guest::prim_alloc::alloc_phys_pages; + +use crate::paging::phys_to_virt; + +/// Initialize virtqueue producers for G2H and H2G queues. +pub(crate) fn init_virtqueues() { + let g2h_gva = unsafe { *scratch_top_ptr::(SCRATCH_TOP_G2H_RING_GVA_OFFSET) }; + let g2h_depth = unsafe { *scratch_top_ptr::(SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET) }; + let h2g_gva = unsafe { *scratch_top_ptr::(SCRATCH_TOP_H2G_RING_GVA_OFFSET) }; + let h2g_depth = unsafe { *scratch_top_ptr::(SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET) }; + let pool_pages = unsafe { *scratch_top_ptr::(SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET) } as u64; + + assert!(g2h_depth > 0 && h2g_depth > 0); + assert!(g2h_gva != 0 && h2g_gva != 0); + assert!(pool_pages > 0); + + // Zero ring memory + let g2h_ring_size = VirtqLayout::query_size(g2h_depth as usize); + unsafe { core::ptr::write_bytes(g2h_gva as *mut u8, 0, g2h_ring_size) }; + + let h2g_ring_size = VirtqLayout::query_size(h2g_depth as usize); + unsafe { core::ptr::write_bytes(h2g_gva as *mut u8, 0, h2g_ring_size) }; + + // Allocate buffer pool from physical pages + let pool_gpa = unsafe { alloc_phys_pages(pool_pages) }; + let pool_ptr = phys_to_virt(pool_gpa).expect("failed to map pool pages"); + let pool_gva = pool_ptr as u64; + let pool_size = pool_pages as usize * PAGE_SIZE_USIZE; + unsafe { core::ptr::write_bytes(pool_ptr, 0, pool_size) }; + + // Create G2H producer + unsafe { + state::init_g2h_producer(g2h_gva, g2h_depth, pool_gva, pool_size); + } + + // TODO(virtq): add other direction's producer + let _ = (h2g_gva, h2g_depth); +} diff --git a/src/hyperlight_guest_bin/src/virtq/state.rs b/src/hyperlight_guest_bin/src/virtq/state.rs new file mode 100644 index 000000000..232726377 --- /dev/null +++ b/src/hyperlight_guest_bin/src/virtq/state.rs @@ -0,0 +1,75 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Guest-side virtqueue state and initialization. +//! +//! Holds the global VirtqProducer instances for G2H and H2G queues. +//! The producers are created during guest init (from `hyperlight_guest_bin`) +//! and used by the guest host-call path in `host_comm`. + +use alloc::rc::Rc; +use core::cell::RefCell; +use core::num::NonZeroU16; + +use hyperlight_common::virtq::{BufferPool, Layout, Notifier, QueueStats, VirtqProducer}; +use hyperlight_guest::virtq_mem::GuestMemOps; + +/// Wrapper to mark types as Sync for single-threaded guest execution. +struct SyncWrap(T); + +// SAFETY: guest execution is single-threaded. +unsafe impl Sync for SyncWrap {} + +/// Guest-side notifier (no-op). +#[derive(Clone, Copy)] +pub struct GuestNotifier; + +impl Notifier for GuestNotifier { + fn notify(&self, _stats: QueueStats) {} +} + +/// Type alias for the guest-side producer. +pub type GuestProducer = VirtqProducer>; +/// Global G2H producer instance, initialized during guest init. +static G2H_PRODUCER: SyncWrap>> = SyncWrap(RefCell::new(None)); + +/// Borrow the G2H producer mutably. +/// +/// # Panics +/// +/// Panics if the G2H producer has not been initialized or is already +/// borrowed. +pub fn with_g2h_producer(f: impl FnOnce(&mut GuestProducer) -> R) -> R { + let mut guard = G2H_PRODUCER.0.borrow_mut(); + let producer = guard.as_mut().expect("G2H producer not initialized"); + f(producer) +} + +/// Initialize the G2H producer +/// +/// # Safety +/// +/// The ring GVA must point to valid, zeroed ring memory of the +/// appropriate size. The pool GVA must point to valid, zeroed memory. +pub unsafe fn init_g2h_producer(ring_gva: u64, num_descs: u16, pool_gva: u64, pool_size: usize) { + let nz = NonZeroU16::new(num_descs).expect("G2H queue depth must be non-zero"); + let pool = BufferPool::new(pool_gva, pool_size).expect("failed to create G2H buffer pool"); + + let layout = unsafe { Layout::from_base(ring_gva, nz) }.expect("invalid G2H ring layout"); + let producer = VirtqProducer::new(layout, GuestMemOps, GuestNotifier, Rc::new(pool)); + + *G2H_PRODUCER.0.borrow_mut() = Some(producer); +} diff --git a/src/hyperlight_guest_bin/src/virtq_init.rs b/src/hyperlight_guest_bin/src/virtq_init.rs deleted file mode 100644 index 1f24f5d9b..000000000 --- a/src/hyperlight_guest_bin/src/virtq_init.rs +++ /dev/null @@ -1,50 +0,0 @@ -/* -Copyright 2026 The Hyperlight Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -//! Guest-side virtqueue initialization. -//! -//! The host places virtqueue rings at deterministic offsets in the -//! scratch region and writes ring GVAs and queue depths to scratch-top -//! metadata. - -use hyperlight_common::layout::{ - self, SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET, SCRATCH_TOP_G2H_RING_GVA_OFFSET, - SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET, SCRATCH_TOP_H2G_RING_GVA_OFFSET, -}; -use hyperlight_common::virtq::Layout as VirtqLayout; - -/// Read a value from a scratch-top metadata slot. -unsafe fn read_scratch_top(offset: u64) -> T { - let addr = (layout::MAX_GVA as u64 - offset + 1) as *const T; - unsafe { core::ptr::read_volatile(addr) } -} - -/// Initialize virtqueue ring memory in the scratch region. -pub(crate) fn init_virtqueues() { - let g2h_gva: u64 = unsafe { read_scratch_top(SCRATCH_TOP_G2H_RING_GVA_OFFSET) }; - let g2h_depth: u16 = unsafe { read_scratch_top(SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET) }; - let h2g_gva: u64 = unsafe { read_scratch_top(SCRATCH_TOP_H2G_RING_GVA_OFFSET) }; - let h2g_depth: u16 = unsafe { read_scratch_top(SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET) }; - - assert!(g2h_depth > 0 && h2g_depth > 0); - assert!(g2h_gva != 0 && h2g_gva != 0); - - let size = VirtqLayout::query_size(g2h_depth as usize); - unsafe { core::ptr::write_bytes(g2h_gva as *mut u8, 0, size) }; - - let size = VirtqLayout::query_size(h2g_depth as usize); - unsafe { core::ptr::write_bytes(h2g_gva as *mut u8, 0, size) }; -} diff --git a/src/hyperlight_host/src/mem/mgr.rs b/src/hyperlight_host/src/mem/mgr.rs index 7356e12a0..59f8cb11c 100644 --- a/src/hyperlight_host/src/mem/mgr.rs +++ b/src/hyperlight_host/src/mem/mgr.rs @@ -632,6 +632,10 @@ impl SandboxMemoryManager { scratch_size - SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET as usize, self.layout.sandbox_memory_config.get_h2g_queue_depth() as u16, )?; + self.scratch_mem.write::( + scratch_size - SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET as usize, + self.layout.sandbox_memory_config.get_virtq_pool_pages() as u16, + )?; // Copy page tables from `shared_mem` into scratch. PT bytes // are appended to the snapshot blob at build time and live diff --git a/src/hyperlight_host/src/sandbox/config.rs b/src/hyperlight_host/src/sandbox/config.rs index 120aa06cd..a329e5fd5 100644 --- a/src/hyperlight_host/src/sandbox/config.rs +++ b/src/hyperlight_host/src/sandbox/config.rs @@ -80,6 +80,9 @@ pub struct SandboxConfiguration { /// Number of descriptors for the host-to-guest virtqueue. Must be a power of 2. /// Default: 32 h2g_queue_depth: usize, + /// Number of physical pages to allocate for each virtqueue's buffer pool. + /// Default: 8 pages (32KB). + virtq_pool_pages: usize, } impl SandboxConfiguration { @@ -103,6 +106,8 @@ impl SandboxConfiguration { pub const DEFAULT_G2H_QUEUE_DEPTH: usize = 64; /// The default H2G virtqueue depth (number of descriptors, must be power of 2) pub const DEFAULT_H2G_QUEUE_DEPTH: usize = 32; + /// The default number of physical pages per virtqueue buffer pool + pub const DEFAULT_VIRTQ_POOL_PAGES: usize = 8; #[allow(clippy::too_many_arguments)] /// Create a new configuration for a sandbox with the given sizes. @@ -126,6 +131,7 @@ impl SandboxConfiguration { interrupt_vcpu_sigrtmin_offset, g2h_queue_depth: Self::DEFAULT_G2H_QUEUE_DEPTH, h2g_queue_depth: Self::DEFAULT_H2G_QUEUE_DEPTH, + virtq_pool_pages: Self::DEFAULT_VIRTQ_POOL_PAGES, #[cfg(gdb)] guest_debug_info, #[cfg(crashdump)] @@ -231,6 +237,26 @@ impl SandboxConfiguration { self.h2g_queue_depth } + /// Get the number of physical pages per virtqueue buffer pool. + pub fn get_virtq_pool_pages(&self) -> usize { + self.virtq_pool_pages + } + + /// Set the G2H virtqueue depth (number of descriptors, must be power of 2). + pub fn set_g2h_queue_depth(&mut self, depth: usize) { + self.g2h_queue_depth = depth; + } + + /// Set the H2G virtqueue depth (number of descriptors, must be power of 2). + pub fn set_h2g_queue_depth(&mut self, depth: usize) { + self.h2g_queue_depth = depth; + } + + /// Set the number of physical pages per virtqueue buffer pool. + pub fn set_virtq_pool_pages(&mut self, pages: usize) { + self.virtq_pool_pages = pages; + } + /// Set the size of the scratch regiong #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub fn set_scratch_size(&mut self, scratch_size: usize) { diff --git a/src/hyperlight_host/tests/integration_test.rs b/src/hyperlight_host/tests/integration_test.rs index cc7b7587d..9e7fe2c91 100644 --- a/src/hyperlight_host/tests/integration_test.rs +++ b/src/hyperlight_host/tests/integration_test.rs @@ -544,6 +544,9 @@ fn guest_malloc_abort() { let mut cfg = SandboxConfiguration::default(); cfg.set_heap_size(heap_size); + cfg.set_g2h_queue_depth(2); + cfg.set_h2g_queue_depth(2); + cfg.set_virtq_pool_pages(2); with_rust_sandbox_cfg(cfg, |mut sbox2| { let err = sbox2 .call::( @@ -620,6 +623,9 @@ fn guest_panic_no_alloc() { let mut cfg = SandboxConfiguration::default(); cfg.set_heap_size(heap_size); + cfg.set_g2h_queue_depth(2); + cfg.set_h2g_queue_depth(2); + cfg.set_virtq_pool_pages(2); with_rust_sandbox_cfg(cfg, |mut sbox| { let res = sbox .call::( @@ -1679,7 +1685,9 @@ fn exception_handler_installation_and_validation() { /// This validates that the exception handling path does not require heap allocations. #[test] fn fill_heap_and_cause_exception() { - with_rust_sandbox(|mut sandbox| { + let mut cfg = SandboxConfiguration::default(); + cfg.set_virtq_pool_pages(2); + with_rust_sandbox_cfg(cfg, |mut sandbox| { let result = sandbox.call::<()>("FillHeapAndCauseException", ()); // The call should fail with an exception error since there's no handler installed From 2c8c7df58d4b891e2a880de2abb86c7da2a6a559 Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Thu, 26 Mar 2026 11:02:57 +0100 Subject: [PATCH 05/31] feat(virtq): add reset API Signed-off-by: Tomasz Andrzejak --- src/hyperlight_common/src/virtq/consumer.rs | 50 +++++++++++ src/hyperlight_common/src/virtq/pool.rs | 94 +++++++++++++++++++++ src/hyperlight_common/src/virtq/producer.rs | 48 +++++++++++ src/hyperlight_common/src/virtq/ring.rs | 2 - 4 files changed, 192 insertions(+), 2 deletions(-) diff --git a/src/hyperlight_common/src/virtq/consumer.rs b/src/hyperlight_common/src/virtq/consumer.rs index 4c7bbc9ba..9e4e09527 100644 --- a/src/hyperlight_common/src/virtq/consumer.rs +++ b/src/hyperlight_common/src/virtq/consumer.rs @@ -441,6 +441,12 @@ impl VirtqConsumer { Ok(Bytes::from(buf)) } + + /// Reset ring and inflight state to initial values. + pub fn reset(&mut self) { + self.inner.reset(); + self.inflight.fill(None); + } } /// Parse a descriptor chain into entry/completion buffer elements. @@ -630,4 +636,48 @@ mod tests { assert_eq!(data.as_ref(), b"abc"); consumer.complete(completion).unwrap(); } + + #[test] + fn test_virtq_consumer_reset() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + // Submit and poll (but do not complete) + let se = producer.chain().completion(16).build().unwrap(); + producer.submit(se).unwrap(); + + let (_entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert!(consumer.inflight.iter().any(|s| s.is_some())); + + // Complete first so we do not leak + consumer.complete(completion).unwrap(); + + consumer.reset(); + + assert!(consumer.inflight.iter().all(|s| s.is_none())); + assert_eq!(consumer.inner.num_inflight(), 0); + } + + #[test] + fn test_virtq_consumer_reset_clears_inflight() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + // Submit two entries and poll both + let se1 = producer.chain().completion(16).build().unwrap(); + producer.submit(se1).unwrap(); + let se2 = producer.chain().completion(16).build().unwrap(); + producer.submit(se2).unwrap(); + + let (_e1, c1) = consumer.poll(1024).unwrap().unwrap(); + let (_e2, c2) = consumer.poll(1024).unwrap().unwrap(); + // Complete both before reset + consumer.complete(c1).unwrap(); + consumer.complete(c2).unwrap(); + + consumer.reset(); + + assert!(consumer.inflight.iter().all(|s| s.is_none())); + assert_eq!(consumer.inner.num_inflight(), 0); + } } diff --git a/src/hyperlight_common/src/virtq/pool.rs b/src/hyperlight_common/src/virtq/pool.rs index 99d8f4109..83178998d 100644 --- a/src/hyperlight_common/src/virtq/pool.rs +++ b/src/hyperlight_common/src/virtq/pool.rs @@ -516,6 +516,12 @@ impl Slab { pub const fn slot_size() -> usize { N } + + /// Reset the slab to initial state which is all slots free. + pub fn reset(&mut self) { + self.used_slots.clear(); + self.last_free_run = None; + } } #[inline] @@ -547,6 +553,13 @@ impl BufferPool { inner: inner.into(), }) } + + /// Reset the pool to initial state + pub fn reset(&self) { + let mut inner = self.inner.borrow_mut(); + inner.lower.reset(); + inner.upper.reset(); + } } #[cfg(all(test, loom))] @@ -1171,6 +1184,87 @@ mod tests { slab.dealloc(a3).unwrap(); slab.dealloc(a4).unwrap(); } + + #[test] + fn test_slab_reset_returns_to_initial_state() { + let mut slab = make_slab::<256>(4096); + let initial_free = slab.free_bytes(); + let initial_cap = slab.capacity(); + + // Allocate some slots + let _a1 = slab.alloc(256).unwrap(); + let _a2 = slab.alloc(512).unwrap(); + assert!(slab.free_bytes() < initial_free); + + slab.reset(); + + assert_eq!(slab.free_bytes(), initial_free); + assert_eq!(slab.capacity(), initial_cap); + assert!(slab.last_free_run.is_none()); + assert_eq!(slab.used_slots.count_ones(..), 0); + + // Should be able to allocate the full capacity again + let a = slab.alloc(initial_cap).unwrap(); + assert_eq!(a.len, initial_cap); + } + + #[test] + fn test_slab_reset_matches_new() { + let base = align_up(0x10000, 256) as u64; + let region = 4096; + + let fresh = Slab::<256>::new(base, region).unwrap(); + + let mut used = Slab::<256>::new(base, region).unwrap(); + let _a = used.alloc(256).unwrap(); + let _b = used.alloc(1024).unwrap(); + used.reset(); + + assert_eq!(used.free_bytes(), fresh.free_bytes()); + assert_eq!(used.capacity(), fresh.capacity()); + assert_eq!( + used.used_slots.count_ones(..), + fresh.used_slots.count_ones(..) + ); + assert!(used.last_free_run.is_none()); + assert!(fresh.last_free_run.is_none()); + } + + #[test] + fn test_buffer_pool_reset_returns_to_initial_state() { + let pool = make_pool::<256, 4096>(0x20000); + + // Allocate from both tiers + let a1 = pool.inner.borrow_mut().alloc(128).unwrap(); + let a2 = pool.inner.borrow_mut().alloc(8192).unwrap(); + assert!(a1.len > 0); + assert!(a2.len > 0); + + pool.reset(); + + let inner = pool.inner.borrow(); + assert_eq!(inner.lower.used_slots.count_ones(..), 0); + assert_eq!(inner.upper.used_slots.count_ones(..), 0); + assert!(inner.lower.last_free_run.is_none()); + assert!(inner.upper.last_free_run.is_none()); + } + + #[test] + fn test_buffer_pool_reset_allows_reallocation() { + let pool = make_pool::<256, 4096>(0x20000); + + // Fill up some allocations + let mut allocs = Vec::new(); + for _ in 0..5 { + allocs.push(pool.inner.borrow_mut().alloc(256).unwrap()); + } + + pool.reset(); + + // Should be able to allocate as if fresh + let a = pool.inner.borrow_mut().alloc(256).unwrap(); + assert!(a.len > 0); + } } #[cfg(test)] diff --git a/src/hyperlight_common/src/virtq/producer.rs b/src/hyperlight_common/src/virtq/producer.rs index 95db0b7ba..28c5dbf3a 100644 --- a/src/hyperlight_common/src/virtq/producer.rs +++ b/src/hyperlight_common/src/virtq/producer.rs @@ -329,6 +329,13 @@ where } Ok(()) } + + /// Reset ring and inflight state to initial values. + /// Does not reset the buffer pool; call pool.reset() separately if needed. + pub fn reset(&mut self) { + self.inner.reset(); + self.inflight.fill(None); + } } /// Builder for configuring a descriptor chain's buffer layout. @@ -787,4 +794,45 @@ mod tests { assert_eq!(cqe.token, token); assert_eq!(&cqe.data[..], b"response data"); } + + #[test] + fn test_virtq_producer_reset() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + // Submit and complete a round trip + let mut se = producer.chain().entry(32).completion(64).build().unwrap(); + se.write_all(b"hello").unwrap(); + producer.submit(se).unwrap(); + + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.data().as_ref(), b"hello"); + consumer.complete(completion).unwrap(); + let _ = producer.poll().unwrap().unwrap(); + + // Now reset + producer.reset(); + + // All inflight slots should be None + assert!(producer.inflight.iter().all(|s| s.is_none())); + // Ring state should be back to initial + assert_eq!(producer.inner.num_free(), producer.inner.len()); + } + + #[test] + fn test_virtq_producer_reset_clears_inflight() { + let ring = make_ring(16); + let (mut producer, _consumer, _notifier) = make_test_producer(&ring); + + // Submit without completing + let se = producer.chain().completion(64).build().unwrap(); + producer.submit(se).unwrap(); + + assert!(producer.inflight.iter().any(|s| s.is_some())); + + producer.reset(); + + assert!(producer.inflight.iter().all(|s| s.is_none())); + assert_eq!(producer.inner.num_free(), producer.inner.len()); + } } diff --git a/src/hyperlight_common/src/virtq/ring.rs b/src/hyperlight_common/src/virtq/ring.rs index 66791d2f0..978c345b5 100644 --- a/src/hyperlight_common/src/virtq/ring.rs +++ b/src/hyperlight_common/src/virtq/ring.rs @@ -910,7 +910,6 @@ impl RingProducer { self.id_num.iter_mut().for_each(|n| *n = 0); self.event_flags_shadow = EventFlags::ENABLE; } - /// Reset the ring to the "N slots submitted, none completed" state. /// /// `ids` contains the descriptor IDs that are in-flight. @@ -3258,7 +3257,6 @@ pub(crate) mod tests { consumer.reset(); assert_eq!(consumer.num_inflight, 0); } - #[test] fn test_reset_prefilled_sets_cursors() { let ring = make_ring(8); From f07523cd3514cdd5856c52de1b826ae0434955fb Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Thu, 26 Mar 2026 14:43:41 +0100 Subject: [PATCH 06/31] feat(virtq): replace guest-to-host calls with virtqueue Signed-off-by: Tomasz Andrzejak --- Cargo.lock | 1 + src/hyperlight_common/src/layout.rs | 4 +- src/hyperlight_guest/Cargo.toml | 1 + src/hyperlight_guest/src/error.rs | 25 ++- src/hyperlight_guest/src/lib.rs | 2 +- src/hyperlight_guest/src/virtq/context.rs | 159 ++++++++++++++++++ .../src/{virtq_mem.rs => virtq/mem.rs} | 28 ++- src/hyperlight_guest/src/virtq/mod.rs | 99 +++++++++++ src/hyperlight_guest_bin/Cargo.toml | 4 +- .../src/guest_function/call.rs | 5 + src/hyperlight_guest_bin/src/host_comm.rs | 13 +- src/hyperlight_guest_bin/src/virtq/mod.rs | 19 +-- src/hyperlight_guest_bin/src/virtq/state.rs | 75 --------- src/hyperlight_host/src/mem/mgr.rs | 68 +++++++- src/hyperlight_host/src/mem/virtq_mem.rs | 12 +- src/hyperlight_host/src/sandbox/outb.rs | 71 +++++++- src/tests/rust_guests/dummyguest/Cargo.lock | 1 + src/tests/rust_guests/simpleguest/Cargo.lock | 1 + src/tests/rust_guests/witguest/Cargo.lock | 1 + 19 files changed, 467 insertions(+), 122 deletions(-) create mode 100644 src/hyperlight_guest/src/virtq/context.rs rename src/hyperlight_guest/src/{virtq_mem.rs => virtq/mem.rs} (69%) create mode 100644 src/hyperlight_guest/src/virtq/mod.rs delete mode 100644 src/hyperlight_guest_bin/src/virtq/state.rs diff --git a/Cargo.lock b/Cargo.lock index 887ece62d..a80ae0959 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1556,6 +1556,7 @@ name = "hyperlight-guest" version = "0.15.0" dependencies = [ "anyhow", + "bytemuck", "flatbuffers", "hyperlight-common", "hyperlight-guest-tracing", diff --git a/src/hyperlight_common/src/layout.rs b/src/hyperlight_common/src/layout.rs index 70622343b..8f7cea9c9 100644 --- a/src/hyperlight_common/src/layout.rs +++ b/src/hyperlight_common/src/layout.rs @@ -42,6 +42,7 @@ pub const SCRATCH_TOP_H2G_RING_GVA_OFFSET: u64 = 0x30; pub const SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET: u64 = 0x38; pub const SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET: u64 = 0x3a; pub const SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET: u64 = 0x3c; +pub const SCRATCH_TOP_VIRTQ_GENERATION_OFFSET: u64 = 0x3e; pub const SCRATCH_TOP_EXN_STACK_OFFSET: u64 = 0x40; const _: () = { @@ -53,7 +54,8 @@ const _: () = { assert!(SCRATCH_TOP_H2G_RING_GVA_OFFSET + 8 <= SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET); assert!(SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET + 2 <= SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET); assert!(SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET + 2 <= SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET); - assert!(SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET + 2 <= SCRATCH_TOP_EXN_STACK_OFFSET); + assert!(SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET + 2 <= SCRATCH_TOP_VIRTQ_GENERATION_OFFSET); + assert!(SCRATCH_TOP_VIRTQ_GENERATION_OFFSET + 2 <= SCRATCH_TOP_EXN_STACK_OFFSET); assert!(SCRATCH_TOP_EXN_STACK_OFFSET % 0x10 == 0); }; diff --git a/src/hyperlight_guest/Cargo.toml b/src/hyperlight_guest/Cargo.toml index d9de514ae..3ab158a7e 100644 --- a/src/hyperlight_guest/Cargo.toml +++ b/src/hyperlight_guest/Cargo.toml @@ -15,6 +15,7 @@ Provides only the essential building blocks for interacting with the host enviro anyhow = { version = "1.0.102", default-features = false } serde_json = { version = "1.0", default-features = false, features = ["alloc"] } hyperlight-common = { workspace = true, default-features = false } +bytemuck = { version = "1.24", features = ["derive"] } flatbuffers = { version= "25.12.19", default-features = false } tracing = { version = "0.1.44", default-features = false, features = ["attributes"] } diff --git a/src/hyperlight_guest/src/error.rs b/src/hyperlight_guest/src/error.rs index 9f256684b..62ca01bda 100644 --- a/src/hyperlight_guest/src/error.rs +++ b/src/hyperlight_guest/src/error.rs @@ -17,10 +17,11 @@ limitations under the License. use alloc::format; use alloc::string::{String, ToString as _}; -use anyhow; -pub use hyperlight_common::flatbuffer_wrappers::guest_error::ErrorCode; +pub(crate) use hyperlight_common::flatbuffer_wrappers::guest_error::ErrorCode; +use hyperlight_common::flatbuffer_wrappers::guest_error::GuestError; use hyperlight_common::func::Error as FuncError; -use serde_json; +use hyperlight_common::virtq::VirtqError; +use {anyhow, serde_json}; pub type Result = core::result::Result; @@ -81,6 +82,24 @@ impl From for HyperlightGuestError { } } +impl From for HyperlightGuestError { + fn from(e: VirtqError) -> Self { + Self { + kind: ErrorCode::GuestError, + message: format!("virtq: {e}"), + } + } +} + +impl From for HyperlightGuestError { + fn from(e: GuestError) -> Self { + Self { + kind: e.code, + message: e.message, + } + } +} + /// Extension trait to add context to `Option` and `Result` types in guest code, /// converting them to `Result`. /// diff --git a/src/hyperlight_guest/src/lib.rs b/src/hyperlight_guest/src/lib.rs index 1aa456b31..9cf64280d 100644 --- a/src/hyperlight_guest/src/lib.rs +++ b/src/hyperlight_guest/src/lib.rs @@ -26,7 +26,7 @@ pub mod exit; pub mod layout; pub mod prim_alloc; pub mod types; -pub mod virtq_mem; +pub mod virtq; pub mod guest_handle { pub mod handle; diff --git a/src/hyperlight_guest/src/virtq/context.rs b/src/hyperlight_guest/src/virtq/context.rs new file mode 100644 index 000000000..ea846cae7 --- /dev/null +++ b/src/hyperlight_guest/src/virtq/context.rs @@ -0,0 +1,159 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Guest virtqueue context. + +use alloc::sync::Arc; +use alloc::vec::Vec; +use core::num::NonZeroU16; +use core::sync::atomic::AtomicU16; +use core::sync::atomic::Ordering::Relaxed; + +use flatbuffers::FlatBufferBuilder; +use hyperlight_common::flatbuffer_wrappers::function_call::{FunctionCall, FunctionCallType}; +use hyperlight_common::flatbuffer_wrappers::function_types::{ + FunctionCallResult, ParameterValue, ReturnType, ReturnValue, +}; +use hyperlight_common::flatbuffer_wrappers::util::estimate_flatbuffer_capacity; +use hyperlight_common::outb::OutBAction; +use hyperlight_common::virtq::msg::{MsgKind, VirtqMsgHeader}; +use hyperlight_common::virtq::{BufferPool, Layout, Notifier, QueueStats, VirtqProducer}; + +use super::GuestMemOps; +use crate::bail; +use crate::error::Result; + +static REQUEST_ID: AtomicU16 = AtomicU16::new(0); +const MAX_RESPONSE_CAP: usize = 4096; + +/// Guest-side notifier that triggers a VM exit via outb. +#[derive(Clone, Copy)] +pub struct GuestNotifier; + +impl Notifier for GuestNotifier { + fn notify(&self, _stats: QueueStats) { + unsafe { crate::exit::out32(OutBAction::VirtqNotify as u16, 0) }; + } +} + +/// Type alias for the guest-side G2H producer. +pub type G2hProducer = VirtqProducer>; + +/// Virtqueue runtime state for guest-host communication. +pub struct GuestContext { + g2h_pool: Arc, + g2h_producer: G2hProducer, + generation: u64, +} + +impl GuestContext { + /// Create a new context with a G2H queue. + /// + /// # Safety + /// + /// `ring_gva` must point to valid, zeroed ring memory. + /// `pool_gva` must point to valid, zeroed memory. + pub unsafe fn new( + ring_gva: u64, + num_descs: u16, + pool_gva: u64, + pool_size: usize, + generation: u64, + ) -> Self { + let pool = Arc::new( + BufferPool::new(pool_gva, pool_size).expect("failed to create G2H buffer pool"), + ); + let nz = NonZeroU16::new(num_descs).expect("G2H queue depth must be non-zero"); + let layout = unsafe { Layout::from_base(ring_gva, nz) }.expect("invalid G2H ring layout"); + let producer = VirtqProducer::new(layout, GuestMemOps, GuestNotifier, pool.clone()); + + Self { + g2h_pool: pool, + g2h_producer: producer, + generation, + } + } + + /// Call a host function via the G2H virtqueue. + pub fn call_host_function>( + &mut self, + function_name: &str, + parameters: Option>, + return_type: ReturnType, + ) -> Result { + let params = parameters.as_deref().unwrap_or_default(); + let estimated_capacity = estimate_flatbuffer_capacity(function_name, params); + + let fc = FunctionCall::new( + function_name.into(), + parameters, + FunctionCallType::Host, + return_type, + ); + + let mut builder = FlatBufferBuilder::with_capacity(estimated_capacity); + let payload = fc.encode(&mut builder); + + let reqid = REQUEST_ID.fetch_add(1, Relaxed); + let hdr = VirtqMsgHeader::new(MsgKind::Request, reqid, payload.len() as u32); + let hdr_bytes = bytemuck::bytes_of(&hdr); + + let entry_len = VirtqMsgHeader::SIZE + payload.len(); + + let mut entry = self + .g2h_producer + .chain() + .entry(entry_len) + .completion(MAX_RESPONSE_CAP) + .build()?; + + entry.write_all(hdr_bytes)?; + entry.write_all(payload)?; + self.g2h_producer.submit(entry)?; + + let Some(completion) = self.g2h_producer.poll()? else { + bail!("G2H: no completion received"); + }; + + let result_bytes = &completion.data; + if result_bytes.len() > MAX_RESPONSE_CAP { + bail!("G2H: response is too large"); + } + + let payload_bytes = &result_bytes[VirtqMsgHeader::SIZE..]; + let Ok(fcr) = FunctionCallResult::try_from(payload_bytes) else { + bail!("G2H: malformed response"); + }; + + let ret = fcr.into_inner()?; + let Ok(ret) = T::try_from(ret) else { + bail!("G2H: host return value type mismatch"); + }; + + Ok(ret) + } + + /// Reset ring and pool state after snapshot restore. + pub(super) fn reset(&mut self, new_generation: u64) { + self.g2h_producer.reset(); + self.g2h_pool.reset(); + self.generation = new_generation; + } + + pub(super) fn generation(&self) -> u64 { + self.generation + } +} diff --git a/src/hyperlight_guest/src/virtq_mem.rs b/src/hyperlight_guest/src/virtq/mem.rs similarity index 69% rename from src/hyperlight_guest/src/virtq_mem.rs rename to src/hyperlight_guest/src/virtq/mem.rs index 8309deb79..16375c868 100644 --- a/src/hyperlight_guest/src/virtq_mem.rs +++ b/src/hyperlight_guest/src/virtq/mem.rs @@ -27,33 +27,27 @@ use hyperlight_common::virtq::MemOps; #[derive(Clone, Copy, Debug)] pub struct GuestMemOps; -impl MemOps for GuestMemOps { +// SAFETY: GuestMemOps treats virtqueue addresses as directly mapped guest +// virtual addresses and performs the required acquire/release operations. +unsafe impl MemOps for GuestMemOps { type Error = Infallible; - fn read(&self, addr: u64, dst: &mut [u8]) -> Result { - let src = addr as *const u8; - unsafe { - ptr::copy_nonoverlapping(src, dst.as_mut_ptr(), dst.len()); - } - Ok(dst.len()) + fn read(&self, addr: u64, dst: &mut [u8]) -> Result<(), Self::Error> { + unsafe { ptr::copy_nonoverlapping(addr as *const u8, dst.as_mut_ptr(), dst.len()) }; + Ok(()) } - fn write(&self, addr: u64, src: &[u8]) -> Result { - let dst = addr as *mut u8; - unsafe { - ptr::copy_nonoverlapping(src.as_ptr(), dst, src.len()); - } - Ok(src.len()) + fn write(&self, addr: u64, src: &[u8]) -> Result<(), Self::Error> { + unsafe { ptr::copy_nonoverlapping(src.as_ptr(), addr as *mut u8, src.len()) }; + Ok(()) } fn load_acquire(&self, addr: u64) -> Result { - let ptr = addr as *const AtomicU16; - Ok(unsafe { (*ptr).load(Ordering::Acquire) }) + Ok(unsafe { (*(addr as *const AtomicU16)).load(Ordering::Acquire) }) } fn store_release(&self, addr: u64, val: u16) -> Result<(), Self::Error> { - let ptr = addr as *const AtomicU16; - unsafe { (*ptr).store(val, Ordering::Release) }; + unsafe { (*(addr as *const AtomicU16)).store(val, Ordering::Release) }; Ok(()) } diff --git a/src/hyperlight_guest/src/virtq/mod.rs b/src/hyperlight_guest/src/virtq/mod.rs new file mode 100644 index 000000000..6aafc75e1 --- /dev/null +++ b/src/hyperlight_guest/src/virtq/mod.rs @@ -0,0 +1,99 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Guest-side virtqueue support. +//! +//! Global context is installed once via [`set_global_context`] and +//! accessed via [`with_context`]. + +pub mod context; +pub mod mem; + +use core::cell::UnsafeCell; +use core::sync::atomic::{AtomicU8, Ordering}; + +use context::GuestContext; +use hyperlight_common::layout::{SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET, scratch_top_ptr}; +pub use mem::GuestMemOps; + +// Init state machine +const UNINITIALIZED: u8 = 0; +const INITIALIZED: u8 = 1; +static INIT_STATE: AtomicU8 = AtomicU8::new(UNINITIALIZED); + +/// Check if the global context has been initialized. +pub fn is_initialized() -> bool { + INIT_STATE.load(Ordering::Acquire) == INITIALIZED +} + +// Storage: UnsafeCell guarded by atomic init state. +struct SyncWrap(T); +unsafe impl Sync for SyncWrap {} + +static GLOBAL_CONTEXT: SyncWrap>> = SyncWrap(UnsafeCell::new(None)); + +/// Access the global guest context via closure. +/// +/// # Panics +/// +/// Panics if the context has not been initialized. +pub fn with_context(f: impl FnOnce(&mut GuestContext) -> R) -> R { + assert!( + INIT_STATE.load(Ordering::Acquire) == INITIALIZED, + "guest context not initialized" + ); + let ctx = unsafe { &mut *GLOBAL_CONTEXT.0.get() }; + f(ctx.as_mut().unwrap()) +} + +/// Install the global guest context. Called once during guest init. +/// +/// # Panics +/// +/// Panics if called more than once. +pub fn set_global_context(ctx: GuestContext) { + if INIT_STATE + .compare_exchange( + UNINITIALIZED, + INITIALIZED, + Ordering::SeqCst, + Ordering::SeqCst, + ) + .is_err() + { + panic!("guest context already initialized"); + } + unsafe { *GLOBAL_CONTEXT.0.get() = Some(ctx) }; +} + +/// Reset the global context if a snapshot restore was detected. +/// Compares the virtq generation counter in scratch-top metadata. +pub fn reset_global_context() { + if !is_initialized() { + return; + } + let current_gen = read_gen(); + with_context(|ctx| { + if current_gen != ctx.generation() { + ctx.reset(current_gen); + } + }); +} + +/// Read the current virtqueue generation from scratch-top metadata. +fn read_gen() -> u64 { + unsafe { *scratch_top_ptr::(SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET) } +} diff --git a/src/hyperlight_guest_bin/Cargo.toml b/src/hyperlight_guest_bin/Cargo.toml index 060f67b88..0e4509865 100644 --- a/src/hyperlight_guest_bin/Cargo.toml +++ b/src/hyperlight_guest_bin/Cargo.toml @@ -13,8 +13,10 @@ and third-party code used by our C-API needed to build a native hyperlight-guest """ [features] -default = ["libc", "macros"] +default = ["libc", "printf", "macros", "virtq"] libc = ["dep:hyperlight-libc"] # compile libc from picolibc +printf = [ "libc" ] # compile printf +virtq = [] # use virtqueue for guest-to-host calls trace_guest = ["hyperlight-common/trace_guest", "hyperlight-guest/trace_guest", "hyperlight-guest-tracing/trace"] mem_profile = ["hyperlight-common/mem_profile"] macros = ["dep:hyperlight-guest-macro", "dep:linkme"] diff --git a/src/hyperlight_guest_bin/src/guest_function/call.rs b/src/hyperlight_guest_bin/src/guest_function/call.rs index 82874c659..5db880f8a 100644 --- a/src/hyperlight_guest_bin/src/guest_function/call.rs +++ b/src/hyperlight_guest_bin/src/guest_function/call.rs @@ -23,6 +23,7 @@ use hyperlight_common::flatbuffer_wrappers::function_types::{FunctionCallResult, use hyperlight_common::flatbuffer_wrappers::guest_error::{ErrorCode, GuestError}; use hyperlight_guest::bail; use hyperlight_guest::error::{HyperlightGuestError, Result}; +use hyperlight_guest::virtq; use tracing::instrument; use crate::{GUEST_HANDLE, REGISTERED_GUEST_FUNCTIONS}; @@ -100,6 +101,10 @@ pub(crate) fn internal_dispatch_function() { let handle = unsafe { GUEST_HANDLE }; + // After snapshot restore, the ring memory is zeroed but the + // producer's cursors are stale. Check once per dispatch entry. + virtq::reset_global_context(); + let function_call = handle .try_pop_shared_input_data_into::() .expect("Function call deserialization failed"); diff --git a/src/hyperlight_guest_bin/src/host_comm.rs b/src/hyperlight_guest_bin/src/host_comm.rs index 301462313..a2dbf77e6 100644 --- a/src/hyperlight_guest_bin/src/host_comm.rs +++ b/src/hyperlight_guest_bin/src/host_comm.rs @@ -36,8 +36,17 @@ pub fn call_host_function( where T: TryFrom, { - let handle = unsafe { GUEST_HANDLE }; - handle.call_host_function::(function_name, parameters, return_type) + #[cfg(feature = "virtq")] + { + hyperlight_guest::virtq::with_context(|ctx| { + ctx.call_host_function(function_name, parameters, return_type) + }) + } + #[cfg(not(feature = "virtq"))] + { + let handle = unsafe { GUEST_HANDLE }; + handle.call_host_function::(function_name, parameters, return_type) + } } pub fn call_host(function_name: impl AsRef, args: impl ParameterTuple) -> Result diff --git a/src/hyperlight_guest_bin/src/virtq/mod.rs b/src/hyperlight_guest_bin/src/virtq/mod.rs index 50c0dd6d9..997044dde 100644 --- a/src/hyperlight_guest_bin/src/virtq/mod.rs +++ b/src/hyperlight_guest_bin/src/virtq/mod.rs @@ -15,30 +15,27 @@ limitations under the License. */ //! Guest-side virtqueue initialization. -//! -//! Zeroes ring memory and creates VirtqProducer instances by allocating -//! buffer pool pages from the scratch page allocator. - -pub(crate) mod state; use hyperlight_common::layout::{ SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET, SCRATCH_TOP_G2H_RING_GVA_OFFSET, SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET, SCRATCH_TOP_H2G_RING_GVA_OFFSET, - SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET, scratch_top_ptr, + SCRATCH_TOP_VIRTQ_GENERATION_OFFSET, SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET, scratch_top_ptr, }; use hyperlight_common::mem::PAGE_SIZE_USIZE; use hyperlight_common::virtq::Layout as VirtqLayout; use hyperlight_guest::prim_alloc::alloc_phys_pages; +use hyperlight_guest::virtq::context::GuestContext; use crate::paging::phys_to_virt; -/// Initialize virtqueue producers for G2H and H2G queues. +/// Initialize virtqueue context. pub(crate) fn init_virtqueues() { let g2h_gva = unsafe { *scratch_top_ptr::(SCRATCH_TOP_G2H_RING_GVA_OFFSET) }; let g2h_depth = unsafe { *scratch_top_ptr::(SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET) }; let h2g_gva = unsafe { *scratch_top_ptr::(SCRATCH_TOP_H2G_RING_GVA_OFFSET) }; let h2g_depth = unsafe { *scratch_top_ptr::(SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET) }; let pool_pages = unsafe { *scratch_top_ptr::(SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET) } as u64; + let generation = unsafe { *scratch_top_ptr::(SCRATCH_TOP_VIRTQ_GENERATION_OFFSET) }; assert!(g2h_depth > 0 && h2g_depth > 0); assert!(g2h_gva != 0 && h2g_gva != 0); @@ -58,11 +55,9 @@ pub(crate) fn init_virtqueues() { let pool_size = pool_pages as usize * PAGE_SIZE_USIZE; unsafe { core::ptr::write_bytes(pool_ptr, 0, pool_size) }; - // Create G2H producer - unsafe { - state::init_g2h_producer(g2h_gva, g2h_depth, pool_gva, pool_size); - } + // Create and install global context + let ctx = unsafe { GuestContext::new(g2h_gva, g2h_depth, pool_gva, pool_size, generation.into()) }; + hyperlight_guest::virtq::set_global_context(ctx); - // TODO(virtq): add other direction's producer let _ = (h2g_gva, h2g_depth); } diff --git a/src/hyperlight_guest_bin/src/virtq/state.rs b/src/hyperlight_guest_bin/src/virtq/state.rs deleted file mode 100644 index 232726377..000000000 --- a/src/hyperlight_guest_bin/src/virtq/state.rs +++ /dev/null @@ -1,75 +0,0 @@ -/* -Copyright 2026 The Hyperlight Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -//! Guest-side virtqueue state and initialization. -//! -//! Holds the global VirtqProducer instances for G2H and H2G queues. -//! The producers are created during guest init (from `hyperlight_guest_bin`) -//! and used by the guest host-call path in `host_comm`. - -use alloc::rc::Rc; -use core::cell::RefCell; -use core::num::NonZeroU16; - -use hyperlight_common::virtq::{BufferPool, Layout, Notifier, QueueStats, VirtqProducer}; -use hyperlight_guest::virtq_mem::GuestMemOps; - -/// Wrapper to mark types as Sync for single-threaded guest execution. -struct SyncWrap(T); - -// SAFETY: guest execution is single-threaded. -unsafe impl Sync for SyncWrap {} - -/// Guest-side notifier (no-op). -#[derive(Clone, Copy)] -pub struct GuestNotifier; - -impl Notifier for GuestNotifier { - fn notify(&self, _stats: QueueStats) {} -} - -/// Type alias for the guest-side producer. -pub type GuestProducer = VirtqProducer>; -/// Global G2H producer instance, initialized during guest init. -static G2H_PRODUCER: SyncWrap>> = SyncWrap(RefCell::new(None)); - -/// Borrow the G2H producer mutably. -/// -/// # Panics -/// -/// Panics if the G2H producer has not been initialized or is already -/// borrowed. -pub fn with_g2h_producer(f: impl FnOnce(&mut GuestProducer) -> R) -> R { - let mut guard = G2H_PRODUCER.0.borrow_mut(); - let producer = guard.as_mut().expect("G2H producer not initialized"); - f(producer) -} - -/// Initialize the G2H producer -/// -/// # Safety -/// -/// The ring GVA must point to valid, zeroed ring memory of the -/// appropriate size. The pool GVA must point to valid, zeroed memory. -pub unsafe fn init_g2h_producer(ring_gva: u64, num_descs: u16, pool_gva: u64, pool_size: usize) { - let nz = NonZeroU16::new(num_descs).expect("G2H queue depth must be non-zero"); - let pool = BufferPool::new(pool_gva, pool_size).expect("failed to create G2H buffer pool"); - - let layout = unsafe { Layout::from_base(ring_gva, nz) }.expect("invalid G2H ring layout"); - let producer = VirtqProducer::new(layout, GuestMemOps, GuestNotifier, Rc::new(pool)); - - *G2H_PRODUCER.0.borrow_mut() = Some(producer); -} diff --git a/src/hyperlight_host/src/mem/mgr.rs b/src/hyperlight_host/src/mem/mgr.rs index 59f8cb11c..b35077a16 100644 --- a/src/hyperlight_host/src/mem/mgr.rs +++ b/src/hyperlight_host/src/mem/mgr.rs @@ -21,6 +21,20 @@ use flatbuffers::FlatBufferBuilder; use hyperlight_common::flatbuffer_wrappers::function_call::{ FunctionCall, validate_guest_function_call_buffer, }; + +use super::virtq_mem::HostMemOps; + +/// No-op notifier for host-side consumer. +/// The host resumes the VM to notify the guest, not via the ring. +#[derive(Clone, Copy)] +pub(crate) struct HostNotifier; + +impl hyperlight_common::virtq::Notifier for HostNotifier { + fn notify(&self, _stats: hyperlight_common::virtq::QueueStats) {} +} + +/// Type alias for the host-side G2H virtqueue consumer. +pub(crate) type G2hConsumer = hyperlight_common::virtq::VirtqConsumer; use hyperlight_common::flatbuffer_wrappers::function_types::FunctionCallResult; use hyperlight_common::flatbuffer_wrappers::guest_log_data::GuestLogData; use hyperlight_common::virtq::Layout as VirtqLayout; @@ -136,7 +150,6 @@ impl ReadonlySharedMemory { pub(crate) use unused_hack::SnapshotSharedMemory; /// A struct that is responsible for laying out and managing the memory /// for a given `Sandbox`. -#[derive(Clone)] pub(crate) struct SandboxMemoryManager { /// Shared memory for the Sandbox pub(crate) shared_mem: SnapshotSharedMemory, @@ -156,6 +169,23 @@ pub(crate) struct SandboxMemoryManager { /// restored snapshot's own generation number so the guest-visible /// counter tracks which snapshot the sandbox is a clone of. pub(crate) snapshot_count: u64, + /// G2H virtqueue consumer, created after sandbox init. + pub(crate) g2h_consumer: Option, +} + +impl Clone for SandboxMemoryManager { + fn clone(&self) -> Self { + Self { + shared_mem: self.shared_mem.clone(), + scratch_mem: self.scratch_mem.clone(), + layout: self.layout, + entrypoint: self.entrypoint, + mapped_rgns: self.mapped_rgns, + abort_buffer: self.abort_buffer.clone(), + snapshot_count: self.snapshot_count, + g2h_consumer: None, // consumer is not cloned; re-init if needed + } + } } /// Buffer for building guest page tables during snapshot creation. @@ -291,6 +321,7 @@ where mapped_rgns: 0, abort_buffer: Vec::new(), snapshot_count: 0, + g2h_consumer: None, } } @@ -361,6 +392,7 @@ impl SandboxMemoryManager { mapped_rgns: self.mapped_rgns, abort_buffer: self.abort_buffer, snapshot_count: self.snapshot_count, + g2h_consumer: None, }; let guest_mgr = SandboxMemoryManager { shared_mem: gshm, @@ -370,8 +402,10 @@ impl SandboxMemoryManager { mapped_rgns: self.mapped_rgns, abort_buffer: Vec::new(), // Guest doesn't need abort buffer snapshot_count: self.snapshot_count, + g2h_consumer: None, }; host_mgr.update_scratch_bookkeeping()?; + host_mgr.init_g2h_consumer()?; Ok((host_mgr, guest_mgr)) } } @@ -568,6 +602,7 @@ impl SandboxMemoryManager { self.snapshot_count = snapshot.snapshot_generation(); self.update_scratch_bookkeeping()?; + self.init_g2h_consumer()?; Ok((gsnapshot, gscratch)) } @@ -636,6 +671,11 @@ impl SandboxMemoryManager { scratch_size - SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET as usize, self.layout.sandbox_memory_config.get_virtq_pool_pages() as u16, )?; + // Increment generation so the guest detects stale ring state. + let gen_offset = scratch_size - SCRATCH_TOP_VIRTQ_GENERATION_OFFSET as usize; + let gen_val: u16 = self.scratch_mem.read(gen_offset).unwrap_or(0); + self.scratch_mem + .write::(gen_offset, gen_val.wrapping_add(1))?; // Copy page tables from `shared_mem` into scratch. PT bytes // are appended to the snapshot blob at build time and live @@ -903,6 +943,32 @@ impl SandboxMemoryManager { unsafe { VirtqLayout::from_base(base, nz) } .map_err(|e| new_error!("Invalid H2G virtq layout: {:?}", e)) } + + /// Create a [`HostMemOps`] instance backed by this manager's + /// scratch shared memory. + pub(crate) fn host_mem_ops(&self) -> HostMemOps { + let scratch_base_gva = + hyperlight_common::layout::scratch_base_gva(self.scratch_mem.mem_size()); + HostMemOps::new(&self.scratch_mem, scratch_base_gva) + } + + /// Initialize the G2H virtqueue consumer. + /// Must be called after scratch bookkeeping is written. + pub(crate) fn init_g2h_consumer(&mut self) -> Result<()> { + match &mut self.g2h_consumer { + Some(consumer) => { + consumer.reset(); + } + None => { + let layout = self.g2h_virtq_layout()?; + let mem_ops = self.host_mem_ops(); + let consumer = + hyperlight_common::virtq::VirtqConsumer::new(layout, mem_ops, HostNotifier); + self.g2h_consumer = Some(consumer); + } + } + Ok(()) + } } #[cfg(test)] diff --git a/src/hyperlight_host/src/mem/virtq_mem.rs b/src/hyperlight_host/src/mem/virtq_mem.rs index f96674c1d..8f01c523b 100644 --- a/src/hyperlight_host/src/mem/virtq_mem.rs +++ b/src/hyperlight_host/src/mem/virtq_mem.rs @@ -82,23 +82,25 @@ impl HostMemOps { } } -impl MemOps for HostMemOps { +// SAFETY: HostMemOps bounds-checks guest addresses against scratch memory before +// accessing them and uses atomic operations for acquire/release accesses. +unsafe impl MemOps for HostMemOps { type Error = HostMemError; - fn read(&self, addr: u64, dst: &mut [u8]) -> Result { + fn read(&self, addr: u64, dst: &mut [u8]) -> Result<(), Self::Error> { let offset = self.to_offset(addr)?; self.scratch .copy_to_slice(dst, offset) .map_err(|e| HostMemError::SharedMem(e.to_string()))?; - Ok(dst.len()) + Ok(()) } - fn write(&self, addr: u64, src: &[u8]) -> Result { + fn write(&self, addr: u64, src: &[u8]) -> Result<(), Self::Error> { let offset = self.to_offset(addr)?; self.scratch .copy_from_slice(src, offset) .map_err(|e| HostMemError::SharedMem(e.to_string()))?; - Ok(src.len()) + Ok(()) } fn load_acquire(&self, addr: u64) -> Result { diff --git a/src/hyperlight_host/src/sandbox/outb.rs b/src/hyperlight_host/src/sandbox/outb.rs index bb73763a6..aa40bec3d 100644 --- a/src/hyperlight_host/src/sandbox/outb.rs +++ b/src/hyperlight_host/src/sandbox/outb.rs @@ -16,6 +16,7 @@ limitations under the License. use std::sync::{Arc, Mutex}; +use hyperlight_common::flatbuffer_wrappers::function_call::FunctionCall; use hyperlight_common::flatbuffer_wrappers::function_types::{FunctionCallResult, ParameterValue}; use hyperlight_common::flatbuffer_wrappers::guest_error::{ErrorCode, GuestError}; use hyperlight_common::flatbuffer_wrappers::guest_log_data::GuestLogData; @@ -180,6 +181,71 @@ fn outb_abort( Ok(()) } +/// Handle a guest-to-host function call received via the G2H virtqueue. +fn outb_virtq_call( + mem_mgr: &mut SandboxMemoryManager, + host_funcs: &Arc>, +) -> Result<(), HandleOutbError> { + use hyperlight_common::virtq::msg::{MsgKind, VirtqMsgHeader}; + + let consumer = mem_mgr.g2h_consumer.as_mut().ok_or_else(|| { + HandleOutbError::ReadHostFunctionCall("G2H consumer not initialized".into()) + })?; + + let (entry, completion) = consumer + .poll(8192) + .map_err(|e| HandleOutbError::ReadHostFunctionCall(format!("G2H poll: {e}")))? + .ok_or_else(|| HandleOutbError::ReadHostFunctionCall("G2H poll: no entry".into()))?; + + // Parse: skip VirtqMsgHeader, deserialize FunctionCall from remainder + let entry_data = entry.data(); + if entry_data.len() < VirtqMsgHeader::SIZE { + return Err(HandleOutbError::ReadHostFunctionCall( + "G2H entry too short".into(), + )); + } + let payload = &entry_data[VirtqMsgHeader::SIZE..]; + let call = FunctionCall::try_from(payload) + .map_err(|e| HandleOutbError::ReadHostFunctionCall(e.to_string()))?; + + // Dispatch the host function (same as CallFunction path) + let name = call.function_name.clone(); + let args: Vec = call.parameters.unwrap_or(vec![]); + let res = host_funcs + .try_lock() + .map_err(|e| HandleOutbError::LockFailed(file!(), line!(), e.to_string()))? + .call_host_function(&name, args) + .map_err(|e| GuestError::new(ErrorCode::HostFunctionError, e.to_string())); + + // Serialize response: VirtqMsgHeader + FunctionCallResult + let func_result = FunctionCallResult::new(res); + let mut builder = flatbuffers::FlatBufferBuilder::new(); + let result_payload = func_result.encode(&mut builder); + + let resp_header = VirtqMsgHeader::new(MsgKind::Response, 0, result_payload.len() as u32); + let resp_header_bytes = bytemuck::bytes_of(&resp_header); + + // Write response into the completion buffer + match completion { + hyperlight_common::virtq::SendCompletion::Writable(mut wc) => { + wc.write_all(resp_header_bytes) + .map_err(|e| HandleOutbError::WriteHostFunctionResponse(format!("{e}")))?; + wc.write_all(result_payload) + .map_err(|e| HandleOutbError::WriteHostFunctionResponse(format!("{e}")))?; + consumer + .complete(wc.into()) + .map_err(|e| HandleOutbError::WriteHostFunctionResponse(format!("{e}")))?; + } + hyperlight_common::virtq::SendCompletion::Ack(ack) => { + consumer + .complete(ack.into()) + .map_err(|e| HandleOutbError::WriteHostFunctionResponse(format!("{e}")))?; + } + } + + Ok(()) +} + /// Handles OutB operations from the guest. #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] pub(crate) fn handle_outb( @@ -227,10 +293,7 @@ pub(crate) fn handle_outb( eprint!("{}", ch); Ok(()) } - OutBAction::VirtqNotify => { - // TODO(ring): acknowledge notification but no-op for now. - Ok(()) - } + OutBAction::VirtqNotify => outb_virtq_call(mem_mgr, host_funcs), #[cfg(feature = "trace_guest")] OutBAction::TraceBatch => Ok(()), #[cfg(feature = "mem_profile")] diff --git a/src/tests/rust_guests/dummyguest/Cargo.lock b/src/tests/rust_guests/dummyguest/Cargo.lock index f2085335f..0f1efc6bd 100644 --- a/src/tests/rust_guests/dummyguest/Cargo.lock +++ b/src/tests/rust_guests/dummyguest/Cargo.lock @@ -197,6 +197,7 @@ name = "hyperlight-guest" version = "0.15.0" dependencies = [ "anyhow", + "bytemuck", "flatbuffers", "hyperlight-common", "hyperlight-guest-tracing", diff --git a/src/tests/rust_guests/simpleguest/Cargo.lock b/src/tests/rust_guests/simpleguest/Cargo.lock index 455139d0d..48adf7195 100644 --- a/src/tests/rust_guests/simpleguest/Cargo.lock +++ b/src/tests/rust_guests/simpleguest/Cargo.lock @@ -189,6 +189,7 @@ name = "hyperlight-guest" version = "0.15.0" dependencies = [ "anyhow", + "bytemuck", "flatbuffers", "hyperlight-common", "hyperlight-guest-tracing", diff --git a/src/tests/rust_guests/witguest/Cargo.lock b/src/tests/rust_guests/witguest/Cargo.lock index 70f41063d..cbefd08cf 100644 --- a/src/tests/rust_guests/witguest/Cargo.lock +++ b/src/tests/rust_guests/witguest/Cargo.lock @@ -306,6 +306,7 @@ name = "hyperlight-guest" version = "0.15.0" dependencies = [ "anyhow", + "bytemuck", "flatbuffers", "hyperlight-common", "hyperlight-guest-tracing", From 502ad6dfc5197054af1e17c0d1c1f56275cb3c00 Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Fri, 3 Apr 2026 12:48:06 +0200 Subject: [PATCH 07/31] feat(virtq): replace host-to-guest calls with virtq Signed-off-by: Tomasz Andrzejak --- src/hyperlight_common/src/layout.rs | 30 +-- src/hyperlight_common/src/virtq/mod.rs | 14 +- src/hyperlight_common/src/virtq/pool.rs | 32 ++- src/hyperlight_common/src/virtq/producer.rs | 36 ++- .../src/virtq/recycle_pool.rs | 120 ++++++++++ src/hyperlight_guest/src/virtq/context.rs | 158 ++++++++++--- src/hyperlight_guest/src/virtq/mod.rs | 6 +- .../src/guest_function/call.rs | 34 ++- src/hyperlight_guest_bin/src/virtq/mod.rs | 62 +++-- src/hyperlight_host/src/mem/layout.rs | 1 + src/hyperlight_host/src/mem/mgr.rs | 216 +++++++++++++++--- src/hyperlight_host/src/sandbox/config.rs | 80 +++++-- .../src/sandbox/initialized_multi_use.rs | 4 +- src/hyperlight_host/src/sandbox/outb.rs | 53 +++-- .../src/sandbox/uninitialized_evolve.rs | 15 +- src/hyperlight_host/tests/integration_test.rs | 8 +- 16 files changed, 692 insertions(+), 177 deletions(-) create mode 100644 src/hyperlight_common/src/virtq/recycle_pool.rs diff --git a/src/hyperlight_common/src/layout.rs b/src/hyperlight_common/src/layout.rs index 8f7cea9c9..5bb8a2cb5 100644 --- a/src/hyperlight_common/src/layout.rs +++ b/src/hyperlight_common/src/layout.rs @@ -40,22 +40,24 @@ pub const SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET: u64 = 0x20; pub const SCRATCH_TOP_G2H_RING_GVA_OFFSET: u64 = 0x28; pub const SCRATCH_TOP_H2G_RING_GVA_OFFSET: u64 = 0x30; pub const SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET: u64 = 0x38; -pub const SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET: u64 = 0x3a; -pub const SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET: u64 = 0x3c; -pub const SCRATCH_TOP_VIRTQ_GENERATION_OFFSET: u64 = 0x3e; -pub const SCRATCH_TOP_EXN_STACK_OFFSET: u64 = 0x40; +pub const SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET: u64 = 0x3A; +pub const SCRATCH_TOP_G2H_POOL_PAGES_OFFSET: u64 = 0x3C; +pub const SCRATCH_TOP_H2G_POOL_PAGES_OFFSET: u64 = 0x3E; +pub const SCRATCH_TOP_H2G_POOL_GVA_OFFSET: u64 = 0x48; +pub const SCRATCH_TOP_EXN_STACK_OFFSET: u64 = 0x50; const _: () = { - assert!(SCRATCH_TOP_SIZE_OFFSET + 8 <= SCRATCH_TOP_ALLOCATOR_OFFSET); - assert!(SCRATCH_TOP_ALLOCATOR_OFFSET + 8 <= SCRATCH_TOP_SNAPSHOT_PT_GPA_BASE_OFFSET); - assert!(SCRATCH_TOP_SNAPSHOT_PT_GPA_BASE_OFFSET + 8 <= SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET); - assert!(SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET + 8 <= SCRATCH_TOP_G2H_RING_GVA_OFFSET); - assert!(SCRATCH_TOP_G2H_RING_GVA_OFFSET + 8 <= SCRATCH_TOP_H2G_RING_GVA_OFFSET); - assert!(SCRATCH_TOP_H2G_RING_GVA_OFFSET + 8 <= SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET); - assert!(SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET + 2 <= SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET); - assert!(SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET + 2 <= SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET); - assert!(SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET + 2 <= SCRATCH_TOP_VIRTQ_GENERATION_OFFSET); - assert!(SCRATCH_TOP_VIRTQ_GENERATION_OFFSET + 2 <= SCRATCH_TOP_EXN_STACK_OFFSET); + assert!(SCRATCH_TOP_ALLOCATOR_OFFSET >= SCRATCH_TOP_SIZE_OFFSET + 8); + assert!(SCRATCH_TOP_SNAPSHOT_PT_GPA_BASE_OFFSET >= SCRATCH_TOP_ALLOCATOR_OFFSET + 8); + assert!(SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET >= SCRATCH_TOP_SNAPSHOT_PT_GPA_BASE_OFFSET + 8); + assert!(SCRATCH_TOP_G2H_RING_GVA_OFFSET >= SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET + 8); + assert!(SCRATCH_TOP_H2G_RING_GVA_OFFSET >= SCRATCH_TOP_G2H_RING_GVA_OFFSET + 8); + assert!(SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET >= SCRATCH_TOP_H2G_RING_GVA_OFFSET + 8); + assert!(SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET >= SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET + 2); + assert!(SCRATCH_TOP_G2H_POOL_PAGES_OFFSET >= SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET + 2); + assert!(SCRATCH_TOP_H2G_POOL_PAGES_OFFSET >= SCRATCH_TOP_G2H_POOL_PAGES_OFFSET + 2); + assert!(SCRATCH_TOP_H2G_POOL_GVA_OFFSET >= SCRATCH_TOP_H2G_POOL_PAGES_OFFSET + 8); + assert!(SCRATCH_TOP_EXN_STACK_OFFSET >= SCRATCH_TOP_H2G_POOL_GVA_OFFSET + 8); assert!(SCRATCH_TOP_EXN_STACK_OFFSET % 0x10 == 0); }; diff --git a/src/hyperlight_common/src/virtq/mod.rs b/src/hyperlight_common/src/virtq/mod.rs index 2e10491b2..9803977b0 100644 --- a/src/hyperlight_common/src/virtq/mod.rs +++ b/src/hyperlight_common/src/virtq/mod.rs @@ -157,6 +157,7 @@ mod event; pub mod msg; mod pool; mod producer; +pub mod recycle_pool; mod ring; use core::num::NonZeroU16; @@ -170,7 +171,7 @@ pub use producer::*; pub use ring::*; use thiserror::Error; -/// A trait for notifying about new requests in the virtqueue. +/// A trait for notifying the consumer about virtqueue events. pub trait Notifier { fn notify(&self, stats: QueueStats); } @@ -476,15 +477,12 @@ pub(crate) mod test_utils { } } + type TestProducer = VirtqProducer; + type TestConsumer = VirtqConsumer; + /// Create test infrastructure: a producer, consumer, and notifier backed /// by the supplied [`OwnedRing`]. - pub(crate) fn make_test_producer( - ring: &OwnedRing, - ) -> ( - VirtqProducer, - VirtqConsumer, - TestNotifier, - ) { + pub(crate) fn make_test_producer(ring: &OwnedRing) -> (TestProducer, TestConsumer, TestNotifier) { let layout = ring.layout(); let mem = ring.mem(); diff --git a/src/hyperlight_common/src/virtq/pool.rs b/src/hyperlight_common/src/virtq/pool.rs index 83178998d..cf0915fdf 100644 --- a/src/hyperlight_common/src/virtq/pool.rs +++ b/src/hyperlight_common/src/virtq/pool.rs @@ -79,7 +79,6 @@ limitations under the License. //! owning slab (`Slab::resize`) but will never move allocations between //! slabs. -#[cfg(all(test, loom))] use alloc::sync::Arc; use core::cmp::Ordering; @@ -124,6 +123,9 @@ pub trait BufferProvider { /// Resize by trying in-place grow; otherwise reserve a new block and free old. fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result; + + /// Reset the pool to initial state. + fn reset(&self) {} } impl BufferProvider for alloc::rc::Rc { @@ -136,9 +138,12 @@ impl BufferProvider for alloc::rc::Rc { fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result { (**self).resize(old_alloc, new_len) } + fn reset(&self) { + (**self).reset() + } } -impl BufferProvider for alloc::sync::Arc { +impl BufferProvider for Arc { fn alloc(&self, len: usize) -> Result { (**self).alloc(len) } @@ -148,6 +153,9 @@ impl BufferProvider for alloc::sync::Arc { fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result { (**self).resize(old_alloc, new_len) } + fn reset(&self) { + (**self).reset() + } } /// The owner of a mapped buffer, ensuring its lifetime. @@ -540,9 +548,10 @@ struct Inner { } /// Two tier buffer pool with small and large slabs. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct BufferPool { - inner: AtomicRefCell>, + // TODO: Use Rc instead, relax Sync + Send bounds + inner: Arc>>, } impl BufferPool { @@ -550,16 +559,9 @@ impl BufferPool { pub fn new(base_addr: u64, region_len: usize) -> Result { let inner = Inner::::new(base_addr, region_len)?; Ok(Self { - inner: inner.into(), + inner: Arc::new(inner.into()), }) } - - /// Reset the pool to initial state - pub fn reset(&self) { - let mut inner = self.inner.borrow_mut(); - inner.lower.reset(); - inner.upper.reset(); - } } #[cfg(all(test, loom))] @@ -672,6 +674,12 @@ impl BufferProvider for BufferPool { fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result { self.inner.borrow_mut().resize(old_alloc, new_len) } + + fn reset(&self) { + let mut inner = self.inner.borrow_mut(); + inner.lower.reset(); + inner.upper.reset(); + } } #[cfg(all(test, loom))] diff --git a/src/hyperlight_common/src/virtq/producer.rs b/src/hyperlight_common/src/virtq/producer.rs index 28c5dbf3a..5e6a7edf1 100644 --- a/src/hyperlight_common/src/virtq/producer.rs +++ b/src/hyperlight_common/src/virtq/producer.rs @@ -282,6 +282,13 @@ where *slot = Some(inflight); let should_notify = self.inner.should_notify_since(cursor_before)?; + + // TODO(virtq): for now simulate current outb behavior of only + // notifying on bidirectional (request/response) entries. + // Eventually this should be decoupled from the buffer layout + // and driven entirely by event suppression rules. + let should_notify = should_notify && matches!(inflight, Inflight::ReadWrite { .. }); + if should_notify { self.notifier.notify(QueueStats { num_free: self.inner.num_free(), @@ -292,6 +299,17 @@ where Ok(Token(id)) } + /// Signal backpressure to the consumer. + /// + /// Bypasses event suppression. Call this when submit fails with a backpressure error and the consumer needs to drain. + #[inline] + pub fn notify_backpressure(&self) { + self.notifier.notify(QueueStats { + num_free: self.inner.num_free(), + num_inflight: self.inner.num_inflight(), + }); + } + /// Get the current used cursor position. /// /// Useful for setting up descriptor-based event suppression. @@ -330,10 +348,20 @@ where Ok(()) } - /// Reset ring and inflight state to initial values. - /// Does not reset the buffer pool; call pool.reset() separately if needed. + /// Reset ring, inflight, and pool state to initial values. + /// + /// # Safety + /// + /// All [`RecvCompletion`]s (and their backing [`Bytes`]) from + /// previous `poll()` calls must have been dropped before calling + /// this. Outstanding completions hold pool allocations via + /// `BufferOwner`; resetting the pool while they exist would cause + /// double-free on drop. + /// + /// TODO(virtq): properly restore state after snapshot instead of just resetting everything pub fn reset(&mut self) { self.inner.reset(); + self.pool.reset(); self.inflight.fill(None); } } @@ -343,14 +371,14 @@ where /// If dropped without building, no resources are leaked (allocations are /// deferred to [`build`](Self::build)). #[must_use = "call .build() to create a SendEntry"] -pub struct ChainBuilder { +pub struct ChainBuilder { mem: M, pool: P, entry_cap: Option, cqe_cap: Option, } -impl ChainBuilder { +impl ChainBuilder { fn new(mem: M, pool: P) -> Self { Self { mem, diff --git a/src/hyperlight_common/src/virtq/recycle_pool.rs b/src/hyperlight_common/src/virtq/recycle_pool.rs new file mode 100644 index 000000000..4bcf9978a --- /dev/null +++ b/src/hyperlight_common/src/virtq/recycle_pool.rs @@ -0,0 +1,120 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! A simple fixed-size buffer recycler for H2G prefill entries. +//! +//! Unlike [`super::BufferPool`] which uses a bitmap allocator, this +//! holds a fixed set of same-sized buffer addresses in a free list. +//! Alloc and dealloc are O(1). Intended for H2G writable buffers +//! that are pre-allocated once and recycled after each use. + +use alloc::sync::Arc; + +use atomic_refcell::AtomicRefCell; +use smallvec::SmallVec; + +use super::{AllocError, Allocation, BufferProvider}; + +/// A recycling buffer provider with fixed-size slots. +#[derive(Clone)] +pub struct RecyclePool { + inner: Arc>, +} + +struct RecyclePoolInner { + base_addr: u64, + slot_size: usize, + count: usize, + free: SmallVec<[u64; 64]>, +} + +impl RecyclePool { + /// Create a new recycling pool by carving `base..base+region_len` into slots of `slot_size` bytes. + pub fn new(base_addr: u64, region_len: usize, slot_size: usize) -> Result { + if slot_size == 0 { + return Err(AllocError::InvalidArg); + } + + let count = region_len / slot_size; + if count == 0 { + return Err(AllocError::EmptyRegion); + } + + let mut free = SmallVec::with_capacity(count); + for i in 0..count { + free.push(base_addr + (i * slot_size) as u64); + } + + let inner = AtomicRefCell::new(RecyclePoolInner { + base_addr, + slot_size, + count, + free, + }); + + Ok(Self { + inner: inner.into(), + }) + } + + /// Number of free slots. + pub fn num_free(&self) -> usize { + self.inner.borrow().free.len() + } +} + +impl BufferProvider for RecyclePool { + fn alloc(&self, len: usize) -> Result { + let mut inner = self.inner.borrow_mut(); + if len > inner.slot_size { + return Err(AllocError::OutOfMemory); + } + + let addr = inner.free.pop().ok_or(AllocError::OutOfMemory)?; + + Ok(Allocation { + addr, + len: inner.slot_size, + }) + } + + fn dealloc(&self, alloc: Allocation) -> Result<(), AllocError> { + let mut inner = self.inner.borrow_mut(); + inner.free.push(alloc.addr); + Ok(()) + } + + fn resize(&self, old: Allocation, new_len: usize) -> Result { + let inner = self.inner.borrow(); + if new_len > inner.slot_size { + return Err(AllocError::OutOfMemory); + } + Ok(old) + } + + fn reset(&self) { + let mut inner = self.inner.borrow_mut(); + let base = inner.base_addr; + let slot = inner.slot_size; + let count = inner.count; + + inner.free.clear(); + + for i in 0..count { + inner.free.push(base + (i * slot) as u64); + } + } +} diff --git a/src/hyperlight_guest/src/virtq/context.rs b/src/hyperlight_guest/src/virtq/context.rs index ea846cae7..53d90068c 100644 --- a/src/hyperlight_guest/src/virtq/context.rs +++ b/src/hyperlight_guest/src/virtq/context.rs @@ -16,9 +16,7 @@ limitations under the License. //! Guest virtqueue context. -use alloc::sync::Arc; use alloc::vec::Vec; -use core::num::NonZeroU16; use core::sync::atomic::AtomicU16; use core::sync::atomic::Ordering::Relaxed; @@ -28,8 +26,10 @@ use hyperlight_common::flatbuffer_wrappers::function_types::{ FunctionCallResult, ParameterValue, ReturnType, ReturnValue, }; use hyperlight_common::flatbuffer_wrappers::util::estimate_flatbuffer_capacity; +use hyperlight_common::mem::PAGE_SIZE_USIZE; use hyperlight_common::outb::OutBAction; use hyperlight_common::virtq::msg::{MsgKind, VirtqMsgHeader}; +use hyperlight_common::virtq::recycle_pool::RecyclePool; use hyperlight_common::virtq::{BufferPool, Layout, Notifier, QueueStats, VirtqProducer}; use super::GuestMemOps; @@ -50,41 +50,58 @@ impl Notifier for GuestNotifier { } /// Type alias for the guest-side G2H producer. -pub type G2hProducer = VirtqProducer>; +pub type G2hProducer = VirtqProducer; + +/// Type alias for the guest-side H2G producer (uses fixed-size RecyclePool slots). +pub type H2gProducer = VirtqProducer; + +/// Configuration for one queue passed to [`GuestContext::new`]. +pub struct QueueConfig { + /// Ring descriptor layout in shared memory. + pub layout: Layout, + /// Base GVA of the buffer pool region. + pub pool_gva: u64, + /// Number of pages in the buffer pool. + pub pool_pages: usize, +} /// Virtqueue runtime state for guest-host communication. pub struct GuestContext { - g2h_pool: Arc, g2h_producer: G2hProducer, + h2g_producer: H2gProducer, generation: u64, } impl GuestContext { - /// Create a new context with a G2H queue. - /// - /// # Safety - /// - /// `ring_gva` must point to valid, zeroed ring memory. - /// `pool_gva` must point to valid, zeroed memory. - pub unsafe fn new( - ring_gva: u64, - num_descs: u16, - pool_gva: u64, - pool_size: usize, - generation: u64, - ) -> Self { - let pool = Arc::new( - BufferPool::new(pool_gva, pool_size).expect("failed to create G2H buffer pool"), - ); - let nz = NonZeroU16::new(num_descs).expect("G2H queue depth must be non-zero"); - let layout = unsafe { Layout::from_base(ring_gva, nz) }.expect("invalid G2H ring layout"); - let producer = VirtqProducer::new(layout, GuestMemOps, GuestNotifier, pool.clone()); + /// Create a new context with G2H and H2G queues. + pub fn new(g2h: QueueConfig, h2g: QueueConfig, generation: u64) -> Self { + let size = g2h.pool_pages * PAGE_SIZE_USIZE; + let g2h_pool = + BufferPool::new(g2h.pool_gva, size).expect("failed to create G2H buffer pool"); + let g2h_producer = + VirtqProducer::new(g2h.layout, GuestMemOps, GuestNotifier, g2h_pool.clone()); - Self { - g2h_pool: pool, - g2h_producer: producer, + // Each H2G prefill entry is a single descriptor with one contiguous buffer: one + // fixed-size buffer per descriptor, large payloads split across multiple independent + // completions. + // + // TODO(virtq): consider smaller slot_size (e.g. pool_size / desc_count) to maximize + // prefilled entries for host-side call batching. + let size = h2g.pool_pages * PAGE_SIZE_USIZE; + let slot = PAGE_SIZE_USIZE; + let h2g_pool = + RecyclePool::new(h2g.pool_gva, size, slot).expect("failed to create H2G recycle pool"); + let h2g_producer = + VirtqProducer::new(h2g.layout, GuestMemOps, GuestNotifier, h2g_pool.clone()); + + let mut ctx = Self { + g2h_producer, + h2g_producer, generation, - } + }; + + ctx.prefill_h2g(); + ctx } /// Call a host function via the G2H virtqueue. @@ -146,10 +163,95 @@ impl GuestContext { Ok(ret) } + /// Pre-fill the H2G queue with completion-only descriptors so the host + /// can write incoming call payloads into them. + fn prefill_h2g(&mut self) { + loop { + let entry = match self + .h2g_producer + .chain() + .completion(PAGE_SIZE_USIZE) + .build() + { + Ok(e) => e, + Err(_) => break, + }; + if self.h2g_producer.submit(entry).is_err() { + break; + } + } + } + + /// Receive a host-to-guest function call from the H2G queue. + pub fn recv_h2g_call(&mut self) -> Result { + let Some(completion) = self.h2g_producer.poll()? else { + bail!("H2G: no pending call"); + }; + + let data = &completion.data; + if data.len() < VirtqMsgHeader::SIZE { + bail!("H2G: completion too short for header"); + } + + let hdr: &VirtqMsgHeader = bytemuck::from_bytes(&data[..VirtqMsgHeader::SIZE]); + + if hdr.kind != MsgKind::Request as u8 { + bail!("H2G: unexpected message kind"); + } + + let payload_end = VirtqMsgHeader::SIZE + hdr.payload_len as usize; + if payload_end > data.len() { + bail!("H2G: payload length exceeds completion data"); + } + + let payload = &data[VirtqMsgHeader::SIZE..payload_end]; + let fc = FunctionCall::try_from(payload)?; + Ok(fc) + } + + /// Send the result of a host-to-guest call back to the host via the + /// G2H queue, then refill one H2G descriptor slot. + pub fn send_h2g_result(&mut self, payload: &[u8]) -> Result<()> { + // Build a Response message on the G2H queue + let reqid = REQUEST_ID.fetch_add(1, Relaxed); + let hdr = VirtqMsgHeader::new(MsgKind::Response, reqid, payload.len() as u32); + let hdr_bytes = bytemuck::bytes_of(&hdr); + + let entry_len = VirtqMsgHeader::SIZE + payload.len(); + let mut entry = self.g2h_producer.chain().entry(entry_len).build()?; + + entry.write_all(hdr_bytes)?; + entry.write_all(payload)?; + self.g2h_producer.submit(entry)?; + + // Refill one H2G completion slot + if let Ok(e) = self + .h2g_producer + .chain() + .completion(PAGE_SIZE_USIZE) + .build() + { + let _ = self.h2g_producer.submit(e); + } + + Ok(()) + } + + /// Drain any pending G2H completions (discard them). + /// + /// This is called before checking for H2G calls so that the host + /// can reclaim G2H response buffers. + pub fn drain_g2h_completions(&mut self) { + while let Ok(Some(_)) = self.g2h_producer.poll() {} + } + /// Reset ring and pool state after snapshot restore. pub(super) fn reset(&mut self, new_generation: u64) { self.g2h_producer.reset(); - self.g2h_pool.reset(); + // H2G state is NOT reset. The guest's inflight and cursors + // survived via CoW and are already correct. The host's + // restore_h2g_prefill() wrote matching descriptors to the + // zeroed ring memory. Both sides are in sync. self.generation = new_generation; } diff --git a/src/hyperlight_guest/src/virtq/mod.rs b/src/hyperlight_guest/src/virtq/mod.rs index 6aafc75e1..d3f592985 100644 --- a/src/hyperlight_guest/src/virtq/mod.rs +++ b/src/hyperlight_guest/src/virtq/mod.rs @@ -81,11 +81,13 @@ pub fn set_global_context(ctx: GuestContext) { /// Reset the global context if a snapshot restore was detected. /// Compares the virtq generation counter in scratch-top metadata. -pub fn reset_global_context() { +pub fn maybe_reset_global_context() { if !is_initialized() { return; } + let current_gen = read_gen(); + with_context(|ctx| { if current_gen != ctx.generation() { ctx.reset(current_gen); @@ -93,7 +95,7 @@ pub fn reset_global_context() { }); } -/// Read the current virtqueue generation from scratch-top metadata. +/// Read the current snapshot generation from scratch-top metadata. fn read_gen() -> u64 { unsafe { *scratch_top_ptr::(SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET) } } diff --git a/src/hyperlight_guest_bin/src/guest_function/call.rs b/src/hyperlight_guest_bin/src/guest_function/call.rs index 5db880f8a..d898dd5d3 100644 --- a/src/hyperlight_guest_bin/src/guest_function/call.rs +++ b/src/hyperlight_guest_bin/src/guest_function/call.rs @@ -26,7 +26,7 @@ use hyperlight_guest::error::{HyperlightGuestError, Result}; use hyperlight_guest::virtq; use tracing::instrument; -use crate::{GUEST_HANDLE, REGISTERED_GUEST_FUNCTIONS}; +use crate::REGISTERED_GUEST_FUNCTIONS; core::arch::global_asm!( ".weak guest_dispatch_function", @@ -99,34 +99,32 @@ pub(crate) fn internal_dispatch_function() { tracing::span!(tracing::Level::INFO, "internal_dispatch_function").entered() }; - let handle = unsafe { GUEST_HANDLE }; - // After snapshot restore, the ring memory is zeroed but the // producer's cursors are stale. Check once per dispatch entry. - virtq::reset_global_context(); + virtq::maybe_reset_global_context(); + virtq::with_context(|ctx| ctx.drain_g2h_completions()); - let function_call = handle - .try_pop_shared_input_data_into::() - .expect("Function call deserialization failed"); + let function_call = virtq::with_context(|ctx| { + ctx.recv_h2g_call() + .expect("H2G: expected a host-to-guest call") + }); let res = call_guest_function(function_call); - match res { - Ok(bytes) => { - handle - .push_shared_output_data(bytes.as_slice()) - .expect("Failed to serialize function call result"); - } + let res_bytes = match res { + Ok(bytes) => bytes, Err(err) => { let guest_error = Err(GuestError::new(err.kind, err.message)); let fcr = FunctionCallResult::new(guest_error); let mut builder = FlatBufferBuilder::new(); - let data = fcr.encode(&mut builder); - handle - .push_shared_output_data(data) - .expect("Failed to serialize function call result"); + fcr.encode(&mut builder).to_vec() } - } + }; + + virtq::with_context(|ctx| { + ctx.send_h2g_result(&res_bytes) + .expect("H2G: failed to send result"); + }); // All this tracing logic shall be done right before the call to `hlt` which is done after this // function returns diff --git a/src/hyperlight_guest_bin/src/virtq/mod.rs b/src/hyperlight_guest_bin/src/virtq/mod.rs index 997044dde..2621f5dd2 100644 --- a/src/hyperlight_guest_bin/src/virtq/mod.rs +++ b/src/hyperlight_guest_bin/src/virtq/mod.rs @@ -16,15 +16,18 @@ limitations under the License. //! Guest-side virtqueue initialization. +use core::num::NonZeroU16; + use hyperlight_common::layout::{ - SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET, SCRATCH_TOP_G2H_RING_GVA_OFFSET, - SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET, SCRATCH_TOP_H2G_RING_GVA_OFFSET, - SCRATCH_TOP_VIRTQ_GENERATION_OFFSET, SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET, scratch_top_ptr, + SCRATCH_TOP_G2H_POOL_PAGES_OFFSET, SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET, + SCRATCH_TOP_G2H_RING_GVA_OFFSET, SCRATCH_TOP_H2G_POOL_GVA_OFFSET, + SCRATCH_TOP_H2G_POOL_PAGES_OFFSET, SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET, + SCRATCH_TOP_H2G_RING_GVA_OFFSET, SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET, scratch_top_ptr, }; use hyperlight_common::mem::PAGE_SIZE_USIZE; use hyperlight_common::virtq::Layout as VirtqLayout; use hyperlight_guest::prim_alloc::alloc_phys_pages; -use hyperlight_guest::virtq::context::GuestContext; +use hyperlight_guest::virtq::context::{GuestContext, QueueConfig}; use crate::paging::phys_to_virt; @@ -34,12 +37,12 @@ pub(crate) fn init_virtqueues() { let g2h_depth = unsafe { *scratch_top_ptr::(SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET) }; let h2g_gva = unsafe { *scratch_top_ptr::(SCRATCH_TOP_H2G_RING_GVA_OFFSET) }; let h2g_depth = unsafe { *scratch_top_ptr::(SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET) }; - let pool_pages = unsafe { *scratch_top_ptr::(SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET) } as u64; - let generation = unsafe { *scratch_top_ptr::(SCRATCH_TOP_VIRTQ_GENERATION_OFFSET) }; + let g2h_pages = unsafe { *scratch_top_ptr::(SCRATCH_TOP_G2H_POOL_PAGES_OFFSET) } as usize; + let h2g_pages = unsafe { *scratch_top_ptr::(SCRATCH_TOP_H2G_POOL_PAGES_OFFSET) } as usize; + let generation = unsafe { *scratch_top_ptr::(SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET) }; - assert!(g2h_depth > 0 && h2g_depth > 0); + assert!(g2h_depth > 0 && h2g_depth > 0 && g2h_pages > 0 && h2g_pages > 0); assert!(g2h_gva != 0 && h2g_gva != 0); - assert!(pool_pages > 0); // Zero ring memory let g2h_ring_size = VirtqLayout::query_size(g2h_depth as usize); @@ -48,16 +51,41 @@ pub(crate) fn init_virtqueues() { let h2g_ring_size = VirtqLayout::query_size(h2g_depth as usize); unsafe { core::ptr::write_bytes(h2g_gva as *mut u8, 0, h2g_ring_size) }; - // Allocate buffer pool from physical pages - let pool_gpa = unsafe { alloc_phys_pages(pool_pages) }; - let pool_ptr = phys_to_virt(pool_gpa).expect("failed to map pool pages"); - let pool_gva = pool_ptr as u64; - let pool_size = pool_pages as usize * PAGE_SIZE_USIZE; - unsafe { core::ptr::write_bytes(pool_ptr, 0, pool_size) }; + // Build ring layouts + let nz = NonZeroU16::new(g2h_depth).expect("G2H depth zero"); + let g2h_layout = unsafe { VirtqLayout::from_base(g2h_gva, nz) }.expect("invalid layout"); + + let nz = NonZeroU16::new(h2g_depth).expect("H2G depth zero"); + let h2g_layout = unsafe { VirtqLayout::from_base(h2g_gva, nz) }.expect("invalid layout"); + + // Allocate buffer pools + let g2h_pool_gva = alloc_pool(g2h_pages); + let h2g_pool_gva = alloc_pool(h2g_pages); - // Create and install global context - let ctx = unsafe { GuestContext::new(g2h_gva, g2h_depth, pool_gva, pool_size, generation.into()) }; + // Publish H2G pool GVA so the host can prefill after restore + unsafe { *scratch_top_ptr::(SCRATCH_TOP_H2G_POOL_GVA_OFFSET) = h2g_pool_gva }; + + let ctx = GuestContext::new( + QueueConfig { + layout: g2h_layout, + pool_gva: g2h_pool_gva, + pool_pages: g2h_pages, + }, + QueueConfig { + layout: h2g_layout, + pool_gva: h2g_pool_gva, + pool_pages: h2g_pages, + }, + generation, + ); hyperlight_guest::virtq::set_global_context(ctx); +} - let _ = (h2g_gva, h2g_depth); +/// Allocate and zero `n` physical pages, returning the GVA. +fn alloc_pool(n: usize) -> u64 { + let gpa = unsafe { alloc_phys_pages(n as u64) }; + let ptr = phys_to_virt(gpa).expect("failed to map pool pages"); + let size = n as usize * PAGE_SIZE_USIZE; + unsafe { core::ptr::write_bytes(ptr, 0, size) }; + ptr as u64 } diff --git a/src/hyperlight_host/src/mem/layout.rs b/src/hyperlight_host/src/mem/layout.rs index 183cdc2af..2b821a965 100644 --- a/src/hyperlight_host/src/mem/layout.rs +++ b/src/hyperlight_host/src/mem/layout.rs @@ -494,6 +494,7 @@ impl SandboxMemoryLayout { } /// Get the size of the G2H ring in bytes. + #[allow(dead_code)] fn get_g2h_ring_size(&self) -> usize { hyperlight_common::virtq::Layout::query_size( self.sandbox_memory_config.get_g2h_queue_depth(), diff --git a/src/hyperlight_host/src/mem/mgr.rs b/src/hyperlight_host/src/mem/mgr.rs index b35077a16..dbe052f8b 100644 --- a/src/hyperlight_host/src/mem/mgr.rs +++ b/src/hyperlight_host/src/mem/mgr.rs @@ -21,24 +21,12 @@ use flatbuffers::FlatBufferBuilder; use hyperlight_common::flatbuffer_wrappers::function_call::{ FunctionCall, validate_guest_function_call_buffer, }; - -use super::virtq_mem::HostMemOps; - -/// No-op notifier for host-side consumer. -/// The host resumes the VM to notify the guest, not via the ring. -#[derive(Clone, Copy)] -pub(crate) struct HostNotifier; - -impl hyperlight_common::virtq::Notifier for HostNotifier { - fn notify(&self, _stats: hyperlight_common::virtq::QueueStats) {} -} - -/// Type alias for the host-side G2H virtqueue consumer. -pub(crate) type G2hConsumer = hyperlight_common::virtq::VirtqConsumer; use hyperlight_common::flatbuffer_wrappers::function_types::FunctionCallResult; use hyperlight_common::flatbuffer_wrappers::guest_log_data::GuestLogData; -use hyperlight_common::virtq::Layout as VirtqLayout; -use hyperlight_common::vmem::{self, PAGE_TABLE_SIZE, PageTableEntry, PhysAddr}; +use hyperlight_common::mem::PAGE_SIZE_USIZE; +use hyperlight_common::virtq::msg::{MsgKind, VirtqMsgHeader}; +use hyperlight_common::virtq::{self, Layout as VirtqLayout}; +use hyperlight_common::vmem::{self, PAGE_TABLE_SIZE}; #[cfg(all(feature = "crashdump", not(feature = "i686-guest")))] use hyperlight_common::vmem::{BasicMapping, MappingKind}; use tracing::{Span, instrument}; @@ -47,6 +35,7 @@ use super::layout::SandboxMemoryLayout; use super::shared_mem::{ ExclusiveSharedMemory, GuestSharedMemory, HostSharedMemory, ReadonlySharedMemory, SharedMemory, }; +use super::virtq_mem::HostMemOps; use crate::hypervisor::regs::CommonSpecialRegisters; use crate::mem::memory_region::MemoryRegion; #[cfg(crashdump)] @@ -54,6 +43,20 @@ use crate::mem::memory_region::{CrashDumpRegion, MemoryRegionFlags, MemoryRegion use crate::sandbox::snapshot::{NextAction, Snapshot}; use crate::{Result, new_error}; +/// Type alias for the host-side G2H virtqueue consumer. +pub(crate) type G2hConsumer = virtq::VirtqConsumer; +/// Type alias for the host-side H2G virtqueue consumer. +pub(crate) type H2gConsumer = virtq::VirtqConsumer; + +/// No-op notifier for host-side consumer. +/// The host resumes the VM to notify the guest, not via the ring. +#[derive(Clone, Copy)] +pub(crate) struct HostNotifier; + +impl virtq::Notifier for HostNotifier { + fn notify(&self, _stats: virtq::QueueStats) {} +} + #[cfg(all(feature = "crashdump", not(feature = "i686-guest")))] fn mapping_kind_to_flags(kind: &MappingKind) -> (MemoryRegionFlags, MemoryRegionType) { match kind { @@ -171,9 +174,13 @@ pub(crate) struct SandboxMemoryManager { pub(crate) snapshot_count: u64, /// G2H virtqueue consumer, created after sandbox init. pub(crate) g2h_consumer: Option, + /// H2G virtqueue consumer, created after sandbox init. + pub(crate) h2g_consumer: Option, + /// Saved H2G pool GVA for prefilling after snapshot restore. + pub(crate) h2g_pool_gva: Option, } -impl Clone for SandboxMemoryManager { +impl Clone for SandboxMemoryManager { fn clone(&self) -> Self { Self { shared_mem: self.shared_mem.clone(), @@ -183,7 +190,9 @@ impl Clone for SandboxMemoryManager { mapped_rgns: self.mapped_rgns, abort_buffer: self.abort_buffer.clone(), snapshot_count: self.snapshot_count, - g2h_consumer: None, // consumer is not cloned; re-init if needed + g2h_consumer: None, + h2g_consumer: None, + h2g_pool_gva: self.h2g_pool_gva, } } } @@ -322,6 +331,8 @@ where abort_buffer: Vec::new(), snapshot_count: 0, g2h_consumer: None, + h2g_consumer: None, + h2g_pool_gva: None, } } @@ -393,6 +404,8 @@ impl SandboxMemoryManager { abort_buffer: self.abort_buffer, snapshot_count: self.snapshot_count, g2h_consumer: None, + h2g_consumer: None, + h2g_pool_gva: None, }; let guest_mgr = SandboxMemoryManager { shared_mem: gshm, @@ -403,9 +416,12 @@ impl SandboxMemoryManager { abort_buffer: Vec::new(), // Guest doesn't need abort buffer snapshot_count: self.snapshot_count, g2h_consumer: None, + h2g_consumer: None, + h2g_pool_gva: None, }; host_mgr.update_scratch_bookkeeping()?; host_mgr.init_g2h_consumer()?; + host_mgr.init_h2g_consumer()?; Ok((host_mgr, guest_mgr)) } } @@ -499,6 +515,7 @@ impl SandboxMemoryManager { /// Writes a guest function call to memory #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] + #[allow(dead_code)] pub(crate) fn write_guest_function_call(&mut self, buffer: &[u8]) -> Result<()> { validate_guest_function_call_buffer(buffer).map_err(|e| { new_error!( @@ -517,6 +534,7 @@ impl SandboxMemoryManager { /// Reads a function call result from memory. /// A function call result can be either an error or a successful return value. + #[allow(dead_code)] #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] pub(crate) fn get_guest_function_call_result(&mut self) -> Result { self.scratch_mem.try_pop_buffer_into::( @@ -603,6 +621,8 @@ impl SandboxMemoryManager { self.update_scratch_bookkeeping()?; self.init_g2h_consumer()?; + self.init_h2g_consumer()?; + self.restore_h2g_prefill()?; Ok((gsnapshot, gscratch)) } @@ -668,14 +688,13 @@ impl SandboxMemoryManager { self.layout.sandbox_memory_config.get_h2g_queue_depth() as u16, )?; self.scratch_mem.write::( - scratch_size - SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET as usize, - self.layout.sandbox_memory_config.get_virtq_pool_pages() as u16, + scratch_size - SCRATCH_TOP_G2H_POOL_PAGES_OFFSET as usize, + self.layout.sandbox_memory_config.get_g2h_pool_pages() as u16, + )?; + self.scratch_mem.write::( + scratch_size - SCRATCH_TOP_H2G_POOL_PAGES_OFFSET as usize, + self.layout.sandbox_memory_config.get_h2g_pool_pages() as u16, )?; - // Increment generation so the guest detects stale ring state. - let gen_offset = scratch_size - SCRATCH_TOP_VIRTQ_GENERATION_OFFSET as usize; - let gen_val: u16 = self.scratch_mem.read(gen_offset).unwrap_or(0); - self.scratch_mem - .write::(gen_offset, gen_val.wrapping_add(1))?; // Copy page tables from `shared_mem` into scratch. PT bytes // are appended to the snapshot blob at build time and live @@ -923,7 +942,7 @@ impl SandboxMemoryManager { } /// Compute the G2H virtqueue Layout from scratch region addresses. - pub(crate) fn g2h_virtq_layout(&self) -> Result { + pub(crate) fn g2h_virtq_layout(&self) -> Result { let base = self.layout.get_g2h_ring_gva(); let depth = self.layout.sandbox_memory_config.get_g2h_queue_depth() as u16; @@ -934,7 +953,7 @@ impl SandboxMemoryManager { } /// Compute the H2G virtqueue Layout from scratch region addresses. - pub(crate) fn h2g_virtq_layout(&self) -> Result { + pub(crate) fn h2g_virtq_layout(&self) -> Result { let base = self.layout.get_h2g_ring_gva(); let depth = self.layout.sandbox_memory_config.get_h2g_queue_depth() as u16; @@ -962,13 +981,152 @@ impl SandboxMemoryManager { None => { let layout = self.g2h_virtq_layout()?; let mem_ops = self.host_mem_ops(); - let consumer = - hyperlight_common::virtq::VirtqConsumer::new(layout, mem_ops, HostNotifier); + let consumer = virtq::VirtqConsumer::new(layout, mem_ops, HostNotifier); self.g2h_consumer = Some(consumer); } } Ok(()) } + + /// Initialize the H2G virtqueue consumer. + /// + /// Must be called after scratch bookkeeping is written. Avail suppression is set to Disable + /// so guest prefill/refill operations do not trigger VM exits. + pub(crate) fn init_h2g_consumer(&mut self) -> Result<()> { + match &mut self.h2g_consumer { + Some(consumer) => { + consumer.reset(); + consumer + .set_avail_suppression(virtq::SuppressionKind::Disable) + .map_err(|e| new_error!("H2G avail suppression: {:?}", e))?; + } + None => { + let layout = self.h2g_virtq_layout()?; + let mem_ops = self.host_mem_ops(); + let mut consumer = virtq::VirtqConsumer::new(layout, mem_ops, HostNotifier); + consumer + .set_avail_suppression(virtq::SuppressionKind::Disable) + .map_err(|e| new_error!("H2G avail suppression: {:?}", e))?; + self.h2g_consumer = Some(consumer); + } + } + Ok(()) + } + + /// Prefill the H2G ring with writable descriptors after snapshot restore. + /// + /// Uses a temporary `RingProducer` to write descriptors into the H2G ring + /// so the host consumer can poll them. The guest's `restore_from_ring` + /// will later reconstruct its inflight state from these descriptors. + pub(crate) fn restore_h2g_prefill(&mut self) -> Result<()> { + let pool_gva = match self.h2g_pool_gva { + Some(gva) => gva, + None => return Ok(()), + }; + + let layout = self.h2g_virtq_layout()?; + let mem_ops = self.host_mem_ops(); + let h2g_depth = self.layout.sandbox_memory_config.get_h2g_queue_depth(); + + // Pool size from config + let slot_size = PAGE_SIZE_USIZE; + let pool_size = self.layout.sandbox_memory_config.get_h2g_pool_pages() * PAGE_SIZE_USIZE; + let slot_count = pool_size / slot_size; + + let mut producer = virtq::RingProducer::new(layout, mem_ops); + let prefill_count = core::cmp::min(slot_count, h2g_depth); + + // Write descriptors in reverse order to match the guest's LIFO + // allocation pattern (RecyclePool::alloc pops from the end of + // the free list, so the first prefill gets the highest address). + for i in (0..prefill_count).rev() { + let addr = pool_gva + (i * slot_size) as u64; + producer + .submit_one(addr, slot_size as u32, true) + .map_err(|e| new_error!("H2G prefill submit: {:?}", e))?; + } + + Ok(()) + } + + /// Write a guest function call into the H2G virtqueue. + /// + /// Polls the H2G consumer for a prefilled entry from the guest, + /// writes `VirtqMsgHeader::Request` followed by `buffer` into the + /// writable completion, and completes the entry. + pub(crate) fn write_guest_function_call_virtq(&mut self, buffer: &[u8]) -> Result<()> { + let consumer = self + .h2g_consumer + .as_mut() + .ok_or_else(|| new_error!("H2G consumer not initialized"))?; + + let (entry, completion) = consumer + .poll(8192) + .map_err(|e| new_error!("H2G poll: {:?}", e))? + .ok_or_else(|| new_error!("H2G: no prefilled entry available"))?; + + // Consume the entry data - this should be empty + drop(entry); + + let header = VirtqMsgHeader::new(MsgKind::Request, 0, buffer.len() as u32); + + let virtq::SendCompletion::Writable(mut wc) = completion else { + return Err(new_error!( + "H2G: expected writable completion, got non-writable (ring corruption)" + )); + }; + + wc.write_all(bytemuck::bytes_of(&header)) + .map_err(|e| new_error!("H2G write header: {:?}", e))?; + wc.write_all(buffer) + .map_err(|e| new_error!("H2G write payload: {:?}", e))?; + + consumer + .complete(wc.into()) + .map_err(|e| new_error!("H2G complete: {:?}", e))?; + + Ok(()) + } + + /// Read the H2G result from G2H after the guest halts. + /// + /// The guest submitted the Response on G2H with + pub(crate) fn read_h2g_result_from_g2h(&mut self) -> Result { + let consumer = self + .g2h_consumer + .as_mut() + .ok_or_else(|| new_error!("G2H consumer not initialized"))?; + + let Some((entry, completion)) = consumer + .poll(8192) + .map_err(|e| new_error!("G2H poll for H2G result: {:?}", e))? + else { + return Err(new_error!("G2H: no H2G result entry after halt")); + }; + + let entry_data = entry.data(); + if entry_data.len() < VirtqMsgHeader::SIZE { + return Err(new_error!("G2H: result entry too short")); + } + + let hdr: &VirtqMsgHeader = bytemuck::from_bytes(&entry_data[..VirtqMsgHeader::SIZE]); + if hdr.kind != MsgKind::Response as u8 { + return Err(new_error!( + "G2H: expected Response after halt, got kind={}", + hdr.kind + )); + } + + let payload = &entry_data[VirtqMsgHeader::SIZE..]; + let fcr = FunctionCallResult::try_from(payload) + .map_err(|e| new_error!("G2H: malformed FunctionCallResult: {}", e))?; + + consumer + .complete(completion) + .map_err(|e| new_error!("G2H complete: {:?}", e))?; + + Ok(fcr) + } } #[cfg(test)] diff --git a/src/hyperlight_host/src/sandbox/config.rs b/src/hyperlight_host/src/sandbox/config.rs index a329e5fd5..b3e5fd6d3 100644 --- a/src/hyperlight_host/src/sandbox/config.rs +++ b/src/hyperlight_host/src/sandbox/config.rs @@ -80,9 +80,14 @@ pub struct SandboxConfiguration { /// Number of descriptors for the host-to-guest virtqueue. Must be a power of 2. /// Default: 32 h2g_queue_depth: usize, - /// Number of physical pages to allocate for each virtqueue's buffer pool. + /// Number of physical pages for the G2H (guest-to-host) buffer pool. + /// If not set, derived from `input_data_size` for backward compatibility. /// Default: 8 pages (32KB). - virtq_pool_pages: usize, + g2h_pool_pages: Option, + /// Number of physical pages for the H2G (host-to-guest) buffer pool. + /// If not set, derived from `output_data_size` for backward compatibility. + /// Default: 4 page (16KB). + h2g_pool_pages: Option, } impl SandboxConfiguration { @@ -106,8 +111,10 @@ impl SandboxConfiguration { pub const DEFAULT_G2H_QUEUE_DEPTH: usize = 64; /// The default H2G virtqueue depth (number of descriptors, must be power of 2) pub const DEFAULT_H2G_QUEUE_DEPTH: usize = 32; - /// The default number of physical pages per virtqueue buffer pool - pub const DEFAULT_VIRTQ_POOL_PAGES: usize = 8; + /// The default number of G2H buffer pool pages + pub const DEFAULT_G2H_POOL_PAGES: usize = 8; + /// The default number of H2G buffer pool pages + pub const DEFAULT_H2G_POOL_PAGES: usize = 4; #[allow(clippy::too_many_arguments)] /// Create a new configuration for a sandbox with the given sizes. @@ -131,7 +138,8 @@ impl SandboxConfiguration { interrupt_vcpu_sigrtmin_offset, g2h_queue_depth: Self::DEFAULT_G2H_QUEUE_DEPTH, h2g_queue_depth: Self::DEFAULT_H2G_QUEUE_DEPTH, - virtq_pool_pages: Self::DEFAULT_VIRTQ_POOL_PAGES, + g2h_pool_pages: None, + h2g_pool_pages: None, #[cfg(gdb)] guest_debug_info, #[cfg(crashdump)] @@ -139,15 +147,21 @@ impl SandboxConfiguration { } } - /// Set the size of the memory buffer that is made available for input to the guest - /// the minimum value is MIN_INPUT_SIZE + /// Set the size of the legacy input data buffer (host-to-guest). + /// + /// Deprecated: use [`set_h2g_pool_pages`](Self::set_h2g_pool_pages) instead. + /// When `h2g_pool_pages` is not set, the H2G pool size is derived + /// from this value for backward compatibility. #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub fn set_input_data_size(&mut self, input_data_size: usize) { self.input_data_size = max(input_data_size, Self::MIN_INPUT_SIZE); } - /// Set the size of the memory buffer that is made available for output from the guest - /// the minimum value is MIN_OUTPUT_SIZE + /// Set the size of the legacy output data buffer (guest-to-host). + /// + /// Deprecated: use [`set_g2h_pool_pages`](Self::set_g2h_pool_pages) instead. + /// When `g2h_pool_pages` is not set, the G2H pool size is derived + /// from this value for backward compatibility. #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub fn set_output_data_size(&mut self, output_data_size: usize) { self.output_data_size = max(output_data_size, Self::MIN_OUTPUT_SIZE); @@ -228,33 +242,65 @@ impl SandboxConfiguration { } /// Get the G2H virtqueue depth (number of descriptors). + #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub fn get_g2h_queue_depth(&self) -> usize { self.g2h_queue_depth } /// Get the H2G virtqueue depth (number of descriptors). + #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub fn get_h2g_queue_depth(&self) -> usize { self.h2g_queue_depth } - /// Get the number of physical pages per virtqueue buffer pool. - pub fn get_virtq_pool_pages(&self) -> usize { - self.virtq_pool_pages - } - /// Set the G2H virtqueue depth (number of descriptors, must be power of 2). + #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub fn set_g2h_queue_depth(&mut self, depth: usize) { self.g2h_queue_depth = depth; } /// Set the H2G virtqueue depth (number of descriptors, must be power of 2). + #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub fn set_h2g_queue_depth(&mut self, depth: usize) { self.h2g_queue_depth = depth; } - /// Set the number of physical pages per virtqueue buffer pool. - pub fn set_virtq_pool_pages(&mut self, pages: usize) { - self.virtq_pool_pages = pages; + /// Get the number of G2H buffer pool pages. + /// Falls back to deriving from `output_data_size` if not explicitly set + /// (output = guest-to-host direction). + #[instrument(skip_all, parent = Span::current(), level= "Trace")] + pub fn get_g2h_pool_pages(&self) -> usize { + self.g2h_pool_pages.unwrap_or_else(|| { + let pages = self + .output_data_size + .div_ceil(hyperlight_common::mem::PAGE_SIZE_USIZE); + pages.max(Self::DEFAULT_G2H_POOL_PAGES) + }) + } + + /// Get the number of H2G buffer pool pages. + /// Falls back to deriving from `input_data_size` if not explicitly set + /// (input = host-to-guest direction). + #[instrument(skip_all, parent = Span::current(), level= "Trace")] + pub fn get_h2g_pool_pages(&self) -> usize { + self.h2g_pool_pages.unwrap_or_else(|| { + let pages = self + .input_data_size + .div_ceil(hyperlight_common::mem::PAGE_SIZE_USIZE); + pages.max(Self::DEFAULT_H2G_POOL_PAGES) + }) + } + + /// Set the number of G2H buffer pool pages. + #[instrument(skip_all, parent = Span::current(), level= "Trace")] + pub fn set_g2h_pool_pages(&mut self, pages: usize) { + self.g2h_pool_pages = Some(pages); + } + + /// Set the number of H2G buffer pool pages. + #[instrument(skip_all, parent = Span::current(), level= "Trace")] + pub fn set_h2g_pool_pages(&mut self, pages: usize) { + self.h2g_pool_pages = Some(pages); } /// Set the size of the scratch regiong diff --git a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs index 8b3cf8db2..4a097358c 100644 --- a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs +++ b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs @@ -737,7 +737,7 @@ impl MultiUseSandbox { let mut builder = FlatBufferBuilder::with_capacity(estimated_capacity); let buffer = fc.encode(&mut builder); - self.mem_mgr.write_guest_function_call(buffer)?; + self.mem_mgr.write_guest_function_call_virtq(buffer)?; let dispatch_res = self.vm.dispatch_call_from_host( &mut self.mem_mgr, @@ -754,7 +754,7 @@ impl MultiUseSandbox { return Err(error); } - let guest_result = self.mem_mgr.get_guest_function_call_result()?.into_inner(); + let guest_result = self.mem_mgr.read_h2g_result_from_g2h()?.into_inner(); match guest_result { Ok(val) => Ok(val), diff --git a/src/hyperlight_host/src/sandbox/outb.rs b/src/hyperlight_host/src/sandbox/outb.rs index aa40bec3d..0e11409ad 100644 --- a/src/hyperlight_host/src/sandbox/outb.rs +++ b/src/hyperlight_host/src/sandbox/outb.rs @@ -21,6 +21,8 @@ use hyperlight_common::flatbuffer_wrappers::function_types::{FunctionCallResult, use hyperlight_common::flatbuffer_wrappers::guest_error::{ErrorCode, GuestError}; use hyperlight_common::flatbuffer_wrappers::guest_log_data::GuestLogData; use hyperlight_common::outb::{Exception, OutBAction}; +use hyperlight_common::virtq::msg::{MsgKind, VirtqMsgHeader}; +use hyperlight_common::virtq::{self}; use log::{Level, Record}; use tracing::{Span, instrument}; use tracing_log::format_trace; @@ -186,29 +188,39 @@ fn outb_virtq_call( mem_mgr: &mut SandboxMemoryManager, host_funcs: &Arc>, ) -> Result<(), HandleOutbError> { - use hyperlight_common::virtq::msg::{MsgKind, VirtqMsgHeader}; - let consumer = mem_mgr.g2h_consumer.as_mut().ok_or_else(|| { HandleOutbError::ReadHostFunctionCall("G2H consumer not initialized".into()) })?; - let (entry, completion) = consumer + let Some((entry, completion)) = consumer .poll(8192) .map_err(|e| HandleOutbError::ReadHostFunctionCall(format!("G2H poll: {e}")))? - .ok_or_else(|| HandleOutbError::ReadHostFunctionCall("G2H poll: no entry".into()))?; + else { + // No G2H entry - can happen when guest H2G prefill + // triggers VirtqNotify before suppression is set. + return Ok(()); + }; - // Parse: skip VirtqMsgHeader, deserialize FunctionCall from remainder let entry_data = entry.data(); if entry_data.len() < VirtqMsgHeader::SIZE { return Err(HandleOutbError::ReadHostFunctionCall( "G2H entry too short".into(), )); } + let hdr: VirtqMsgHeader = *bytemuck::from_bytes(&entry_data[..VirtqMsgHeader::SIZE]); let payload = &entry_data[VirtqMsgHeader::SIZE..]; + + // TODO(virtq): Only Requests (host function callbacks) arrive via outb. + if hdr.kind != MsgKind::Request as u8 { + return Err(HandleOutbError::ReadHostFunctionCall(format!( + "G2H: expected Request via outb, got kind={}", + hdr.kind + ))); + } + let call = FunctionCall::try_from(payload) .map_err(|e| HandleOutbError::ReadHostFunctionCall(e.to_string()))?; - // Dispatch the host function (same as CallFunction path) let name = call.function_name.clone(); let args: Vec = call.parameters.unwrap_or(vec![]); let res = host_funcs @@ -226,22 +238,19 @@ fn outb_virtq_call( let resp_header_bytes = bytemuck::bytes_of(&resp_header); // Write response into the completion buffer - match completion { - hyperlight_common::virtq::SendCompletion::Writable(mut wc) => { - wc.write_all(resp_header_bytes) - .map_err(|e| HandleOutbError::WriteHostFunctionResponse(format!("{e}")))?; - wc.write_all(result_payload) - .map_err(|e| HandleOutbError::WriteHostFunctionResponse(format!("{e}")))?; - consumer - .complete(wc.into()) - .map_err(|e| HandleOutbError::WriteHostFunctionResponse(format!("{e}")))?; - } - hyperlight_common::virtq::SendCompletion::Ack(ack) => { - consumer - .complete(ack.into()) - .map_err(|e| HandleOutbError::WriteHostFunctionResponse(format!("{e}")))?; - } - } + let virtq::SendCompletion::Writable(mut wc) = completion else { + return Err(HandleOutbError::WriteHostFunctionResponse( + "G2H: expected writable completion, got ack (ring corruption)".into(), + )); + }; + + wc.write_all(resp_header_bytes) + .map_err(|e| HandleOutbError::WriteHostFunctionResponse(format!("{e}")))?; + wc.write_all(result_payload) + .map_err(|e| HandleOutbError::WriteHostFunctionResponse(format!("{e}")))?; + consumer + .complete(wc.into()) + .map_err(|e| HandleOutbError::WriteHostFunctionResponse(format!("{e}")))?; Ok(()) } diff --git a/src/hyperlight_host/src/sandbox/uninitialized_evolve.rs b/src/hyperlight_host/src/sandbox/uninitialized_evolve.rs index 7f0cc1c0d..32ec1a2b6 100644 --- a/src/hyperlight_host/src/sandbox/uninitialized_evolve.rs +++ b/src/hyperlight_host/src/sandbox/uninitialized_evolve.rs @@ -16,6 +16,7 @@ limitations under the License. #[cfg(gdb)] use std::sync::{Arc, Mutex}; +use hyperlight_common::layout::SCRATCH_TOP_H2G_POOL_GVA_OFFSET; use rand::RngExt; use tracing::{Span, instrument}; @@ -26,7 +27,7 @@ use crate::hypervisor::hyperlight_vm::{HyperlightVm, HyperlightVmError}; use crate::mem::exe::LoadInfo; use crate::mem::mgr::SandboxMemoryManager; use crate::mem::ptr::RawPtr; -use crate::mem::shared_mem::GuestSharedMemory; +use crate::mem::shared_mem::{GuestSharedMemory, SharedMemory}; #[cfg(gdb)] use crate::sandbox::config::DebugInfo; #[cfg(feature = "mem_profile")] @@ -131,6 +132,18 @@ pub(super) fn evolve_impl_multi_use(u_sbox: UninitializedSandbox) -> Result(offset) + && gva != 0 + { + hshm.h2g_pool_gva = Some(gva); + } + } + #[cfg(gdb)] let dbg_mem_wrapper = Arc::new(Mutex::new(hshm.clone())); diff --git a/src/hyperlight_host/tests/integration_test.rs b/src/hyperlight_host/tests/integration_test.rs index 9e7fe2c91..928b63466 100644 --- a/src/hyperlight_host/tests/integration_test.rs +++ b/src/hyperlight_host/tests/integration_test.rs @@ -546,7 +546,8 @@ fn guest_malloc_abort() { cfg.set_heap_size(heap_size); cfg.set_g2h_queue_depth(2); cfg.set_h2g_queue_depth(2); - cfg.set_virtq_pool_pages(2); + cfg.set_g2h_pool_pages(3); + cfg.set_h2g_pool_pages(1); with_rust_sandbox_cfg(cfg, |mut sbox2| { let err = sbox2 .call::( @@ -625,7 +626,8 @@ fn guest_panic_no_alloc() { cfg.set_heap_size(heap_size); cfg.set_g2h_queue_depth(2); cfg.set_h2g_queue_depth(2); - cfg.set_virtq_pool_pages(2); + cfg.set_g2h_pool_pages(3); + cfg.set_h2g_pool_pages(1); with_rust_sandbox_cfg(cfg, |mut sbox| { let res = sbox .call::( @@ -1686,7 +1688,7 @@ fn exception_handler_installation_and_validation() { #[test] fn fill_heap_and_cause_exception() { let mut cfg = SandboxConfiguration::default(); - cfg.set_virtq_pool_pages(2); + cfg.set_scratch_size(0x60000); with_rust_sandbox_cfg(cfg, |mut sandbox| { let result = sandbox.call::<()>("FillHeapAndCauseException", ()); From b9f805e3f62018544656e3c91143a85487547275 Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Fri, 3 Apr 2026 14:16:11 +0200 Subject: [PATCH 08/31] feat(virtq): cleanup send + sync bounds Signed-off-by: Tomasz Andrzejak --- Cargo.lock | 7 - src/hyperlight_common/Cargo.toml | 1 - src/hyperlight_common/src/virtq/buffer.rs | 158 +++++++++ src/hyperlight_common/src/virtq/mod.rs | 5 +- src/hyperlight_common/src/virtq/pool.rs | 312 ++++++++---------- src/hyperlight_common/src/virtq/producer.rs | 8 +- .../src/virtq/recycle_pool.rs | 120 ------- src/hyperlight_guest/src/virtq/context.rs | 5 +- src/tests/rust_guests/dummyguest/Cargo.lock | 7 - src/tests/rust_guests/simpleguest/Cargo.lock | 7 - src/tests/rust_guests/witguest/Cargo.lock | 7 - 11 files changed, 299 insertions(+), 338 deletions(-) create mode 100644 src/hyperlight_common/src/virtq/buffer.rs delete mode 100644 src/hyperlight_common/src/virtq/recycle_pool.rs diff --git a/Cargo.lock b/Cargo.lock index a80ae0959..58f446f61 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -181,12 +181,6 @@ version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" -[[package]] -name = "atomic_refcell" -version = "0.1.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21e4227379beff4205943696e6c3e0cd809bacdf3f0edd6e3dd153e2269571a4" - [[package]] name = "autocfg" version = "1.5.0" @@ -1495,7 +1489,6 @@ version = "0.15.0" dependencies = [ "anyhow", "arbitrary", - "atomic_refcell", "bitflags 2.11.1", "bytemuck", "bytes", diff --git a/src/hyperlight_common/Cargo.toml b/src/hyperlight_common/Cargo.toml index 32a78026f..1b688cd54 100644 --- a/src/hyperlight_common/Cargo.toml +++ b/src/hyperlight_common/Cargo.toml @@ -17,7 +17,6 @@ workspace = true [dependencies] arbitrary = {version = "1.4.2", optional = true, features = ["derive"]} anyhow = { version = "1.0.102", default-features = false } -atomic_refcell = "0.1.13" bitflags = "2.10.0" bytemuck = { version = "1.24", features = ["derive"] } bytes = { version = "1", default-features = false } diff --git a/src/hyperlight_common/src/virtq/buffer.rs b/src/hyperlight_common/src/virtq/buffer.rs new file mode 100644 index 000000000..238775c6d --- /dev/null +++ b/src/hyperlight_common/src/virtq/buffer.rs @@ -0,0 +1,158 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Buffer allocation traits and shared types for virtqueue buffer management. + +use alloc::rc::Rc; +use alloc::sync::Arc; + +use thiserror::Error; + +use super::access::MemOps; + +#[derive(Debug, Error, Copy, Clone)] +pub enum AllocError { + #[error("Invalid region addr {0}")] + InvalidAlign(u64), + #[error("Invalid free addr {0} and size {1}")] + InvalidFree(u64, usize), + #[error("Invalid argument")] + InvalidArg, + #[error("Empty region")] + EmptyRegion, + #[error("Out of memory")] + OutOfMemory, + #[error("Overflow")] + Overflow, +} + +/// Allocation result +#[derive(Debug, Clone, Copy)] +pub struct Allocation { + /// Starting address of the allocation + pub addr: u64, + /// Length of the allocation in bytes rounded up to slab size + pub len: usize, +} + +/// Trait for buffer providers. +pub trait BufferProvider { + /// Allocate at least `len` bytes. + fn alloc(&self, len: usize) -> Result; + + /// Free a previously allocated block. + fn dealloc(&self, alloc: Allocation) -> Result<(), AllocError>; + + /// Resize by trying in-place grow; otherwise reserve a new block and free old. + fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result; + + /// Reset the pool to initial state. + fn reset(&self) {} +} + +impl BufferProvider for Rc { + fn alloc(&self, len: usize) -> Result { + (**self).alloc(len) + } + fn dealloc(&self, alloc: Allocation) -> Result<(), AllocError> { + (**self).dealloc(alloc) + } + fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result { + (**self).resize(old_alloc, new_len) + } + fn reset(&self) { + (**self).reset() + } +} + +impl BufferProvider for Arc { + fn alloc(&self, len: usize) -> Result { + (**self).alloc(len) + } + fn dealloc(&self, alloc: Allocation) -> Result<(), AllocError> { + (**self).dealloc(alloc) + } + fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result { + (**self).resize(old_alloc, new_len) + } + fn reset(&self) { + (**self).reset() + } +} + +/// The owner of a mapped buffer, ensuring its lifetime. +/// +/// Holds a pool allocation and provides direct access to the underlying +/// shared memory via [`MemOps::as_slice`]. Implements `AsRef<[u8]>` so it +/// can be used with [`Bytes::from_owner`](bytes::Bytes::from_owner) for +/// zero-copy `Bytes` backed by shared memory. +/// +/// When dropped, the allocation is returned to the pool. +#[derive(Debug, Clone)] +pub struct BufferOwner { + pub(crate) pool: P, + pub(crate) mem: M, + pub(crate) alloc: Allocation, + pub(crate) written: usize, +} + +impl Drop for BufferOwner { + fn drop(&mut self) { + let _ = self.pool.dealloc(self.alloc); + } +} + +impl AsRef<[u8]> for BufferOwner { + fn as_ref(&self) -> &[u8] { + let len = self.written.min(self.alloc.len); + // Safety: BufferOwner keeps both the pool allocation and the M + // alive, so the memory region is valid. Protocol-level descriptor + // ownership transfer guarantees no concurrent writes. + match unsafe { self.mem.as_slice(self.alloc.addr, len) } { + Ok(slice) => slice, + Err(_) => &[], + } + } +} + +/// A guard that runs a cleanup function when dropped, unless dismissed. +pub struct AllocGuard(Option<(Allocation, F)>); + +impl AllocGuard { + pub fn new(alloc: Allocation, cleanup: F) -> Self { + Self(Some((alloc, cleanup))) + } + + pub fn release(mut self) -> Allocation { + self.0.take().unwrap().0 + } +} + +impl core::ops::Deref for AllocGuard { + type Target = Allocation; + + fn deref(&self) -> &Allocation { + &self.0.as_ref().unwrap().0 + } +} + +impl Drop for AllocGuard { + fn drop(&mut self) { + if let Some((alloc, cleanup)) = self.0.take() { + cleanup(alloc) + } + } +} diff --git a/src/hyperlight_common/src/virtq/mod.rs b/src/hyperlight_common/src/virtq/mod.rs index 9803977b0..09d69268c 100644 --- a/src/hyperlight_common/src/virtq/mod.rs +++ b/src/hyperlight_common/src/virtq/mod.rs @@ -151,18 +151,19 @@ limitations under the License. //! ``` mod access; +mod buffer; mod consumer; mod desc; mod event; pub mod msg; mod pool; mod producer; -pub mod recycle_pool; mod ring; use core::num::NonZeroU16; pub use access::*; +pub use buffer::*; pub use consumer::*; pub use desc::*; pub use event::*; @@ -992,7 +993,7 @@ mod fuzz { } } - unsafe impl MemOps for Arc { + unsafe impl MemOps for LoomMem { type Error = MemErr; fn read(&self, addr: u64, dst: &mut [u8]) -> Result<(), Self::Error> { diff --git a/src/hyperlight_common/src/virtq/pool.rs b/src/hyperlight_common/src/virtq/pool.rs index cf0915fdf..42325a56c 100644 --- a/src/hyperlight_common/src/virtq/pool.rs +++ b/src/hyperlight_common/src/virtq/pool.rs @@ -13,50 +13,24 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -//! Simple bitmap-based allocator for virtio buffer management. +//! Buffer pool implementations for virtqueue buffer management. //! -//! This module provides two layers: +//! This module provides concrete buffer allocators: //! -//! - [`Slab`] - a fixed-size region allocator with a power-of-two slot size `N`, -//! backed by a flat bitmap (`FixedBitSet`). -//! - [`BufferPool`] - a two-tier pool that composes two slabs: one with small -//! slots (e.g. 256 bytes) for control messages / small descriptors, and one -//! with page-sized slots (e.g. 4 KiB) for data buffers. +//! - [`BufferPool`] - a two-tier bitmap pool with small and large slabs, +//! intended for G2H descriptors where allocation sizes vary. +//! - [`RecyclePool`] - a fixed-size free-list recycler for H2G prefill +//! entries where all buffers are the same size. //! -//! # Design and algorithm +//! Both implement [`BufferProvider`] from the [`super::buffer`] module. +//! +//! # BufferPool design //! //! The core allocation strategy is a bitmap allocator that performs a linear //! search over the bitmap, but implemented via `fixedbitset`'s SIMD iteration -//! over zero bits. This is conceptually simpler than tree-based allocators -//! (e.g. linked lists or bitmaps representing a tree as in -//! ), yet for "moderate" region sizes it can -//! be faster in practice: -//! -//! - `FixedBitSet::zeroes()` and related methods use word/SIMD operations to -//! skip over runs of set bits, so the linear search is over words rather than -//! individual bits. -//! - We scan for a contiguous run of free bits corresponding to the required -//! number of slots; no auxiliary tree structure is maintained. -//! -//! The tree-based approach (bitmap encoding a tree and doing a binary search -//! in O(log(n)) time) is a natural next step if larger regions or stricter worst -//! case bounds are required; switching to such a representation should be -//! relatively straightforward since all allocation paths go through a single -//! `find_slots` function. -//! -//! # Locality characteristics +//! over zero bits. //! -//! The allocator tends to preserve spatial locality: -//! -//! - It searches from low indices upward, returning the first run of free -//! slots large enough for the request. Slots are merged if necessary. -//! - Freed runs are cached in `last_free_run` and reused eagerly, which -//! introduces a mild LIFO behavior for recently freed blocks. -//! - As a result, consecutive allocations are likely to end up in nearby slots, -//! which keeps virtqueue descriptors, control buffers, and data buffers -//! clustered in memory and helps cache performance. -//! -//! # Two-tier buffer pool +//! # Two-tier layout //! //! [`BufferPool`] divides the underlying region into two slabs with different //! slot sizes: @@ -66,159 +40,40 @@ limitations under the License. //! small structures. Small allocations first try this tier. //! - The upper tier (`Slab`, default `U = 4096`) uses page sized slots //! and is intended for larger data buffers. -//! -//! The split of the region is currently fixed at a constant fraction -//! (`LOWER_FRACTION`) for the lower slab and the remainder for the upper slab. -//! -//! Allocation policy: -//! -//! - Requests `<= L` bytes are first attempted in the lower slab; on -//! `OutOfMemory` they fall back to the upper slab. -//! - Larger requests go directly to the upper slab. -//! - [`BufferPool::resize`] will try to grow or shrink in place within the -//! owning slab (`Slab::resize`) but will never move allocations between -//! slabs. - -use alloc::sync::Arc; + +use alloc::rc::Rc; +use core::cell::RefCell; use core::cmp::Ordering; +use core::ops::Deref; -use atomic_refcell::AtomicRefCell; use fixedbitset::FixedBitSet; -use thiserror::Error; - -use super::access::MemOps; - -#[derive(Debug, Error, Copy, Clone)] -pub enum AllocError { - #[error("Invalid region addr {0}")] - InvalidAlign(u64), - #[error("Invalid free addr {0} and size {1}")] - InvalidFree(u64, usize), - #[error("Invalid argument")] - InvalidArg, - #[error("Empty region")] - EmptyRegion, - #[error("Out of memory")] - OutOfMemory, - #[error("Overflow")] - Overflow, -} - -/// Allocation result -#[derive(Debug, Clone, Copy)] -pub struct Allocation { - /// Starting address of the allocation - pub addr: u64, - /// Length of the allocation in bytes rounded up to slab size - pub len: usize, -} - -/// Trait for buffer providers. -pub trait BufferProvider { - /// Allocate at least `len` bytes. - fn alloc(&self, len: usize) -> Result; - - /// Free a previously allocated block. - fn dealloc(&self, alloc: Allocation) -> Result<(), AllocError>; - - /// Resize by trying in-place grow; otherwise reserve a new block and free old. - fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result; - - /// Reset the pool to initial state. - fn reset(&self) {} -} - -impl BufferProvider for alloc::rc::Rc { - fn alloc(&self, len: usize) -> Result { - (**self).alloc(len) - } - fn dealloc(&self, alloc: Allocation) -> Result<(), AllocError> { - (**self).dealloc(alloc) - } - fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result { - (**self).resize(old_alloc, new_len) - } - fn reset(&self) { - (**self).reset() - } -} +use smallvec::SmallVec; -impl BufferProvider for Arc { - fn alloc(&self, len: usize) -> Result { - (**self).alloc(len) - } - fn dealloc(&self, alloc: Allocation) -> Result<(), AllocError> { - (**self).dealloc(alloc) - } - fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result { - (**self).resize(old_alloc, new_len) - } - fn reset(&self) { - (**self).reset() - } -} +use super::buffer::{AllocError, Allocation, BufferProvider}; -/// The owner of a mapped buffer, ensuring its lifetime. +/// Wrapper asserting `Send + Sync` for single-threaded contexts. /// -/// Holds a pool allocation and provides direct access to the underlying -/// shared memory via [`MemOps::as_slice`]. Implements `AsRef<[u8]>` so it -/// can be used with [`Bytes::from_owner`](bytes::Bytes::from_owner) for -/// zero-copy `Bytes` backed by shared memory. +/// # Safety /// -/// When dropped, the allocation is returned to the pool. -#[derive(Debug, Clone)] -pub struct BufferOwner { - pub(crate) pool: P, - pub(crate) mem: M, - pub(crate) alloc: Allocation, - pub(crate) written: usize, -} - -impl Drop for BufferOwner { - fn drop(&mut self) { - let _ = self.pool.dealloc(self.alloc); - } -} - -impl AsRef<[u8]> for BufferOwner { - fn as_ref(&self) -> &[u8] { - let len = self.written.min(self.alloc.len); - // Safety: BufferOwner keeps both the pool allocation and the M - // alive, so the memory region is valid. Protocol-level descriptor - // ownership transfer guarantees no concurrent writes. - match unsafe { self.mem.as_slice(self.alloc.addr, len) } { - Ok(slice) => slice, - Err(_) => &[], - } - } -} - -/// A guard that runs a cleanup function when dropped, unless dismissed. -pub struct AllocGuard(Option<(Allocation, F)>); - -impl AllocGuard { - pub fn new(alloc: Allocation, cleanup: F) -> Self { - Self(Some((alloc, cleanup))) - } - - pub fn release(mut self) -> Allocation { - self.0.take().unwrap().0 - } -} +/// The wrapped value must only be accessed from a single thread. +#[derive(Debug)] +pub(super) struct SyncWrap(pub(super) T); -impl core::ops::Deref for AllocGuard { - type Target = Allocation; +// SAFETY: The wrapped value must only be accessed from a single thread. +unsafe impl Send for SyncWrap {} +// SAFETY: The wrapped value must only be accessed from a single thread. +unsafe impl Sync for SyncWrap {} - fn deref(&self) -> &Allocation { - &self.0.as_ref().unwrap().0 +impl Clone for SyncWrap { + fn clone(&self) -> Self { + Self(self.0.clone()) } } -impl Drop for AllocGuard { - fn drop(&mut self) { - if let Some((alloc, cleanup)) = self.0.take() { - cleanup(alloc) - } +impl Deref for SyncWrap { + type Target = T; + fn deref(&self) -> &T { + &self.0 } } @@ -550,8 +405,7 @@ struct Inner { /// Two tier buffer pool with small and large slabs. #[derive(Debug, Clone)] pub struct BufferPool { - // TODO: Use Rc instead, relax Sync + Send bounds - inner: Arc>>, + inner: SyncWrap>>>, } impl BufferPool { @@ -559,7 +413,7 @@ impl BufferPool { pub fn new(base_addr: u64, region_len: usize) -> Result { let inner = Inner::::new(base_addr, region_len)?; Ok(Self { - inner: Arc::new(inner.into()), + inner: SyncWrap(Rc::new(RefCell::new(inner))), }) } } @@ -700,6 +554,102 @@ impl BufferProvider for BufferPoolSync { } } +struct RecyclePoolInner { + base_addr: u64, + slot_size: usize, + count: usize, + free: SmallVec<[u64; 64]>, +} + +/// A recycling buffer provider with fixed-size slots. +/// +/// Unlike [`BufferPool`] which uses a bitmap allocator, this holds a +/// fixed set of same-sized buffer addresses in a free list. Alloc and +/// dealloc are O(1). Intended for H2G writable buffers that are +/// pre-allocated once and recycled after each use. +#[derive(Clone)] +pub struct RecyclePool { + inner: SyncWrap>>, +} + +impl RecyclePool { + /// Create a new recycling pool by carving `base..base+region_len` into slots of `slot_size` bytes. + pub fn new(base_addr: u64, region_len: usize, slot_size: usize) -> Result { + if slot_size == 0 { + return Err(AllocError::InvalidArg); + } + + let count = region_len / slot_size; + if count == 0 { + return Err(AllocError::EmptyRegion); + } + + let mut free = SmallVec::with_capacity(count); + for i in 0..count { + free.push(base_addr + (i * slot_size) as u64); + } + + let inner = RefCell::new(RecyclePoolInner { + base_addr, + slot_size, + count, + free, + }); + + Ok(Self { + inner: SyncWrap(Rc::new(inner)), + }) + } + + /// Number of free slots. + pub fn num_free(&self) -> usize { + self.inner.borrow().free.len() + } +} + +impl BufferProvider for RecyclePool { + fn alloc(&self, len: usize) -> Result { + let mut inner = self.inner.borrow_mut(); + if len > inner.slot_size { + return Err(AllocError::OutOfMemory); + } + + let addr = inner.free.pop().ok_or(AllocError::OutOfMemory)?; + + Ok(Allocation { + addr, + len: inner.slot_size, + }) + } + + fn dealloc(&self, alloc: Allocation) -> Result<(), AllocError> { + let mut inner = self.inner.borrow_mut(); + inner.free.push(alloc.addr); + Ok(()) + } + + fn resize(&self, old: Allocation, new_len: usize) -> Result { + let inner = self.inner.borrow(); + if new_len > inner.slot_size { + return Err(AllocError::OutOfMemory); + } + Ok(old) + } + + fn reset(&self) { + let mut inner = self.inner.borrow_mut(); + let base = inner.base_addr; + let slot = inner.slot_size; + let count = inner.count; + + inner.free.clear(); + + for i in 0..count { + inner.free.push(base + (i * slot) as u64); + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/hyperlight_common/src/virtq/producer.rs b/src/hyperlight_common/src/virtq/producer.rs index 5e6a7edf1..276a2ff3b 100644 --- a/src/hyperlight_common/src/virtq/producer.rs +++ b/src/hyperlight_common/src/virtq/producer.rs @@ -168,8 +168,8 @@ where /// wrote more data than the completion buffer capacity pub fn poll(&mut self) -> Result, VirtqError> where - M: Send + Sync + 'static, - P: Send + Sync + 'static, + M: Send + 'static, + P: Send + 'static, { let used = match self.inner.poll_used() { Ok(u) => u, @@ -234,8 +234,8 @@ where /// ``` pub fn drain(&mut self, mut f: impl FnMut(Token, Bytes)) -> Result<(), VirtqError> where - M: Send + Sync + 'static, - P: Send + Sync + 'static, + M: Send + 'static, + P: Send + 'static, { while let Some(cqe) = self.poll()? { f(cqe.token, cqe.data); diff --git a/src/hyperlight_common/src/virtq/recycle_pool.rs b/src/hyperlight_common/src/virtq/recycle_pool.rs deleted file mode 100644 index 4bcf9978a..000000000 --- a/src/hyperlight_common/src/virtq/recycle_pool.rs +++ /dev/null @@ -1,120 +0,0 @@ -/* -Copyright 2026 The Hyperlight Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -//! A simple fixed-size buffer recycler for H2G prefill entries. -//! -//! Unlike [`super::BufferPool`] which uses a bitmap allocator, this -//! holds a fixed set of same-sized buffer addresses in a free list. -//! Alloc and dealloc are O(1). Intended for H2G writable buffers -//! that are pre-allocated once and recycled after each use. - -use alloc::sync::Arc; - -use atomic_refcell::AtomicRefCell; -use smallvec::SmallVec; - -use super::{AllocError, Allocation, BufferProvider}; - -/// A recycling buffer provider with fixed-size slots. -#[derive(Clone)] -pub struct RecyclePool { - inner: Arc>, -} - -struct RecyclePoolInner { - base_addr: u64, - slot_size: usize, - count: usize, - free: SmallVec<[u64; 64]>, -} - -impl RecyclePool { - /// Create a new recycling pool by carving `base..base+region_len` into slots of `slot_size` bytes. - pub fn new(base_addr: u64, region_len: usize, slot_size: usize) -> Result { - if slot_size == 0 { - return Err(AllocError::InvalidArg); - } - - let count = region_len / slot_size; - if count == 0 { - return Err(AllocError::EmptyRegion); - } - - let mut free = SmallVec::with_capacity(count); - for i in 0..count { - free.push(base_addr + (i * slot_size) as u64); - } - - let inner = AtomicRefCell::new(RecyclePoolInner { - base_addr, - slot_size, - count, - free, - }); - - Ok(Self { - inner: inner.into(), - }) - } - - /// Number of free slots. - pub fn num_free(&self) -> usize { - self.inner.borrow().free.len() - } -} - -impl BufferProvider for RecyclePool { - fn alloc(&self, len: usize) -> Result { - let mut inner = self.inner.borrow_mut(); - if len > inner.slot_size { - return Err(AllocError::OutOfMemory); - } - - let addr = inner.free.pop().ok_or(AllocError::OutOfMemory)?; - - Ok(Allocation { - addr, - len: inner.slot_size, - }) - } - - fn dealloc(&self, alloc: Allocation) -> Result<(), AllocError> { - let mut inner = self.inner.borrow_mut(); - inner.free.push(alloc.addr); - Ok(()) - } - - fn resize(&self, old: Allocation, new_len: usize) -> Result { - let inner = self.inner.borrow(); - if new_len > inner.slot_size { - return Err(AllocError::OutOfMemory); - } - Ok(old) - } - - fn reset(&self) { - let mut inner = self.inner.borrow_mut(); - let base = inner.base_addr; - let slot = inner.slot_size; - let count = inner.count; - - inner.free.clear(); - - for i in 0..count { - inner.free.push(base + (i * slot) as u64); - } - } -} diff --git a/src/hyperlight_guest/src/virtq/context.rs b/src/hyperlight_guest/src/virtq/context.rs index 53d90068c..c3373d220 100644 --- a/src/hyperlight_guest/src/virtq/context.rs +++ b/src/hyperlight_guest/src/virtq/context.rs @@ -29,8 +29,9 @@ use hyperlight_common::flatbuffer_wrappers::util::estimate_flatbuffer_capacity; use hyperlight_common::mem::PAGE_SIZE_USIZE; use hyperlight_common::outb::OutBAction; use hyperlight_common::virtq::msg::{MsgKind, VirtqMsgHeader}; -use hyperlight_common::virtq::recycle_pool::RecyclePool; -use hyperlight_common::virtq::{BufferPool, Layout, Notifier, QueueStats, VirtqProducer}; +use hyperlight_common::virtq::{ + BufferPool, Layout, Notifier, QueueStats, RecyclePool, VirtqProducer, +}; use super::GuestMemOps; use crate::bail; diff --git a/src/tests/rust_guests/dummyguest/Cargo.lock b/src/tests/rust_guests/dummyguest/Cargo.lock index 0f1efc6bd..060a340fa 100644 --- a/src/tests/rust_guests/dummyguest/Cargo.lock +++ b/src/tests/rust_guests/dummyguest/Cargo.lock @@ -17,12 +17,6 @@ version = "1.0.102" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" -[[package]] -name = "atomic_refcell" -version = "0.1.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21e4227379beff4205943696e6c3e0cd809bacdf3f0edd6e3dd153e2269571a4" - [[package]] name = "bindgen" version = "0.71.1" @@ -179,7 +173,6 @@ name = "hyperlight-common" version = "0.15.0" dependencies = [ "anyhow", - "atomic_refcell", "bitflags", "bytemuck", "bytes", diff --git a/src/tests/rust_guests/simpleguest/Cargo.lock b/src/tests/rust_guests/simpleguest/Cargo.lock index 48adf7195..7327a0332 100644 --- a/src/tests/rust_guests/simpleguest/Cargo.lock +++ b/src/tests/rust_guests/simpleguest/Cargo.lock @@ -17,12 +17,6 @@ version = "1.0.102" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" -[[package]] -name = "atomic_refcell" -version = "0.1.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21e4227379beff4205943696e6c3e0cd809bacdf3f0edd6e3dd153e2269571a4" - [[package]] name = "bindgen" version = "0.71.1" @@ -171,7 +165,6 @@ name = "hyperlight-common" version = "0.15.0" dependencies = [ "anyhow", - "atomic_refcell", "bitflags", "bytemuck", "bytes", diff --git a/src/tests/rust_guests/witguest/Cargo.lock b/src/tests/rust_guests/witguest/Cargo.lock index cbefd08cf..7e7516cd0 100644 --- a/src/tests/rust_guests/witguest/Cargo.lock +++ b/src/tests/rust_guests/witguest/Cargo.lock @@ -67,12 +67,6 @@ version = "1.0.102" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" -[[package]] -name = "atomic_refcell" -version = "0.1.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21e4227379beff4205943696e6c3e0cd809bacdf3f0edd6e3dd153e2269571a4" - [[package]] name = "bindgen" version = "0.71.1" @@ -261,7 +255,6 @@ name = "hyperlight-common" version = "0.15.0" dependencies = [ "anyhow", - "atomic_refcell", "bitflags", "bytemuck", "bytes", From 6977382ddb4a039ec2b1875ca96b1b73ca4867ca Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Tue, 7 Apr 2026 13:40:51 +0200 Subject: [PATCH 09/31] feat(virtq): send logs over virtq Signed-off-by: Tomasz Andrzejak --- src/hyperlight_common/src/virtq/buffer.rs | 4 +- src/hyperlight_common/src/virtq/mod.rs | 184 +++++++++++++++++- src/hyperlight_common/src/virtq/msg.rs | 23 +++ src/hyperlight_common/src/virtq/pool.rs | 12 +- src/hyperlight_common/src/virtq/producer.rs | 40 ++++ .../src/guest_handle/host_comm.rs | 12 +- src/hyperlight_guest/src/virtq/context.rs | 179 ++++++++++++----- src/hyperlight_host/src/mem/mgr.rs | 66 ++++--- src/hyperlight_host/src/sandbox/outb.rs | 136 ++++++++----- src/hyperlight_host/tests/common/mod.rs | 10 + .../tests/sandbox_host_tests.rs | 181 ++++++++++++++++- src/tests/rust_guests/simpleguest/src/main.rs | 7 + 12 files changed, 715 insertions(+), 139 deletions(-) diff --git a/src/hyperlight_common/src/virtq/buffer.rs b/src/hyperlight_common/src/virtq/buffer.rs index 238775c6d..b41708b03 100644 --- a/src/hyperlight_common/src/virtq/buffer.rs +++ b/src/hyperlight_common/src/virtq/buffer.rs @@ -33,7 +33,9 @@ pub enum AllocError { InvalidArg, #[error("Empty region")] EmptyRegion, - #[error("Out of memory")] + #[error("No space available")] + NoSpace, + #[error("Requested size exceeds pool capacity")] OutOfMemory, #[error("Overflow")] Overflow, diff --git a/src/hyperlight_common/src/virtq/mod.rs b/src/hyperlight_common/src/virtq/mod.rs index 09d69268c..7d2aa7f0a 100644 --- a/src/hyperlight_common/src/virtq/mod.rs +++ b/src/hyperlight_common/src/virtq/mod.rs @@ -181,9 +181,13 @@ pub trait Notifier { #[derive(Error, Debug)] pub enum VirtqError { #[error("Ring error: {0}")] - RingError(#[from] RingError), + RingError(RingError), #[error("Allocation error: {0}")] - Alloc(#[from] AllocError), + Alloc(AllocError), + #[error("Ring or pool temporarily full")] + Backpressure, + #[error("Allocation exceeds pool capacity")] + OutOfMemory, #[error("Invalid token")] BadToken, #[error("Invalid chain received")] @@ -202,6 +206,33 @@ pub enum VirtqError { NoReadableBuffer, } +impl VirtqError { + /// Check if this error is transient or unrecoverable. + #[inline(always)] + pub fn is_transient(&self) -> bool { + matches!(self, Self::Backpressure) + } +} + +impl From for VirtqError { + fn from(e: RingError) -> Self { + match e { + RingError::WouldBlock => Self::Backpressure, + other => Self::RingError(other), + } + } +} + +impl From for VirtqError { + fn from(e: AllocError) -> Self { + match e { + AllocError::NoSpace => Self::Backpressure, + AllocError::OutOfMemory => Self::OutOfMemory, + other => Self::Alloc(other), + } + } +} + /// Layout of a packed virtqueue ring in shared memory. /// /// Describes the memory addresses for the descriptor table and event suppression @@ -461,7 +492,7 @@ pub(crate) mod test_utils { let addr = self.next.fetch_add(len as u64, Ordering::Relaxed); let end = addr + len as u64; if end > self.base + self.size as u64 { - return Err(AllocError::OutOfMemory); + return Err(AllocError::NoSpace); } Ok(Allocation { addr, len }) } @@ -829,6 +860,153 @@ mod tests { assert_eq!(&expected_first.1[..], b"resp1"); assert_eq!(&expected_second.1[..], b"resp2"); } + + /// Helper: submit a ReadOnly entry (entry data, no completion). + fn send_readonly( + producer: &mut VirtqProducer, + entry_data: &[u8], + ) -> Token { + let mut se = producer.chain().entry(entry_data.len()).build().unwrap(); + se.write_all(entry_data).unwrap(); + producer.submit(se).unwrap() + } + + #[test] + fn test_reclaim_frees_ring_slots() { + let ring = make_ring(4); + let (mut producer, mut consumer, _) = make_test_producer(&ring); + + // Fill the ring with ReadOnly entries + send_readonly(&mut producer, b"a"); + send_readonly(&mut producer, b"b"); + send_readonly(&mut producer, b"c"); + send_readonly(&mut producer, b"d"); + + // Ring is now full - next submit should fail with Backpressure + let mut se = producer.chain().entry(1).build().unwrap(); + se.write_all(b"e").unwrap(); + let res = producer.submit(se); + assert!( + matches!(res, Err(VirtqError::Backpressure)), + "expected Backpressure from full ring" + ); + + // Consumer acks all entries + while let Some((_, completion)) = consumer.poll(1024).unwrap() { + consumer.complete(completion).unwrap(); + } + + // Reclaim should free ring slots without losing data + let count = producer.reclaim().unwrap(); + assert_eq!(count, 4, "expected 4 reclaimed entries"); + + // Ring should have space now + send_readonly(&mut producer, b"e"); + } + + #[test] + fn test_reclaim_buffers_rw_completions() { + let ring = make_ring(4); + let (mut producer, mut consumer, _) = make_test_producer(&ring); + + // Submit a ReadWrite entry + let tok = send_readwrite(&mut producer, b"request", 64); + + // Consumer processes and writes response + let (_, completion) = consumer.poll(1024).unwrap().unwrap(); + let SendCompletion::Writable(mut wc) = completion else { + panic!("expected writable"); + }; + wc.write_all(b"response-data").unwrap(); + consumer.complete(wc.into()).unwrap(); + + // Reclaim buffers the completion (doesn't discard it) + let count = producer.reclaim().unwrap(); + assert_eq!(count, 1); + + // poll() should return the buffered completion + let cqe = producer.poll().unwrap().unwrap(); + assert_eq!(cqe.token, tok); + assert_eq!(&cqe.data[..], b"response-data"); + } + + #[test] + fn test_reclaim_then_poll_preserves_order() { + let ring = make_ring(8); + let (mut producer, mut consumer, _) = make_test_producer(&ring); + + // Submit 3 entries: RO, RW, RO + let tok_ro1 = send_readonly(&mut producer, b"log1"); + let tok_rw = send_readwrite(&mut producer, b"call", 64); + let tok_ro2 = send_readonly(&mut producer, b"log2"); + + // Consumer processes all 3 + let (_, c1) = consumer.poll(1024).unwrap().unwrap(); + consumer.complete(c1).unwrap(); // ack RO + + let (_, c2) = consumer.poll(1024).unwrap().unwrap(); + let SendCompletion::Writable(mut wc) = c2 else { + panic!("expected writable"); + }; + wc.write_all(b"result").unwrap(); + consumer.complete(wc.into()).unwrap(); // complete RW + + let (_, c3) = consumer.poll(1024).unwrap().unwrap(); + consumer.complete(c3).unwrap(); // ack RO + + // Reclaim all 3 + let count = producer.reclaim().unwrap(); + assert_eq!(count, 3); + + // poll() returns them in order + let cqe1 = producer.poll().unwrap().unwrap(); + assert_eq!(cqe1.token, tok_ro1); + assert!(cqe1.data.is_empty()); + + let cqe2 = producer.poll().unwrap().unwrap(); + assert_eq!(cqe2.token, tok_rw); + assert_eq!(&cqe2.data[..], b"result"); + + let cqe3 = producer.poll().unwrap().unwrap(); + assert_eq!(cqe3.token, tok_ro2); + assert!(cqe3.data.is_empty()); + + // No more + assert!(producer.poll().unwrap().is_none()); + } + + #[test] + fn test_reclaim_mixed_with_poll() { + let ring = make_ring(8); + let (mut producer, mut consumer, _) = make_test_producer(&ring); + + // Submit and complete 2 entries + send_readonly(&mut producer, b"x"); + let tok_rw = send_readwrite(&mut producer, b"y", 64); + + let (_, c1) = consumer.poll(1024).unwrap().unwrap(); + consumer.complete(c1).unwrap(); + + let (_, c2) = consumer.poll(1024).unwrap().unwrap(); + let SendCompletion::Writable(mut wc) = c2 else { + panic!("expected writable"); + }; + wc.write_all(b"reply").unwrap(); + consumer.complete(wc.into()).unwrap(); + + // poll() consumes first entry directly from ring + let cqe1 = producer.poll().unwrap().unwrap(); + assert!(cqe1.data.is_empty()); + + // reclaim() buffers second entry + let count = producer.reclaim().unwrap(); + assert_eq!(count, 1); + + // poll() returns the buffered one + let cqe2 = producer.poll().unwrap().unwrap(); + assert_eq!(cqe2.token, tok_rw); + assert_eq!(&cqe2.data[..], b"reply"); + } } #[cfg(all(test, loom))] mod fuzz { diff --git a/src/hyperlight_common/src/virtq/msg.rs b/src/hyperlight_common/src/virtq/msg.rs index 9c7f69947..ade59643b 100644 --- a/src/hyperlight_common/src/virtq/msg.rs +++ b/src/hyperlight_common/src/virtq/msg.rs @@ -34,6 +34,24 @@ pub enum MsgKind { StreamEnd = 0x04, /// Cancel a pending request. Cancel = 0x05, + /// A guest log message (GuestLogData payload follows). + Log = 0x06, +} + +impl TryFrom for MsgKind { + type Error = u8; + + fn try_from(value: u8) -> Result { + match value { + 0x01 => Ok(Self::Request), + 0x02 => Ok(Self::Response), + 0x03 => Ok(Self::StreamChunk), + 0x04 => Ok(Self::StreamEnd), + 0x05 => Ok(Self::Cancel), + 0x06 => Ok(Self::Log), + other => Err(other), + } + } } /// Wire header for all virtqueue messages @@ -72,4 +90,9 @@ impl VirtqMsgHeader { payload_len, } } + + /// Parse the kind field into a [`MsgKind`] enum. + pub fn msg_kind(&self) -> Result { + MsgKind::try_from(self.kind) + } } diff --git a/src/hyperlight_common/src/virtq/pool.rs b/src/hyperlight_common/src/virtq/pool.rs index 42325a56c..2e49e27fe 100644 --- a/src/hyperlight_common/src/virtq/pool.rs +++ b/src/hyperlight_common/src/virtq/pool.rs @@ -169,7 +169,7 @@ impl Slab { return Err(AllocError::OutOfMemory); } - let idx = self.find_slots(need_slots).ok_or(AllocError::OutOfMemory)?; + let idx = self.find_slots(need_slots).ok_or(AllocError::NoSpace)?; self.used_slots.insert_range(idx..idx + need_slots); let addr = self.addr_of(idx).ok_or(AllocError::Overflow)?; @@ -463,7 +463,7 @@ impl Inner { if len <= L { match self.lower.alloc(len) { Ok(alloc) => return Ok(alloc), - Err(AllocError::OutOfMemory) => {} + Err(AllocError::NoSpace) => {} Err(e) => return Err(e), } } @@ -614,7 +614,7 @@ impl BufferProvider for RecyclePool { return Err(AllocError::OutOfMemory); } - let addr = inner.free.pop().ok_or(AllocError::OutOfMemory)?; + let addr = inner.free.pop().ok_or(AllocError::NoSpace)?; Ok(Allocation { addr, @@ -727,7 +727,7 @@ mod tests { // Next allocation should fail let result = slab.alloc(256); - assert!(matches!(result, Err(AllocError::OutOfMemory))); + assert!(matches!(result, Err(AllocError::NoSpace))); // Free one and retry slab.dealloc(a2).unwrap(); @@ -1287,7 +1287,7 @@ mod fuzz { assert!(alloc.len >= *size); allocations.push(alloc); } - Err(AllocError::OutOfMemory) => {} + Err(AllocError::NoSpace | AllocError::OutOfMemory) => {} Err(_) => { return false; } @@ -1318,7 +1318,7 @@ mod fuzz { assert!(new_alloc.len >= *new_size); allocations[idx] = new_alloc; } - Err(AllocError::OutOfMemory) => {} + Err(AllocError::NoSpace | AllocError::OutOfMemory) => {} Err(_) => return false, } } diff --git a/src/hyperlight_common/src/virtq/producer.rs b/src/hyperlight_common/src/virtq/producer.rs index 276a2ff3b..eeb96cc7f 100644 --- a/src/hyperlight_common/src/virtq/producer.rs +++ b/src/hyperlight_common/src/virtq/producer.rs @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ +use alloc::collections::VecDeque; use alloc::vec; use alloc::vec::Vec; @@ -124,6 +125,7 @@ pub struct VirtqProducer { notifier: N, pool: P, inflight: Vec>, + pending: VecDeque, } impl VirtqProducer @@ -149,11 +151,15 @@ where pool, notifier, inflight, + pending: VecDeque::new(), } } /// Poll for a single completion from the device. /// + /// Returns buffered completions from prior [`reclaim`](Self::reclaim) + /// calls first, then checks the ring for new completions. + /// /// Returns `Ok(Some(completion))` if a completion is available, `Ok(None)` if no /// completions are ready (would block), or an error if the device misbehaved. /// @@ -167,6 +173,39 @@ where /// - [`VirtqError::InvalidState`] - Device returned invalid descriptor ID or /// wrote more data than the completion buffer capacity pub fn poll(&mut self) -> Result, VirtqError> + where + M: Send + 'static, + P: Send + 'static, + { + if let Some(cqe) = self.pending.pop_front() { + return Ok(Some(cqe)); + } + self.poll_ring() + } + + /// Reclaim ring slots and pool entries from completed descriptors. + /// + /// Processes all available used entries from the ring: frees entry + /// buffer allocations immediately, and buffers completion data for + /// later retrieval via [`poll`](Self::poll). + /// + /// Use this to free resources under backpressure without losing + /// completion data. Returns the number of entries reclaimed. + pub fn reclaim(&mut self) -> Result + where + M: Send + 'static, + P: Send + 'static, + { + let mut count = 0; + while let Some(cqe) = self.poll_ring()? { + self.pending.push_back(cqe); + count += 1; + } + Ok(count) + } + + /// Poll one completion directly from the ring (bypassing pending buffer). + fn poll_ring(&mut self) -> Result, VirtqError> where M: Send + 'static, P: Send + 'static, @@ -363,6 +402,7 @@ where self.inner.reset(); self.pool.reset(); self.inflight.fill(None); + self.pending.clear(); } } diff --git a/src/hyperlight_guest/src/guest_handle/host_comm.rs b/src/hyperlight_guest/src/guest_handle/host_comm.rs index c72de8a3f..d440852f6 100644 --- a/src/hyperlight_guest/src/guest_handle/host_comm.rs +++ b/src/hyperlight_guest/src/guest_handle/host_comm.rs @@ -162,7 +162,7 @@ impl GuestHandle { source_file: &str, line: u32, ) { - // Closure to send log message to host + // Closure to send log message to host via G2H virtqueue let _send_to_host = || { let guest_log_data = GuestLogData::new( message.to_string(), @@ -177,12 +177,10 @@ impl GuestHandle { .try_into() .expect("Failed to convert GuestLogData to bytes"); - self.push_shared_output_data(&bytes) - .expect("Unable to push log data to shared output data"); - - unsafe { - out32(OutBAction::Log as u16, 0); - } + crate::virtq::with_context(|ctx| { + ctx.emit_log(&bytes) + .expect("Unable to send log data via virtq"); + }); }; #[cfg(all(feature = "trace_guest", target_arch = "x86_64"))] diff --git a/src/hyperlight_guest/src/virtq/context.rs b/src/hyperlight_guest/src/virtq/context.rs index c3373d220..59b71bf14 100644 --- a/src/hyperlight_guest/src/virtq/context.rs +++ b/src/hyperlight_guest/src/virtq/context.rs @@ -17,6 +17,7 @@ limitations under the License. //! Guest virtqueue context. use alloc::vec::Vec; +use core::result; use core::sync::atomic::AtomicU16; use core::sync::atomic::Ordering::Relaxed; @@ -30,7 +31,7 @@ use hyperlight_common::mem::PAGE_SIZE_USIZE; use hyperlight_common::outb::OutBAction; use hyperlight_common::virtq::msg::{MsgKind, VirtqMsgHeader}; use hyperlight_common::virtq::{ - BufferPool, Layout, Notifier, QueueStats, RecyclePool, VirtqProducer, + self, BufferPool, Layout, Notifier, QueueStats, RecyclePool, Token, VirtqProducer, }; use super::GuestMemOps; @@ -131,19 +132,33 @@ impl GuestContext { let entry_len = VirtqMsgHeader::SIZE + payload.len(); - let mut entry = self - .g2h_producer - .chain() - .entry(entry_len) - .completion(MAX_RESPONSE_CAP) - .build()?; + let token = match self.try_send_readwrite(hdr_bytes, payload, entry_len) { + Ok(tok) => tok, + Err(e) if e.is_transient() => { + self.g2h_producer.notify_backpressure(); - entry.write_all(hdr_bytes)?; - entry.write_all(payload)?; - self.g2h_producer.submit(entry)?; + if let Err(err) = self.g2h_producer.reclaim() { + bail!("G2H reclaim: {err}"); + } + + let Ok(tok) = self.try_send_readwrite(hdr_bytes, payload, entry_len) else { + bail!("G2H call retry"); + }; - let Some(completion) = self.g2h_producer.poll()? else { - bail!("G2H: no completion received"); + tok + } + Err(e) => bail!("G2H call: {e}"), + }; + + // Poll completions, skipping earlier entries like log acks + // until we find the completion matching our request token. + let completion = loop { + let Some(cqe) = self.g2h_producer.poll()? else { + bail!("G2H: no completion received"); + }; + if cqe.token == token { + break cqe; + } }; let result_bytes = &completion.data; @@ -164,25 +179,6 @@ impl GuestContext { Ok(ret) } - /// Pre-fill the H2G queue with completion-only descriptors so the host - /// can write incoming call payloads into them. - fn prefill_h2g(&mut self) { - loop { - let entry = match self - .h2g_producer - .chain() - .completion(PAGE_SIZE_USIZE) - .build() - { - Ok(e) => e, - Err(_) => break, - }; - if self.h2g_producer.submit(entry).is_err() { - break; - } - } - } - /// Receive a host-to-guest function call from the H2G queue. pub fn recv_h2g_call(&mut self) -> Result { let Some(completion) = self.h2g_producer.poll()? else { @@ -196,8 +192,8 @@ impl GuestContext { let hdr: &VirtqMsgHeader = bytemuck::from_bytes(&data[..VirtqMsgHeader::SIZE]); - if hdr.kind != MsgKind::Request as u8 { - bail!("H2G: unexpected message kind"); + if hdr.msg_kind() != Ok(MsgKind::Request) { + bail!("H2G: unexpected message kind: 0x{:02x}", hdr.kind); } let payload_end = VirtqMsgHeader::SIZE + hdr.payload_len as usize; @@ -213,32 +209,83 @@ impl GuestContext { /// Send the result of a host-to-guest call back to the host via the /// G2H queue, then refill one H2G descriptor slot. pub fn send_h2g_result(&mut self, payload: &[u8]) -> Result<()> { - // Build a Response message on the G2H queue - let reqid = REQUEST_ID.fetch_add(1, Relaxed); - let hdr = VirtqMsgHeader::new(MsgKind::Response, reqid, payload.len() as u32); - let hdr_bytes = bytemuck::bytes_of(&hdr); + self.send_g2h_oneshot(MsgKind::Response, payload)?; - let entry_len = VirtqMsgHeader::SIZE + payload.len(); - let mut entry = self.g2h_producer.chain().entry(entry_len).build()?; - - entry.write_all(hdr_bytes)?; - entry.write_all(payload)?; - self.g2h_producer.submit(entry)?; - - // Refill one H2G completion slot - if let Ok(e) = self + // Best-effort refill of one H2G slot. Backpressure is expected + // (pool/ring may be full), other errors are propagated. + match self .h2g_producer .chain() .completion(PAGE_SIZE_USIZE) .build() { - let _ = self.h2g_producer.submit(e); + Ok(e) => match self.h2g_producer.submit(e) { + Ok(_) => {} + Err(virtq::VirtqError::Backpressure) => {} + Err(e) => bail!("H2G refill submit: {e}"), + }, + Err(virtq::VirtqError::Backpressure) => {} + Err(e) => bail!("H2G refill build: {e}"), } Ok(()) } - /// Drain any pending G2H completions (discard them). + /// Pre-fill the H2G queue with completion-only descriptors so the host + /// can write incoming call payloads into them. + fn prefill_h2g(&mut self) { + loop { + let entry = match self + .h2g_producer + .chain() + .completion(PAGE_SIZE_USIZE) + .build() + { + Ok(e) => e, + Err(virtq::VirtqError::Backpressure) => break, + Err(e) => panic!("H2G prefill build: {e}"), + }; + + match self.h2g_producer.submit(entry) { + Ok(_) => {} + Err(virtq::VirtqError::Backpressure) => break, + Err(e) => panic!("H2G prefill submit: {e}"), + } + } + } + + /// Send a one-way message on the G2H queue ReadOnly and no completion. + /// + /// If the pool or ring is full, triggers backpressure, VM exit so + /// the host can drain, then retries once. + fn send_g2h_oneshot(&mut self, kind: MsgKind, payload: &[u8]) -> Result<()> { + let reqid = REQUEST_ID.fetch_add(1, Relaxed); + let hdr = VirtqMsgHeader::new(kind, reqid, payload.len() as u32); + let hdr_bytes = bytemuck::bytes_of(&hdr); + let entry_len = VirtqMsgHeader::SIZE + payload.len(); + + // First attempt + match self.try_send_readonly(hdr_bytes, payload, entry_len) { + Ok(_) => return Ok(()), + Err(virtq::VirtqError::Backpressure) => { + // VM exit so host drains and completes G2H entries. + self.g2h_producer.notify_backpressure(); + } + Err(e) => bail!("G2H oneshot: {e}"), + } + + // Reclaim ring/pool resources from completed entries. + if let Err(e) = self.g2h_producer.reclaim() { + bail!("G2H oneshot retry: {e}"); + } + // Retry after backpressure + match self.try_send_readonly(hdr_bytes, payload, entry_len) { + Ok(_) => Ok(()), + Err(e) => bail!("G2H oneshot retry: {e}"), + } + } + + /// Drain any pending G2H completions. /// /// This is called before checking for H2G calls so that the host /// can reclaim G2H response buffers. @@ -246,6 +293,11 @@ impl GuestContext { while let Ok(Some(_)) = self.g2h_producer.poll() {} } + /// Send a log message via the G2H queue. Fire-and-forget. + pub fn emit_log(&mut self, log_data: &[u8]) -> Result<()> { + self.send_g2h_oneshot(MsgKind::Log, log_data) + } + /// Reset ring and pool state after snapshot restore. pub(super) fn reset(&mut self, new_generation: u64) { self.g2h_producer.reset(); @@ -259,4 +311,35 @@ impl GuestContext { pub(super) fn generation(&self) -> u64 { self.generation } + + fn try_send_readonly( + &mut self, + header: &[u8], + payload: &[u8], + entry_len: usize, + ) -> result::Result { + let mut entry = self.g2h_producer.chain().entry(entry_len).build()?; + + entry.write_all(header)?; + entry.write_all(payload)?; + self.g2h_producer.submit(entry) + } + + fn try_send_readwrite( + &mut self, + header: &[u8], + payload: &[u8], + entry_len: usize, + ) -> result::Result { + let mut entry = self + .g2h_producer + .chain() + .entry(entry_len) + .completion(MAX_RESPONSE_CAP) + .build()?; + + entry.write_all(header)?; + entry.write_all(payload)?; + self.g2h_producer.submit(entry) + } } diff --git a/src/hyperlight_host/src/mem/mgr.rs b/src/hyperlight_host/src/mem/mgr.rs index dbe052f8b..764e72f4c 100644 --- a/src/hyperlight_host/src/mem/mgr.rs +++ b/src/hyperlight_host/src/mem/mgr.rs @@ -544,6 +544,7 @@ impl SandboxMemoryManager { } /// Read guest log data from the `SharedMemory` contained within `self` + #[allow(dead_code)] #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] pub(crate) fn read_guest_log_data(&mut self) -> Result { self.scratch_mem.try_pop_buffer_into::( @@ -1097,35 +1098,48 @@ impl SandboxMemoryManager { .as_mut() .ok_or_else(|| new_error!("G2H consumer not initialized"))?; - let Some((entry, completion)) = consumer - .poll(8192) - .map_err(|e| new_error!("G2H poll for H2G result: {:?}", e))? - else { - return Err(new_error!("G2H: no H2G result entry after halt")); - }; - - let entry_data = entry.data(); - if entry_data.len() < VirtqMsgHeader::SIZE { - return Err(new_error!("G2H: result entry too short")); - } - - let hdr: &VirtqMsgHeader = bytemuck::from_bytes(&entry_data[..VirtqMsgHeader::SIZE]); - if hdr.kind != MsgKind::Response as u8 { - return Err(new_error!( - "G2H: expected Response after halt, got kind={}", - hdr.kind - )); - } + // Drain the G2H queue, processing Log entries inline, until we + // find the Response that carries the H2G function call result. + loop { + let maybe_next = consumer + .poll(8192) + .map_err(|e| new_error!("G2H poll for H2G result: {:?}", e))?; - let payload = &entry_data[VirtqMsgHeader::SIZE..]; - let fcr = FunctionCallResult::try_from(payload) - .map_err(|e| new_error!("G2H: malformed FunctionCallResult: {}", e))?; + let Some((entry, completion)) = maybe_next else { + return Err(new_error!("G2H: no H2G result entry after halt")); + }; - consumer - .complete(completion) - .map_err(|e| new_error!("G2H complete: {:?}", e))?; + let entry_data = entry.data(); + if entry_data.len() < VirtqMsgHeader::SIZE { + return Err(new_error!("G2H: result entry too short")); + } - Ok(fcr) + let hdr: &VirtqMsgHeader = bytemuck::from_bytes(&entry_data[..VirtqMsgHeader::SIZE]); + let payload = &entry_data[VirtqMsgHeader::SIZE..]; + + match hdr.msg_kind() { + Ok(MsgKind::Response) => { + let fcr = FunctionCallResult::try_from(payload) + .map_err(|e| new_error!("G2H: malformed FunctionCallResult: {}", e))?; + consumer + .complete(completion) + .map_err(|e| new_error!("G2H complete: {:?}", e))?; + return Ok(fcr); + } + Ok(MsgKind::Log) => { + crate::sandbox::outb::emit_guest_log_from_payload(payload); + consumer + .complete(completion) + .map_err(|e| new_error!("G2H complete log: {:?}", e))?; + } + Ok(other) => { + return Err(new_error!("G2H: expected Response or Log, got {:?}", other)); + } + Err(unknown) => { + return Err(new_error!("G2H: unknown message kind: 0x{:02x}", unknown)); + } + } + } } } diff --git a/src/hyperlight_host/src/sandbox/outb.rs b/src/hyperlight_host/src/sandbox/outb.rs index 0e11409ad..b5a20d31c 100644 --- a/src/hyperlight_host/src/sandbox/outb.rs +++ b/src/hyperlight_host/src/sandbox/outb.rs @@ -64,46 +64,67 @@ pub enum HandleOutbError { MemProfile(String), } +#[allow(dead_code)] #[instrument(err(Debug), skip_all, parent = Span::current(), level="Trace")] pub(super) fn outb_log( mgr: &mut SandboxMemoryManager, ) -> Result<(), HandleOutbError> { - // This code will create either a logging record or a tracing record for the GuestLogData depending on if the host has set up a tracing subscriber. - // In theory as we have enabled the log feature in the Cargo.toml for tracing this should happen - // automatically (based on if there is tracing subscriber present) but only works if the event created using macros. (see https://github.com/tokio-rs/tracing/blob/master/tracing/src/macros.rs#L2421 ) - // The reason that we don't want to use the tracing macros is that we want to be able to explicitly - // set the file and line number for the log record which is not possible with macros. - // This is because the file and line number come from the guest not the call site. - let log_data: GuestLogData = mgr .read_guest_log_data() .map_err(|e| HandleOutbError::ReadLogData(e.to_string()))?; - let record_level: Level = (&log_data.level).into(); + emit_guest_log(&log_data); + Ok(()) +} - // Work out if we need to log or trace - // this API is marked as follows but it is the easiest way to work out if we should trace or log +/// Emit a guest log record from a virtqueue payload. +/// +/// Deserializes [`GuestLogData`] from the raw bytes and emits either +/// a tracing event or a log record, matching the original `outb_log` +/// behavior. +pub(crate) fn emit_guest_log_from_payload(payload: &[u8]) { + let Ok(log_data) = GuestLogData::try_from(payload) else { + return; + }; + emit_guest_log(&log_data); +} - // Private API for internal use by tracing's macros. - // - // This function is *not* considered part of `tracing`'s public API, and has no - // stability guarantees. If you use it, and it breaks or disappears entirely, - // don't say we didn't warn you. +fn emit_guest_log(log_data: &GuestLogData) { + // This code will create either a logging record or a tracing record + // for the GuestLogData depending on if the host has set up a tracing + // subscriber. + // In theory as we have enabled the log feature in the Cargo.toml for + // tracing this should happen automatically (based on if there is a + // tracing subscriber present) but only works if the event is created + // using macros. + // (see https://github.com/tokio-rs/tracing/blob/master/tracing/src/macros.rs#L2421) + // The reason that we don't want to use the tracing macros is that we + // want to be able to explicitly set the file and line number for the + // log record which is not possible with macros. + // This is because the file and line number come from the guest not + // the call site. + let record_level: Level = (&log_data.level).into(); + + // Work out if we need to log or trace. + // This API is marked as internal but it is the easiest way to work + // out if we should trace or log. let should_trace = tracing_core::dispatcher::has_been_set(); let source_file = Some(log_data.source_file.as_str()); let line = Some(log_data.line); let source = Some(log_data.source.as_str()); - // See https://github.com/rust-lang/rust/issues/42253 for the reason this has to be done this way + // See https://github.com/rust-lang/rust/issues/42253 for the reason + // this has to be done this way. if should_trace { - // Create a tracing event for the GuestLogData - // Ideally we would create tracing metadata based on the Guest Log Data - // but tracing derives the metadata at compile time + // Create a tracing event for the GuestLogData. + // Ideally we would create tracing metadata based on the Guest + // Log Data but tracing derives the metadata at compile time. // see https://github.com/tokio-rs/tracing/issues/2419 - // so we leave it up to the subscriber to figure out that there are logging fields present with this data - format_trace( + // So we leave it up to the subscriber to figure out that there + // are logging fields present with this data. + let _ = format_trace( &Record::builder() .args(format_args!("{}", log_data.message)) .level(record_level) @@ -112,8 +133,7 @@ pub(super) fn outb_log( .line(line) .module_path(source) .build(), - ) - .map_err(|e| HandleOutbError::TraceFormat(e.to_string()))?; + ); } else { // Create a log record for the GuestLogData log::logger().log( @@ -127,8 +147,6 @@ pub(super) fn outb_log( .build(), ); } - - Ok(()) } const ABORT_TERMINATOR: u8 = 0xFF; @@ -184,6 +202,8 @@ fn outb_abort( } /// Handle a guest-to-host function call received via the G2H virtqueue. +/// +/// Log entries that arrive before the Request are processed inline. fn outb_virtq_call( mem_mgr: &mut SandboxMemoryManager, host_funcs: &Arc>, @@ -192,32 +212,49 @@ fn outb_virtq_call( HandleOutbError::ReadHostFunctionCall("G2H consumer not initialized".into()) })?; - let Some((entry, completion)) = consumer - .poll(8192) - .map_err(|e| HandleOutbError::ReadHostFunctionCall(format!("G2H poll: {e}")))? - else { - // No G2H entry - can happen when guest H2G prefill - // triggers VirtqNotify before suppression is set. - return Ok(()); + // Drain entries, processing Log messages, until we find a Request. + let (entry, completion) = loop { + let Some((entry, completion)) = consumer + .poll(8192) + .map_err(|e| HandleOutbError::ReadHostFunctionCall(format!("G2H poll: {e}")))? + else { + // No G2H entry - backpressure-only notify or prefill notify. + return Ok(()); + }; + + let entry_data = entry.data(); + if entry_data.len() < VirtqMsgHeader::SIZE { + return Err(HandleOutbError::ReadHostFunctionCall( + "G2H entry too short".into(), + )); + } + let hdr: VirtqMsgHeader = *bytemuck::from_bytes(&entry_data[..VirtqMsgHeader::SIZE]); + + match hdr.msg_kind() { + Ok(MsgKind::Log) => { + let payload = &entry_data[VirtqMsgHeader::SIZE..]; + emit_guest_log_from_payload(payload); + let _ = consumer.complete(completion); + continue; + } + Ok(MsgKind::Request) => break (entry, completion), + Ok(other) => { + return Err(HandleOutbError::ReadHostFunctionCall(format!( + "G2H: expected Request via outb, got {:?}", + other + ))); + } + Err(unknown) => { + return Err(HandleOutbError::ReadHostFunctionCall(format!( + "G2H: unknown message kind: 0x{unknown:02x}" + ))); + } + } }; let entry_data = entry.data(); - if entry_data.len() < VirtqMsgHeader::SIZE { - return Err(HandleOutbError::ReadHostFunctionCall( - "G2H entry too short".into(), - )); - } - let hdr: VirtqMsgHeader = *bytemuck::from_bytes(&entry_data[..VirtqMsgHeader::SIZE]); let payload = &entry_data[VirtqMsgHeader::SIZE..]; - // TODO(virtq): Only Requests (host function callbacks) arrive via outb. - if hdr.kind != MsgKind::Request as u8 { - return Err(HandleOutbError::ReadHostFunctionCall(format!( - "G2H: expected Request via outb, got kind={}", - hdr.kind - ))); - } - let call = FunctionCall::try_from(payload) .map_err(|e| HandleOutbError::ReadHostFunctionCall(e.to_string()))?; @@ -269,7 +306,12 @@ pub(crate) fn handle_outb( .try_into() .map_err(|e: anyhow::Error| HandleOutbError::InvalidPort(e.to_string()))? { - OutBAction::Log => outb_log(mem_mgr), + OutBAction::Log => { + // Legacy path - logs now arrive via G2H virtqueue + // and are processed inline by outb_virtq_call / + // read_h2g_result_from_g2h. + Ok(()) + } OutBAction::CallFunction => { let call = mem_mgr .get_host_function_call() diff --git a/src/hyperlight_host/tests/common/mod.rs b/src/hyperlight_host/tests/common/mod.rs index d58e60aa6..8b2f6de9f 100644 --- a/src/hyperlight_host/tests/common/mod.rs +++ b/src/hyperlight_host/tests/common/mod.rs @@ -80,6 +80,16 @@ where f(sandbox); } +/// Runs a test with a Rust guest UninitializedSandbox using custom configuration. +pub fn with_rust_uninit_sandbox_cfg(cfg: SandboxConfiguration, f: F) +where + F: FnOnce(UninitializedSandbox), +{ + let sandbox = + UninitializedSandbox::new(GuestBinary::FilePath(rust_guest_path()), Some(cfg)).unwrap(); + f(sandbox); +} + // ============================================================================= // C guest helpers // ============================================================================= diff --git a/src/hyperlight_host/tests/sandbox_host_tests.rs b/src/hyperlight_host/tests/sandbox_host_tests.rs index e0daf969b..d6db20ddc 100644 --- a/src/hyperlight_host/tests/sandbox_host_tests.rs +++ b/src/hyperlight_host/tests/sandbox_host_tests.rs @@ -26,7 +26,8 @@ use hyperlight_testing::simple_guest_as_string; pub mod common; // pub to disable dead_code warning use crate::common::{ with_all_sandboxes, with_all_sandboxes_cfg, with_all_sandboxes_with_writer, - with_all_uninit_sandboxes, + with_all_uninit_sandboxes, with_rust_sandbox_cfg, with_rust_uninit_sandbox, + with_rust_uninit_sandbox_cfg, }; #[test] @@ -375,3 +376,181 @@ fn host_function_error() { } }); } + +#[test] +fn virtq_log_delivery() { + use hyperlight_testing::simplelogger::{LOGGER, SimpleLogger}; + + SimpleLogger::initialize_test_logger(); + LOGGER.clear_log_calls(); + + with_rust_uninit_sandbox(|mut sbox| { + sbox.set_max_guest_log_level(tracing_core::LevelFilter::TRACE); + let mut sandbox = sbox.evolve().unwrap(); + + sandbox + .call::<()>("LogMessage", ("virtq log test message".to_string(), 3_i32)) + .unwrap(); + + // Verify the guest log arrived via virtqueue + let count = LOGGER.num_log_calls(); + assert!(count > 0, "expected at least one guest log, got 0"); + + let mut found = false; + for i in 0..count { + if let Some(call) = LOGGER.get_log_call(i) + && call.target == "hyperlight_guest" + && call.args.contains("virtq log test") + { + found = true; + break; + } + } + assert!(found, "expected 'virtq log test' message from guest"); + LOGGER.clear_log_calls(); + }); +} + +#[test] +fn virtq_log_with_callback() { + // Verify that log messages interleaved with host callbacks work + with_all_uninit_sandboxes(|mut sandbox| { + let (tx, _rx) = channel(); + sandbox + .register("HostMethod1", move |msg: String| { + let len = msg.len(); + tx.send(msg).unwrap(); + len as i32 + }) + .unwrap(); + let mut sandbox = sandbox.evolve().unwrap(); + + // Echo triggers guest-side logging infrastructure, then returns. + // This validates that log ReadOnly entries interleaved with + // function call ReadWrite entries don't corrupt the G2H queue. + let res: String = sandbox.call("Echo", "test".to_string()).unwrap(); + assert_eq!(res, "test"); + }); +} + +#[test] +fn virtq_log_backpressure() { + use hyperlight_testing::simplelogger::{LOGGER, SimpleLogger}; + + SimpleLogger::initialize_test_logger(); + LOGGER.clear_log_calls(); + + let mut cfg = SandboxConfiguration::default(); + cfg.set_g2h_pool_pages(2); + + with_rust_uninit_sandbox_cfg(cfg, |mut sbox| { + sbox.set_max_guest_log_level(tracing_core::LevelFilter::INFO); + let mut sandbox = sbox.evolve().unwrap(); + + // 50 logs with a 2-page pool should trigger backpressure + sandbox.call::<()>("LogMessageN", 50_i32).unwrap(); + + // Verify sandbox is still functional after backpressure + let res: i32 = sandbox + .call("ThisIsNotARealFunctionButTheNameIsImportant", ()) + .unwrap(); + assert_eq!(res, 99); + + // Verify all 50 log entries were delivered + let guest_count = (0..LOGGER.num_log_calls()) + .filter_map(|i| LOGGER.get_log_call(i)) + .filter(|c| c.target == "hyperlight_guest" && c.args.contains("log entry")) + .count(); + assert_eq!(guest_count, 50, "expected 50 guest logs, got {guest_count}"); + LOGGER.clear_log_calls(); + }); +} + +#[test] +fn virtq_log_backpressure_repeated() { + // Multiple calls that each trigger backpressure, verifying the + // pool recovers correctly each time. + let mut cfg = SandboxConfiguration::default(); + cfg.set_g2h_pool_pages(2); + + with_rust_sandbox_cfg(cfg, |mut sandbox| { + for _ in 0..5 { + sandbox.call::<()>("LogMessageN", 30_i32).unwrap(); + } + }); +} + +#[test] +fn virtq_backpressure_small_ring() { + // Small descriptor table forces ring-level backpressure. + use hyperlight_testing::simplelogger::{LOGGER, SimpleLogger}; + + SimpleLogger::initialize_test_logger(); + LOGGER.clear_log_calls(); + + let mut cfg = SandboxConfiguration::default(); + cfg.set_g2h_queue_depth(4); + + with_rust_uninit_sandbox_cfg(cfg, |mut sbox| { + sbox.set_max_guest_log_level(tracing_core::LevelFilter::INFO); + let mut sandbox = sbox.evolve().unwrap(); + + sandbox.call::<()>("LogMessageN", 20_i32).unwrap(); + + let guest_count = (0..LOGGER.num_log_calls()) + .filter_map(|i| LOGGER.get_log_call(i)) + .filter(|c| c.target == "hyperlight_guest" && c.args.contains("log entry")) + .count(); + assert_eq!(guest_count, 20, "expected 20 guest logs, got {guest_count}"); + LOGGER.clear_log_calls(); + }); +} + +#[test] +fn virtq_backpressure_log_then_callback() { + // Logs fill the G2H ring, then a host callback needs ring space. + // call_host_function handles backpressure by notify + reclaim + retry. + let mut cfg = SandboxConfiguration::default(); + cfg.set_g2h_queue_depth(4); + cfg.set_g2h_pool_pages(2); + + with_rust_uninit_sandbox_cfg(cfg, |mut sbox| { + sbox.set_max_guest_log_level(tracing_core::LevelFilter::INFO); + sbox.register_print(|msg: String| msg.len() as i32).unwrap(); + let mut sandbox = sbox.evolve().unwrap(); + + // PrintOutput logs and calls HostPrint callback. + // With depth=4 the logs may fill the ring, requiring + // call_host_function to handle backpressure before + // submitting the callback entry. + let res: i32 = sandbox.call("PrintOutput", "bp-test".to_string()).unwrap(); + assert_eq!(res, 7); + }); +} + +#[test] +fn virtq_backpressure_no_data_loss() { + // After backpressure recovery, verify multiple function calls + // return correct results (completion data wasn't lost by reclaim). + let mut cfg = SandboxConfiguration::default(); + cfg.set_g2h_pool_pages(2); + cfg.set_g2h_queue_depth(4); + + with_rust_uninit_sandbox_cfg(cfg, |mut sbox| { + sbox.set_max_guest_log_level(tracing_core::LevelFilter::INFO); + let mut sandbox = sbox.evolve().unwrap(); + + // Trigger backpressure with logs + sandbox.call::<()>("LogMessageN", 20_i32).unwrap(); + + // Now verify multiple function calls with return values + let res: String = sandbox.call("Echo", "first".to_string()).unwrap(); + assert_eq!(res, "first"); + + let res: String = sandbox.call("Echo", "second".to_string()).unwrap(); + assert_eq!(res, "second"); + + let res: f64 = sandbox.call("EchoDouble", 1.234_f64).unwrap(); + assert!((res - 1.234).abs() < f64::EPSILON); + }); +} diff --git a/src/tests/rust_guests/simpleguest/src/main.rs b/src/tests/rust_guests/simpleguest/src/main.rs index b6844a716..024f7d1b4 100644 --- a/src/tests/rust_guests/simpleguest/src/main.rs +++ b/src/tests/rust_guests/simpleguest/src/main.rs @@ -479,6 +479,13 @@ fn log_message(message: String, level: i32) { } } +#[guest_function("LogMessageN")] +fn log_message_n(count: i32) { + for i in 0..count { + log::info!("log entry {}", i); + } +} + #[guest_function("TriggerException")] fn trigger_exception() { // trigger an undefined instruction exception From 9ad22cf040613686cb50ec4cc80ad1153df97941 Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Tue, 7 Apr 2026 13:55:08 +0200 Subject: [PATCH 10/31] feat(virtq): use virtq for capi ret error Signed-off-by: Tomasz Andrzejak --- src/hyperlight_guest_capi/src/error.rs | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/hyperlight_guest_capi/src/error.rs b/src/hyperlight_guest_capi/src/error.rs index 03217600e..720911157 100644 --- a/src/hyperlight_guest_capi/src/error.rs +++ b/src/hyperlight_guest_capi/src/error.rs @@ -19,7 +19,6 @@ use core::ffi::{CStr, c_char}; use flatbuffers::FlatBufferBuilder; use hyperlight_common::flatbuffer_wrappers::function_types::FunctionCallResult; use hyperlight_common::flatbuffer_wrappers::guest_error::{ErrorCode, GuestError}; -use hyperlight_guest_bin::GUEST_HANDLE; use crate::alloc::borrow::ToOwned; @@ -35,12 +34,11 @@ pub extern "C" fn hl_set_error(err: ErrorCode, message: *const c_char) { let fcr = FunctionCallResult::new(guest_error); let mut builder = FlatBufferBuilder::new(); let data = fcr.encode(&mut builder); - unsafe { - #[allow(static_mut_refs)] // we are single threaded - GUEST_HANDLE - .push_shared_output_data(data) - .expect("Failed to set error") - } + + hyperlight_guest::virtq::with_context(|ctx| { + ctx.send_h2g_result(data) + .expect("Failed to send error via virtq"); + }); } #[unsafe(no_mangle)] From e3d5c152ece218738c080579fd0ec3c425acd6d3 Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Tue, 7 Apr 2026 16:26:06 +0200 Subject: [PATCH 11/31] feat(virtq): remove unused stack based io path Signed-off-by: Tomasz Andrzejak --- .../src/guest_handle/host_comm.rs | 101 ------ src/hyperlight_guest/src/guest_handle/io.rs | 150 -------- src/hyperlight_guest/src/lib.rs | 1 - src/hyperlight_guest/src/virtq/context.rs | 25 ++ src/hyperlight_guest_bin/src/host_comm.rs | 84 +++-- src/hyperlight_guest_capi/src/dispatch.rs | 24 +- src/hyperlight_guest_capi/src/flatbuffer.rs | 22 +- src/hyperlight_host/src/mem/mgr.rs | 93 +---- src/hyperlight_host/src/mem/shared_mem.rs | 326 ------------------ .../src/sandbox/initialized_multi_use.rs | 2 - src/hyperlight_host/src/sandbox/outb.rs | 298 +--------------- src/hyperlight_host/src/testing/log_values.rs | 62 ---- src/hyperlight_host/src/testing/mod.rs | 1 - .../tests/sandbox_host_tests.rs | 102 ++++++ src/tests/rust_guests/simpleguest/src/main.rs | 46 +-- 15 files changed, 230 insertions(+), 1107 deletions(-) delete mode 100644 src/hyperlight_guest/src/guest_handle/io.rs delete mode 100644 src/hyperlight_host/src/testing/log_values.rs diff --git a/src/hyperlight_guest/src/guest_handle/host_comm.rs b/src/hyperlight_guest/src/guest_handle/host_comm.rs index d440852f6..10b8e9a7a 100644 --- a/src/hyperlight_guest/src/guest_handle/host_comm.rs +++ b/src/hyperlight_guest/src/guest_handle/host_comm.rs @@ -18,21 +18,13 @@ use alloc::format; use alloc::string::ToString; use alloc::vec::Vec; -use flatbuffers::FlatBufferBuilder; -use hyperlight_common::flatbuffer_wrappers::function_call::{FunctionCall, FunctionCallType}; -use hyperlight_common::flatbuffer_wrappers::function_types::{ - FunctionCallResult, ParameterValue, ReturnType, ReturnValue, -}; use hyperlight_common::flatbuffer_wrappers::guest_error::ErrorCode; use hyperlight_common::flatbuffer_wrappers::guest_log_data::GuestLogData; use hyperlight_common::flatbuffer_wrappers::guest_log_level::LogLevel; -use hyperlight_common::flatbuffer_wrappers::util::estimate_flatbuffer_capacity; -use hyperlight_common::outb::OutBAction; use tracing::instrument; use super::handle::GuestHandle; use crate::error::{HyperlightGuestError, Result}; -use crate::exit::out32; impl GuestHandle { /// Get user memory region as bytes. @@ -59,99 +51,6 @@ impl GuestHandle { } } - /// Get a return value from a host function call. - /// This usually requires a host function to be called first using - /// `call_host_function_internal`. - /// - /// When calling `call_host_function`, this function is called - /// internally to get the return value. - #[instrument(skip_all, level = "Trace")] - pub fn get_host_return_value>(&self) -> Result { - let inner = self - .try_pop_shared_input_data_into::() - .expect("Unable to deserialize a return value from host") - .into_inner(); - - match inner { - Ok(ret) => T::try_from(ret).map_err(|_| { - let expected = core::any::type_name::(); - HyperlightGuestError::new( - ErrorCode::UnsupportedParameterType, - format!("Host return value could not be converted to expected {expected}",), - ) - }), - Err(e) => Err(HyperlightGuestError { - kind: e.code, - message: e.message, - }), - } - } - - pub fn get_host_return_raw(&self) -> Result { - let inner = self - .try_pop_shared_input_data_into::() - .expect("Unable to deserialize a return value from host") - .into_inner(); - - match inner { - Ok(ret) => Ok(ret), - Err(e) => Err(HyperlightGuestError { - kind: e.code, - message: e.message, - }), - } - } - - /// Call a host function without reading its return value from shared mem. - /// This is used by both the Rust and C APIs to reduce code duplication. - /// - /// Note: The function return value must be obtained by calling - /// `get_host_return_value`. - #[instrument(skip_all, level = "Trace")] - pub fn call_host_function_without_returning_result( - &self, - function_name: &str, - parameters: Option>, - return_type: ReturnType, - ) -> Result<()> { - let estimated_capacity = - estimate_flatbuffer_capacity(function_name, parameters.as_deref().unwrap_or(&[])); - - let host_function_call = FunctionCall::new( - function_name.to_string(), - parameters, - FunctionCallType::Host, - return_type, - ); - - let mut builder = FlatBufferBuilder::with_capacity(estimated_capacity); - - let host_function_call_buffer = host_function_call.encode(&mut builder); - self.push_shared_output_data(host_function_call_buffer)?; - - unsafe { - out32(OutBAction::CallFunction as u16, 0); - } - - Ok(()) - } - - /// Call a host function with the given parameters and return type. - /// This function serializes the function call and its parameters, - /// sends it to the host, and then retrieves the return value. - /// - /// The return value is deserialized into the specified type `T`. - #[instrument(skip_all, level = "Info")] - pub fn call_host_function>( - &self, - function_name: &str, - parameters: Option>, - return_type: ReturnType, - ) -> Result { - self.call_host_function_without_returning_result(function_name, parameters, return_type)?; - self.get_host_return_value::() - } - /// Log a message with the specified log level, source, caller, source file, and line number. pub fn log_message( &self, diff --git a/src/hyperlight_guest/src/guest_handle/io.rs b/src/hyperlight_guest/src/guest_handle/io.rs deleted file mode 100644 index 46c1d68f6..000000000 --- a/src/hyperlight_guest/src/guest_handle/io.rs +++ /dev/null @@ -1,150 +0,0 @@ -/* -Copyright 2025 The Hyperlight Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -use alloc::format; -use alloc::string::ToString; -use core::any::type_name; -use core::slice::from_raw_parts_mut; - -use hyperlight_common::flatbuffer_wrappers::guest_error::ErrorCode; -use tracing::instrument; - -use super::handle::GuestHandle; -use crate::error::{HyperlightGuestError, Result}; - -impl GuestHandle { - /// Pops the top element from the shared input data buffer and returns it as a T - #[instrument(skip_all, level = "Trace")] - pub fn try_pop_shared_input_data_into(&self) -> Result - where - T: for<'a> TryFrom<&'a [u8]>, - { - let peb_ptr = self.peb().unwrap(); - let input_stack_size = unsafe { (*peb_ptr).input_stack.size as usize }; - let input_stack_ptr = unsafe { (*peb_ptr).input_stack.ptr as *mut u8 }; - - let idb = unsafe { from_raw_parts_mut(input_stack_ptr, input_stack_size) }; - - if idb.is_empty() { - return Err(HyperlightGuestError::new( - ErrorCode::GuestError, - "Got a 0-size buffer in pop_shared_input_data_into".to_string(), - )); - } - - // get relative offset to next free address - let stack_ptr_rel: u64 = - u64::from_le_bytes(idb[..8].try_into().expect("Shared input buffer too small")); - - if stack_ptr_rel as usize > input_stack_size || stack_ptr_rel < 16 { - return Err(HyperlightGuestError::new( - ErrorCode::GuestError, - format!( - "Invalid stack pointer: {} in pop_shared_input_data_into", - stack_ptr_rel - ), - )); - } - - // go back 8 bytes and read. This is the offset to the element on top of stack - let last_element_offset_rel = u64::from_le_bytes( - idb[stack_ptr_rel as usize - 8..stack_ptr_rel as usize] - .try_into() - .expect("Invalid stack pointer in pop_shared_input_data_into"), - ); - - let buffer = &idb[last_element_offset_rel as usize..]; - - // convert the buffer to T - let type_t = match T::try_from(buffer) { - Ok(t) => Ok(t), - Err(_e) => { - return Err(HyperlightGuestError::new( - ErrorCode::GuestError, - format!("Unable to convert buffer to {}", type_name::()), - )); - } - }; - - // update the stack pointer to point to the element we just popped of since that is now free - idb[..8].copy_from_slice(&last_element_offset_rel.to_le_bytes()); - - // zero out popped off buffer - idb[last_element_offset_rel as usize..stack_ptr_rel as usize].fill(0); - - type_t - } - - /// Pushes the given data onto the shared output data buffer. - pub fn push_shared_output_data(&self, data: &[u8]) -> Result<()> { - let peb_ptr = self.peb().unwrap(); - let output_stack_size = unsafe { (*peb_ptr).output_stack.size as usize }; - let output_stack_ptr = unsafe { (*peb_ptr).output_stack.ptr as *mut u8 }; - - let odb = unsafe { from_raw_parts_mut(output_stack_ptr, output_stack_size) }; - - if odb.is_empty() { - return Err(HyperlightGuestError::new( - ErrorCode::GuestError, - "Got a 0-size buffer in push_shared_output_data".to_string(), - )); - } - - // get offset to next free address on the stack - let stack_ptr_rel: u64 = - u64::from_le_bytes(odb[..8].try_into().expect("Shared output buffer too small")); - - // check if the stack pointer is within the bounds of the buffer. - // It can be equal to the size, but never greater - // It can never be less than 8. An empty buffer's stack pointer is 8 - if stack_ptr_rel as usize > output_stack_size || stack_ptr_rel < 8 { - return Err(HyperlightGuestError::new( - ErrorCode::GuestError, - format!( - "Invalid stack pointer: {} in push_shared_output_data", - stack_ptr_rel - ), - )); - } - - // check if there is enough space in the buffer - let size_required = data.len() + 8; // the data plus the pointer pointing to the data - let size_available = output_stack_size - stack_ptr_rel as usize; - if size_required > size_available { - return Err(HyperlightGuestError::new( - ErrorCode::GuestError, - format!( - "Not enough space in shared output buffer. Required: {}, Available: {}", - size_required, size_available - ), - )); - } - - // write the actual data - odb[stack_ptr_rel as usize..stack_ptr_rel as usize + data.len()].copy_from_slice(data); - - // write the offset to the newly written data, to the top of the stack - let bytes: [u8; 8] = stack_ptr_rel.to_le_bytes(); - odb[stack_ptr_rel as usize + data.len()..stack_ptr_rel as usize + data.len() + 8] - .copy_from_slice(&bytes); - - // update stack pointer to point to next free address - let new_stack_ptr_rel: u64 = (stack_ptr_rel as usize + data.len() + 8) as u64; - odb[0..8].copy_from_slice(&(new_stack_ptr_rel).to_le_bytes()); - - Ok(()) - } -} diff --git a/src/hyperlight_guest/src/lib.rs b/src/hyperlight_guest/src/lib.rs index 9cf64280d..ed2656a6d 100644 --- a/src/hyperlight_guest/src/lib.rs +++ b/src/hyperlight_guest/src/lib.rs @@ -31,5 +31,4 @@ pub mod virtq; pub mod guest_handle { pub mod handle; pub mod host_comm; - pub mod io; } diff --git a/src/hyperlight_guest/src/virtq/context.rs b/src/hyperlight_guest/src/virtq/context.rs index 59b71bf14..44c0a83ca 100644 --- a/src/hyperlight_guest/src/virtq/context.rs +++ b/src/hyperlight_guest/src/virtq/context.rs @@ -72,6 +72,7 @@ pub struct GuestContext { g2h_producer: G2hProducer, h2g_producer: H2gProducer, generation: u64, + last_host_return: Option, } impl GuestContext { @@ -100,6 +101,7 @@ impl GuestContext { g2h_producer, h2g_producer, generation, + last_host_return: None, }; ctx.prefill_h2g(); @@ -342,4 +344,27 @@ impl GuestContext { entry.write_all(payload)?; self.g2h_producer.submit(entry) } + + /// Stash a host function return value for later retrieval. + /// + /// Used by the C API's two-step calling convention where + /// `hl_call_host_function` and `hl_get_host_return_value_as_*` + /// are separate calls. + pub fn stash_host_return(&mut self, value: ReturnValue) { + self.last_host_return = Some(value); + } + + /// Take the stashed host return value. + /// + /// Panics if no value was stashed or if the type conversion fails. + pub fn take_host_return>(&mut self) -> T { + let rv = self + .last_host_return + .take() + .expect("No host return value available"); + match T::try_from(rv) { + Ok(v) => v, + Err(_) => panic!("Host return value type mismatch"), + } + } } diff --git a/src/hyperlight_guest_bin/src/host_comm.rs b/src/hyperlight_guest_bin/src/host_comm.rs index a2dbf77e6..1fe7f9994 100644 --- a/src/hyperlight_guest_bin/src/host_comm.rs +++ b/src/hyperlight_guest_bin/src/host_comm.rs @@ -14,8 +14,10 @@ See the License for the specific language governing permissions and limitations under the License. */ -use alloc::string::ToString; +use alloc::string::{String, ToString}; use alloc::vec::Vec; +use core::ffi::{CStr, c_char}; +use core::mem; use hyperlight_common::flatbuffer_wrappers::function_call::FunctionCall; use hyperlight_common::flatbuffer_wrappers::function_types::{ @@ -25,6 +27,10 @@ use hyperlight_common::flatbuffer_wrappers::guest_error::ErrorCode; use hyperlight_common::flatbuffer_wrappers::util::get_flatbuffer_result; use hyperlight_common::func::{ParameterTuple, SupportedReturnType}; use hyperlight_guest::error::{HyperlightGuestError, Result}; +use hyperlight_guest::virtq; + +const BUFFER_SIZE: usize = 1000; +static mut MESSAGE_BUFFER: Vec = Vec::new(); use crate::GUEST_HANDLE; @@ -36,17 +42,7 @@ pub fn call_host_function( where T: TryFrom, { - #[cfg(feature = "virtq")] - { - hyperlight_guest::virtq::with_context(|ctx| { - ctx.call_host_function(function_name, parameters, return_type) - }) - } - #[cfg(not(feature = "virtq"))] - { - let handle = unsafe { GUEST_HANDLE }; - handle.call_host_function::(function_name, parameters, return_type) - } + virtq::with_context(|ctx| ctx.call_host_function(function_name, parameters, return_type)) } pub fn call_host(function_name: impl AsRef, args: impl ParameterTuple) -> Result @@ -56,25 +52,6 @@ where call_host_function::(function_name.as_ref(), Some(args.into_value()), T::TYPE) } -pub fn call_host_function_without_returning_result( - function_name: &str, - parameters: Option>, - return_type: ReturnType, -) -> Result<()> { - let handle = unsafe { GUEST_HANDLE }; - handle.call_host_function_without_returning_result(function_name, parameters, return_type) -} - -pub fn get_host_return_value_raw() -> Result { - let handle = unsafe { GUEST_HANDLE }; - handle.get_host_return_raw() -} - -pub fn get_host_return_value>() -> Result { - let handle = unsafe { GUEST_HANDLE }; - handle.get_host_return_value::() -} - pub fn read_n_bytes_from_user_memory(num: u64) -> Result> { let handle = unsafe { GUEST_HANDLE }; handle.read_n_bytes_from_user_memory(num) @@ -85,9 +62,8 @@ pub fn read_n_bytes_from_user_memory(num: u64) -> Result> { /// This function requires memory to be setup to be used. In particular, the /// existence of the input and output memory regions. pub fn print_output_with_host_print(function_call: FunctionCall) -> Result> { - let handle = unsafe { GUEST_HANDLE }; if let ParameterValue::String(message) = function_call.parameters.unwrap().remove(0) { - let res = handle.call_host_function::( + let res = call_host_function::( "HostPrint", Some(Vec::from(&[ParameterValue::String(message)])), ReturnType::Int, @@ -101,3 +77,45 @@ pub fn print_output_with_host_print(function_call: FunctionCall) -> Result( + "HostPrint", + Some(Vec::from(&[ParameterValue::String(str)])), + ReturnType::Int, + ) + .expect("Failed to call HostPrint"); + + // Clear the buffer after sending + message_buffer.clear(); + } +} diff --git a/src/hyperlight_guest_capi/src/dispatch.rs b/src/hyperlight_guest_capi/src/dispatch.rs index e0a8bc34c..245eaf700 100644 --- a/src/hyperlight_guest_capi/src/dispatch.rs +++ b/src/hyperlight_guest_capi/src/dispatch.rs @@ -20,12 +20,14 @@ use alloc::vec::Vec; use core::ffi::{CStr, c_char}; use hyperlight_common::flatbuffer_wrappers::function_call::FunctionCall; -use hyperlight_common::flatbuffer_wrappers::function_types::{ParameterType, ReturnType}; +use hyperlight_common::flatbuffer_wrappers::function_types::{ + ParameterType, ReturnType, ReturnValue, +}; use hyperlight_common::flatbuffer_wrappers::guest_error::ErrorCode; use hyperlight_guest::error::{HyperlightGuestError, Result}; +use hyperlight_guest::virtq; use hyperlight_guest_bin::guest_function::definition::GuestFunctionDefinition; use hyperlight_guest_bin::guest_function::register::GuestFunctionRegister; -use hyperlight_guest_bin::host_comm::call_host_function_without_returning_result; use crate::types::{FfiFunctionCall, FfiVec}; static mut REGISTERED_C_GUEST_FUNCTIONS: GuestFunctionRegister = @@ -98,15 +100,23 @@ pub extern "C" fn hl_register_function_definition( unsafe { (&mut *(&raw mut REGISTERED_C_GUEST_FUNCTIONS)).register(func_def) }; } -/// The caller is responsible for freeing the memory associated with given `FfiFunctionCall`. +/// Call a host function. The return value can be retrieved with +/// `hl_get_host_return_value_as_*` immediately after. #[unsafe(no_mangle)] pub extern "C" fn hl_call_host_function(function_call: &FfiFunctionCall) { let parameters = unsafe { function_call.copy_parameters() }; let func_name = unsafe { function_call.copy_function_name() }; let return_type = unsafe { function_call.copy_return_type() }; - // Use the non-generic internal implementation - // The C API will then call specific getter functions to fetch the properly typed return value - let _ = call_host_function_without_returning_result(&func_name, Some(parameters), return_type) - .expect("Failed to call host function"); + virtq::with_context(|ctx| { + let result: ReturnValue = ctx + .call_host_function(&func_name, Some(parameters), return_type) + .expect("Failed to call host function"); + ctx.stash_host_return(result); + }); +} + +/// Retrieve the return value stashed by the last `hl_call_host_function`. +pub(crate) fn take_last_host_return>() -> T { + virtq::with_context(|ctx| ctx.take_host_return::()) } diff --git a/src/hyperlight_guest_capi/src/flatbuffer.rs b/src/hyperlight_guest_capi/src/flatbuffer.rs index ff12400d6..043431e4c 100644 --- a/src/hyperlight_guest_capi/src/flatbuffer.rs +++ b/src/hyperlight_guest_capi/src/flatbuffer.rs @@ -21,8 +21,8 @@ use alloc::vec::Vec; use core::ffi::{CStr, c_char}; use hyperlight_common::flatbuffer_wrappers::util::get_flatbuffer_result; -use hyperlight_guest_bin::host_comm::get_host_return_value; +use crate::dispatch::take_last_host_return; use crate::types::FfiVec; // The reason for the capitalized type in the function names below @@ -106,44 +106,43 @@ pub extern "C" fn hl_flatbuffer_result_from_Bool(value: bool) -> Box { #[unsafe(no_mangle)] pub extern "C" fn hl_get_host_return_value_as_Int() -> i32 { - get_host_return_value().expect("Unable to get host return value as int") + take_last_host_return() } #[unsafe(no_mangle)] pub extern "C" fn hl_get_host_return_value_as_UInt() -> u32 { - get_host_return_value().expect("Unable to get host return value as uint") + take_last_host_return() } // the same for long, ulong #[unsafe(no_mangle)] pub extern "C" fn hl_get_host_return_value_as_Long() -> i64 { - get_host_return_value().expect("Unable to get host return value as long") + take_last_host_return() } #[unsafe(no_mangle)] pub extern "C" fn hl_get_host_return_value_as_ULong() -> u64 { - get_host_return_value().expect("Unable to get host return value as ulong") + take_last_host_return() } #[unsafe(no_mangle)] pub extern "C" fn hl_get_host_return_value_as_Bool() -> bool { - get_host_return_value().expect("Unable to get host return value as bool") + take_last_host_return() } #[unsafe(no_mangle)] pub extern "C" fn hl_get_host_return_value_as_Float() -> f32 { - get_host_return_value().expect("Unable to get host return value as f32") + take_last_host_return() } #[unsafe(no_mangle)] pub extern "C" fn hl_get_host_return_value_as_Double() -> f64 { - get_host_return_value().expect("Unable to get host return value as f64") + take_last_host_return() } #[unsafe(no_mangle)] pub extern "C" fn hl_get_host_return_value_as_String() -> *const c_char { - let string_value: String = - get_host_return_value().expect("Unable to get host return value as string"); + let string_value: String = take_last_host_return(); let c_string = CString::new(string_value).expect("Failed to create CString"); c_string.into_raw() @@ -151,8 +150,7 @@ pub extern "C" fn hl_get_host_return_value_as_String() -> *const c_char { #[unsafe(no_mangle)] pub extern "C" fn hl_get_host_return_value_as_VecBytes() -> Box { - let vec_value: Vec = - get_host_return_value().expect("Unable to get host return value as vec bytes"); + let vec_value: Vec = take_last_host_return(); Box::new(unsafe { FfiVec::from_vec(vec_value) }) } diff --git a/src/hyperlight_host/src/mem/mgr.rs b/src/hyperlight_host/src/mem/mgr.rs index 764e72f4c..f869b31be 100644 --- a/src/hyperlight_host/src/mem/mgr.rs +++ b/src/hyperlight_host/src/mem/mgr.rs @@ -17,12 +17,7 @@ limitations under the License. use std::mem::offset_of; use std::num::NonZeroU16; -use flatbuffers::FlatBufferBuilder; -use hyperlight_common::flatbuffer_wrappers::function_call::{ - FunctionCall, validate_guest_function_call_buffer, -}; use hyperlight_common::flatbuffer_wrappers::function_types::FunctionCallResult; -use hyperlight_common::flatbuffer_wrappers::guest_log_data::GuestLogData; use hyperlight_common::mem::PAGE_SIZE_USIZE; use hyperlight_common::virtq::msg::{MsgKind, VirtqMsgHeader}; use hyperlight_common::virtq::{self, Layout as VirtqLayout}; @@ -488,92 +483,6 @@ impl SandboxMemoryManager { Ok(()) } - /// Reads a host function call from memory - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - pub(crate) fn get_host_function_call(&mut self) -> Result { - self.scratch_mem.try_pop_buffer_into::( - self.layout.get_output_data_buffer_scratch_host_offset(), - self.layout.sandbox_memory_config.get_output_data_size(), - ) - } - - /// Writes a host function call result to memory - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - pub(crate) fn write_response_from_host_function_call( - &mut self, - res: &FunctionCallResult, - ) -> Result<()> { - let mut builder = FlatBufferBuilder::new(); - let data = res.encode(&mut builder); - - self.scratch_mem.push_buffer( - self.layout.get_input_data_buffer_scratch_host_offset(), - self.layout.sandbox_memory_config.get_input_data_size(), - data, - ) - } - - /// Writes a guest function call to memory - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - #[allow(dead_code)] - pub(crate) fn write_guest_function_call(&mut self, buffer: &[u8]) -> Result<()> { - validate_guest_function_call_buffer(buffer).map_err(|e| { - new_error!( - "Guest function call buffer validation failed: {}", - e.to_string() - ) - })?; - - self.scratch_mem.push_buffer( - self.layout.get_input_data_buffer_scratch_host_offset(), - self.layout.sandbox_memory_config.get_input_data_size(), - buffer, - )?; - Ok(()) - } - - /// Reads a function call result from memory. - /// A function call result can be either an error or a successful return value. - #[allow(dead_code)] - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - pub(crate) fn get_guest_function_call_result(&mut self) -> Result { - self.scratch_mem.try_pop_buffer_into::( - self.layout.get_output_data_buffer_scratch_host_offset(), - self.layout.sandbox_memory_config.get_output_data_size(), - ) - } - - /// Read guest log data from the `SharedMemory` contained within `self` - #[allow(dead_code)] - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - pub(crate) fn read_guest_log_data(&mut self) -> Result { - self.scratch_mem.try_pop_buffer_into::( - self.layout.get_output_data_buffer_scratch_host_offset(), - self.layout.sandbox_memory_config.get_output_data_size(), - ) - } - - pub(crate) fn clear_io_buffers(&mut self) { - // Clear the output data buffer - loop { - let Ok(_) = self.scratch_mem.try_pop_buffer_into::>( - self.layout.get_output_data_buffer_scratch_host_offset(), - self.layout.sandbox_memory_config.get_output_data_size(), - ) else { - break; - }; - } - // Clear the input data buffer - loop { - let Ok(_) = self.scratch_mem.try_pop_buffer_into::>( - self.layout.get_input_data_buffer_scratch_host_offset(), - self.layout.sandbox_memory_config.get_input_data_size(), - ) else { - break; - }; - } - } - /// This function restores a memory snapshot from a given snapshot. pub(crate) fn restore_snapshot( &mut self, @@ -1127,7 +1036,7 @@ impl SandboxMemoryManager { return Ok(fcr); } Ok(MsgKind::Log) => { - crate::sandbox::outb::emit_guest_log_from_payload(payload); + crate::sandbox::outb::emit_guest_log(payload); consumer .complete(completion) .map_err(|e| new_error!("G2H complete log: {:?}", e))?; diff --git a/src/hyperlight_host/src/mem/shared_mem.rs b/src/hyperlight_host/src/mem/shared_mem.rs index db1b407c7..f8766c347 100644 --- a/src/hyperlight_host/src/mem/shared_mem.rs +++ b/src/hyperlight_host/src/mem/shared_mem.rs @@ -14,7 +14,6 @@ See the License for the specific language governing permissions and limitations under the License. */ -use std::any::type_name; use std::ffi::c_void; use std::io::Error; use std::mem::{align_of, size_of}; @@ -1048,145 +1047,6 @@ impl HostSharedMemory { drop(guard); Ok(()) } - - /// Pushes the given data onto shared memory to the buffer at the given offset. - /// NOTE! buffer_start_offset must point to the beginning of the buffer - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - pub fn push_buffer( - &mut self, - buffer_start_offset: usize, - buffer_size: usize, - data: &[u8], - ) -> Result<()> { - let stack_pointer_rel = self.read::(buffer_start_offset)? as usize; - let buffer_size_u64: u64 = buffer_size.try_into()?; - - if stack_pointer_rel > buffer_size || stack_pointer_rel < 8 { - return Err(new_error!( - "Unable to push data to buffer: Stack pointer is out of bounds. Stack pointer: {}, Buffer size: {}", - stack_pointer_rel, - buffer_size_u64 - )); - } - - let size_required = data.len() + 8; - let size_available = buffer_size - stack_pointer_rel; - - if size_required > size_available { - return Err(new_error!( - "Not enough space in buffer to push data. Required: {}, Available: {}", - size_required, - size_available - )); - } - - // get absolute - let stack_pointer_abs = stack_pointer_rel + buffer_start_offset; - - // write the actual data to the top of stack - self.copy_from_slice(data, stack_pointer_abs)?; - - // write the offset to the newly written data, to the top of stack. - // this is used when popping the stack, to know how far back to jump - self.write::(stack_pointer_abs + data.len(), stack_pointer_rel as u64)?; - - // update stack pointer to point to the next free address - self.write::( - buffer_start_offset, - (stack_pointer_rel + data.len() + 8) as u64, - )?; - Ok(()) - } - - /// Pops the given given buffer into a `T` and returns it. - /// NOTE! the data must be a size-prefixed flatbuffer, and - /// buffer_start_offset must point to the beginning of the buffer - pub fn try_pop_buffer_into( - &mut self, - buffer_start_offset: usize, - buffer_size: usize, - ) -> Result - where - T: for<'b> TryFrom<&'b [u8]>, - { - // get the stackpointer - let stack_pointer_rel = self.read::(buffer_start_offset)? as usize; - - if stack_pointer_rel > buffer_size || stack_pointer_rel < 16 { - return Err(new_error!( - "Unable to pop data from buffer: Stack pointer is out of bounds. Stack pointer: {}, Buffer size: {}", - stack_pointer_rel, - buffer_size - )); - } - - // make it absolute - let last_element_offset_abs = stack_pointer_rel + buffer_start_offset; - - // go back 8 bytes to get offset to element on top of stack - let last_element_offset_rel: usize = - self.read::(last_element_offset_abs - 8)? as usize; - - // Validate element offset (guest-writable): must be in [8, stack_pointer_rel - 16] - // to leave room for the 8-byte back-pointer plus at least 8 bytes of element data - // (the minimum for a size-prefixed flatbuffer: 4-byte prefix + 4-byte root offset). - if last_element_offset_rel > stack_pointer_rel.saturating_sub(16) - || last_element_offset_rel < 8 - { - return Err(new_error!( - "Corrupt buffer back-pointer: element offset {} is outside valid range [8, {}].", - last_element_offset_rel, - stack_pointer_rel.saturating_sub(16), - )); - } - - // make it absolute - let last_element_offset_abs = last_element_offset_rel + buffer_start_offset; - - // Max bytes the element can span (excluding the 8-byte back-pointer). - let max_element_size = stack_pointer_rel - last_element_offset_rel - 8; - - // Get the size of the flatbuffer buffer from memory - let fb_buffer_size = { - let raw_prefix = self.read::(last_element_offset_abs)?; - // flatbuffer byte arrays are prefixed by 4 bytes indicating - // the remaining size; add 4 for the prefix itself. - let total = raw_prefix.checked_add(4).ok_or_else(|| { - new_error!( - "Corrupt buffer size prefix: value {} overflows when adding 4-byte header.", - raw_prefix - ) - })?; - usize::try_from(total) - }?; - - if fb_buffer_size > max_element_size { - return Err(new_error!( - "Corrupt buffer size prefix: flatbuffer claims {} bytes but the element slot is only {} bytes.", - fb_buffer_size, - max_element_size - )); - } - - let mut result_buffer = vec![0; fb_buffer_size]; - - self.copy_to_slice(&mut result_buffer, last_element_offset_abs)?; - let to_return = T::try_from(result_buffer.as_slice()).map_err(|_e| { - new_error!( - "pop_buffer_into: failed to convert buffer to {}", - type_name::() - ) - })?; - - // update the stack pointer to point to the element we just popped off since that is now free - self.write::(buffer_start_offset, last_element_offset_rel as u64)?; - - // zero out the memory we just popped off - let num_bytes_to_zero = stack_pointer_rel - last_element_offset_rel; - self.fill(0, last_element_offset_abs, num_bytes_to_zero)?; - - Ok(to_return) - } } impl SharedMemory for HostSharedMemory { @@ -1694,192 +1554,6 @@ mod tests { } } - /// Bounds checking for `try_pop_buffer_into` against corrupt guest data. - mod try_pop_buffer_bounds { - use super::*; - - #[derive(Debug, PartialEq)] - struct RawBytes(Vec); - - impl TryFrom<&[u8]> for RawBytes { - type Error = String; - fn try_from(value: &[u8]) -> std::result::Result { - Ok(RawBytes(value.to_vec())) - } - } - - /// Create a buffer with stack pointer initialized to 8 (empty). - fn make_buffer(mem_size: usize) -> super::super::HostSharedMemory { - let eshm = ExclusiveSharedMemory::new(mem_size).unwrap(); - let (hshm, _) = eshm.build(); - hshm.write::(0, 8u64).unwrap(); - hshm - } - - #[test] - fn normal_push_pop_roundtrip() { - let mem_size = 4096; - let mut hshm = make_buffer(mem_size); - - // Size-prefixed flatbuffer-like payload: [size: u32 LE][payload] - let payload = b"hello"; - let mut data = Vec::new(); - data.extend_from_slice(&(payload.len() as u32).to_le_bytes()); - data.extend_from_slice(payload); - - hshm.push_buffer(0, mem_size, &data).unwrap(); - let result: RawBytes = hshm.try_pop_buffer_into(0, mem_size).unwrap(); - assert_eq!(result.0, data); - } - - #[test] - fn malicious_flatbuffer_size_prefix() { - let mem_size = 4096; - let mut hshm = make_buffer(mem_size); - - let payload = b"small"; - let mut data = Vec::new(); - data.extend_from_slice(&(payload.len() as u32).to_le_bytes()); - data.extend_from_slice(payload); - hshm.push_buffer(0, mem_size, &data).unwrap(); - - // Corrupt size prefix at element start (offset 8) to near u32::MAX. - hshm.write::(8, 0xFFFF_FFFBu32).unwrap(); // +4 = 0xFFFF_FFFF - - let result: Result = hshm.try_pop_buffer_into(0, mem_size); - let err_msg = format!("{}", result.unwrap_err()); - assert!( - err_msg.contains("Corrupt buffer size prefix: flatbuffer claims 4294967295 bytes but the element slot is only 9 bytes"), - "Unexpected error message: {}", - err_msg - ); - } - - #[test] - fn malicious_element_offset_too_small() { - let mem_size = 4096; - let mut hshm = make_buffer(mem_size); - - let payload = b"test"; - let mut data = Vec::new(); - data.extend_from_slice(&(payload.len() as u32).to_le_bytes()); - data.extend_from_slice(payload); - hshm.push_buffer(0, mem_size, &data).unwrap(); - - // Corrupt back-pointer (offset 16) to 0 (before valid range). - hshm.write::(16, 0u64).unwrap(); - - let result: Result = hshm.try_pop_buffer_into(0, mem_size); - let err_msg = format!("{}", result.unwrap_err()); - assert!( - err_msg.contains( - "Corrupt buffer back-pointer: element offset 0 is outside valid range [8, 8]" - ), - "Unexpected error message: {}", - err_msg - ); - } - - #[test] - fn malicious_element_offset_past_stack_pointer() { - let mem_size = 4096; - let mut hshm = make_buffer(mem_size); - - let payload = b"test"; - let mut data = Vec::new(); - data.extend_from_slice(&(payload.len() as u32).to_le_bytes()); - data.extend_from_slice(payload); - hshm.push_buffer(0, mem_size, &data).unwrap(); - - // Corrupt back-pointer (offset 16) to 9999 (past stack pointer 24). - hshm.write::(16, 9999u64).unwrap(); - - let result: Result = hshm.try_pop_buffer_into(0, mem_size); - let err_msg = format!("{}", result.unwrap_err()); - assert!( - err_msg.contains( - "Corrupt buffer back-pointer: element offset 9999 is outside valid range [8, 8]" - ), - "Unexpected error message: {}", - err_msg - ); - } - - #[test] - fn malicious_flatbuffer_size_off_by_one() { - let mem_size = 4096; - let mut hshm = make_buffer(mem_size); - - let payload = b"abcd"; - let mut data = Vec::new(); - data.extend_from_slice(&(payload.len() as u32).to_le_bytes()); - data.extend_from_slice(payload); - hshm.push_buffer(0, mem_size, &data).unwrap(); - - // Corrupt size prefix: claim 5 bytes (total 9), exceeding the 8-byte slot. - hshm.write::(8, 5u32).unwrap(); // fb_buffer_size = 5 + 4 = 9 - - let result: Result = hshm.try_pop_buffer_into(0, mem_size); - let err_msg = format!("{}", result.unwrap_err()); - assert!( - err_msg.contains("Corrupt buffer size prefix: flatbuffer claims 9 bytes but the element slot is only 8 bytes"), - "Unexpected error message: {}", - err_msg - ); - } - - /// Back-pointer just below stack_pointer causes underflow in - /// `stack_pointer_rel - last_element_offset_rel - 8`. - #[test] - fn back_pointer_near_stack_pointer_underflow() { - let mem_size = 4096; - let mut hshm = make_buffer(mem_size); - - let payload = b"test"; - let mut data = Vec::new(); - data.extend_from_slice(&(payload.len() as u32).to_le_bytes()); - data.extend_from_slice(payload); - hshm.push_buffer(0, mem_size, &data).unwrap(); - - // stack_pointer_rel = 24. Set back-pointer to 23 (> 24 - 16 = 8, so rejected). - hshm.write::(16, 23u64).unwrap(); - - let result: Result = hshm.try_pop_buffer_into(0, mem_size); - let err_msg = format!("{}", result.unwrap_err()); - assert!( - err_msg.contains( - "Corrupt buffer back-pointer: element offset 23 is outside valid range [8, 8]" - ), - "Unexpected error message: {}", - err_msg - ); - } - - /// Size prefix of 0xFFFF_FFFD causes u32 overflow: 0xFFFF_FFFD + 4 wraps. - #[test] - fn size_prefix_u32_overflow() { - let mem_size = 4096; - let mut hshm = make_buffer(mem_size); - - let payload = b"test"; - let mut data = Vec::new(); - data.extend_from_slice(&(payload.len() as u32).to_le_bytes()); - data.extend_from_slice(payload); - hshm.push_buffer(0, mem_size, &data).unwrap(); - - // Write 0xFFFF_FFFD as size prefix: checked_add(4) returns None. - hshm.write::(8, 0xFFFF_FFFDu32).unwrap(); - - let result: Result = hshm.try_pop_buffer_into(0, mem_size); - let err_msg = format!("{}", result.unwrap_err()); - assert!( - err_msg.contains("Corrupt buffer size prefix: value 4294967293 overflows when adding 4-byte header"), - "Unexpected error message: {}", - err_msg - ); - } - } - #[cfg(target_os = "linux")] mod guard_page_crash_test { use crate::mem::shared_mem::{ExclusiveSharedMemory, SharedMemory}; diff --git a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs index 4a097358c..779ccd08a 100644 --- a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs +++ b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs @@ -782,8 +782,6 @@ impl MultiUseSandbox { // - any serialized host function call are zeroed out by us (the host) during deserialization, see `get_host_function_call` // - any serialized host function result is zeroed out by the guest during deserialization, see `get_host_return_value` if let Err(e) = &res { - self.mem_mgr.clear_io_buffers(); - // Determine if we should poison the sandbox. self.poisoned |= e.is_poison_error(); } diff --git a/src/hyperlight_host/src/sandbox/outb.rs b/src/hyperlight_host/src/sandbox/outb.rs index b5a20d31c..3fa571db2 100644 --- a/src/hyperlight_host/src/sandbox/outb.rs +++ b/src/hyperlight_host/src/sandbox/outb.rs @@ -64,32 +64,15 @@ pub enum HandleOutbError { MemProfile(String), } -#[allow(dead_code)] -#[instrument(err(Debug), skip_all, parent = Span::current(), level="Trace")] -pub(super) fn outb_log( - mgr: &mut SandboxMemoryManager, -) -> Result<(), HandleOutbError> { - let log_data: GuestLogData = mgr - .read_guest_log_data() - .map_err(|e| HandleOutbError::ReadLogData(e.to_string()))?; - - emit_guest_log(&log_data); - Ok(()) -} - /// Emit a guest log record from a virtqueue payload. /// /// Deserializes [`GuestLogData`] from the raw bytes and emits either -/// a tracing event or a log record, matching the original `outb_log` -/// behavior. -pub(crate) fn emit_guest_log_from_payload(payload: &[u8]) { +/// a tracing event or a log record. +pub(crate) fn emit_guest_log(payload: &[u8]) { let Ok(log_data) = GuestLogData::try_from(payload) else { return; }; - emit_guest_log(&log_data); -} -fn emit_guest_log(log_data: &GuestLogData) { // This code will create either a logging record or a tracing record // for the GuestLogData depending on if the host has set up a tracing // subscriber. @@ -233,7 +216,7 @@ fn outb_virtq_call( match hdr.msg_kind() { Ok(MsgKind::Log) => { let payload = &entry_data[VirtqMsgHeader::SIZE..]; - emit_guest_log_from_payload(payload); + emit_guest_log(payload); let _ = consumer.complete(completion); continue; } @@ -306,30 +289,9 @@ pub(crate) fn handle_outb( .try_into() .map_err(|e: anyhow::Error| HandleOutbError::InvalidPort(e.to_string()))? { - OutBAction::Log => { - // Legacy path - logs now arrive via G2H virtqueue - // and are processed inline by outb_virtq_call / - // read_h2g_result_from_g2h. - Ok(()) - } - OutBAction::CallFunction => { - let call = mem_mgr - .get_host_function_call() - .map_err(|e| HandleOutbError::ReadHostFunctionCall(e.to_string()))?; - let name = call.function_name.clone(); - let args: Vec = call.parameters.unwrap_or(vec![]); - let res = host_funcs - .try_lock() - .map_err(|e| HandleOutbError::LockFailed(file!(), line!(), e.to_string()))? - .call_host_function(&name, args) - .map_err(|e| GuestError::new(ErrorCode::HostFunctionError, e.to_string())); - - let func_result = FunctionCallResult::new(res); - - mem_mgr - .write_response_from_host_function_call(&func_result) - .map_err(|e| HandleOutbError::WriteHostFunctionResponse(e.to_string()))?; - + OutBAction::Log | OutBAction::CallFunction => { + // Legacy paths removed - these actions should no longer be + // emitted by the guest. Ignore gracefully. Ok(()) } OutBAction::Abort => outb_abort(mem_mgr, data), @@ -353,251 +315,3 @@ pub(crate) fn handle_outb( OutBAction::TraceMemoryFree => trace_info.handle_trace_mem_free(regs, mem_mgr), } } -#[cfg(test)] -mod tests { - use hyperlight_common::flatbuffer_wrappers::guest_log_level::LogLevel; - use hyperlight_testing::logger::{LOGGER, Logger}; - use hyperlight_testing::simple_guest_as_string; - use log::Level; - use tracing_core::callsite::rebuild_interest_cache; - - use super::outb_log; - use crate::GuestBinary; - use crate::mem::mgr::SandboxMemoryManager; - use crate::sandbox::SandboxConfiguration; - use crate::sandbox::outb::GuestLogData; - use crate::testing::log_values::test_value_as_str; - - fn new_guest_log_data(level: LogLevel) -> GuestLogData { - GuestLogData::new( - "test log".to_string(), - "test source".to_string(), - level, - "test caller".to_string(), - "test source file".to_string(), - 123, - ) - } - - #[test] - #[ignore] - fn test_log_outb_log() { - Logger::initialize_test_logger(); - LOGGER.set_max_level(log::LevelFilter::Off); - - let sandbox_cfg = SandboxConfiguration::default(); - - let new_mgr = || { - let bin = GuestBinary::FilePath(simple_guest_as_string().unwrap()); - let snapshot = crate::sandbox::snapshot::Snapshot::from_env(bin, sandbox_cfg).unwrap(); - let mgr = SandboxMemoryManager::from_snapshot(&snapshot).unwrap(); - let (hmgr, _) = mgr.build().unwrap(); - hmgr - }; - { - // We set a logger but there is no guest log data - // in memory, so expect a log operation to fail - let mut mgr = new_mgr(); - assert!(outb_log(&mut mgr).is_err()); - } - { - // Write a log message so outb_log will succeed. - // Since the logger level is set off, expect logs to be no-ops - let mut mgr = new_mgr(); - let log_msg = new_guest_log_data(LogLevel::Information); - - let guest_log_data_buffer: Vec = log_msg.try_into().unwrap(); - let offset = mgr.layout.get_output_data_buffer_scratch_host_offset(); - mgr.scratch_mem - .push_buffer( - offset, - sandbox_cfg.get_output_data_size(), - &guest_log_data_buffer, - ) - .unwrap(); - - let res = outb_log(&mut mgr); - assert!(res.is_ok()); - assert_eq!(0, LOGGER.num_log_calls()); - LOGGER.clear_log_calls(); - } - { - // now, test logging - LOGGER.set_max_level(log::LevelFilter::Trace); - let mut mgr = new_mgr(); - LOGGER.clear_log_calls(); - - // set up the logger and set the log level to the maximum - // possible (Trace) to ensure we're able to test all - // the possible branches of the match in outb_log - - let levels = vec![ - LogLevel::Trace, - LogLevel::Debug, - LogLevel::Information, - LogLevel::Warning, - LogLevel::Error, - LogLevel::Critical, - LogLevel::None, - ]; - for level in levels { - let layout = mgr.layout; - let log_data = new_guest_log_data(level); - - let guest_log_data_buffer: Vec = log_data.clone().try_into().unwrap(); - mgr.scratch_mem - .push_buffer( - layout.get_output_data_buffer_scratch_host_offset(), - sandbox_cfg.get_output_data_size(), - guest_log_data_buffer.as_slice(), - ) - .unwrap(); - - outb_log(&mut mgr).unwrap(); - - LOGGER.test_log_records(|log_calls| { - let expected_level: Level = (&level).into(); - - assert!( - log_calls - .iter() - .filter(|log_call| { - log_call.level == expected_level - && log_call.line == Some(log_data.line) - && log_call.args == log_data.message - && log_call.module_path == Some(log_data.source.clone()) - && log_call.file == Some(log_data.source_file.clone()) - }) - .count() - == 1, - "log call did not occur for level {:?}", - level.clone() - ); - }); - } - } - } - - // Tests that outb_log emits traces when a trace subscriber is set - // this test is ignored because it is incompatible with other tests , specifically those which require a logger for tracing - // marking this test as ignored means that running `cargo test` will not run this test but will allow a developer who runs that command - // from their workstation to be successful without needed to know about test interdependencies - // this test will be run explicitly as a part of the CI pipeline - #[ignore] - #[test] - fn test_trace_outb_log() { - Logger::initialize_log_tracer(); - rebuild_interest_cache(); - let subscriber = - hyperlight_testing::tracing_subscriber::TracingSubscriber::new(tracing::Level::TRACE); - let sandbox_cfg = SandboxConfiguration::default(); - tracing::subscriber::with_default(subscriber.clone(), || { - let new_mgr = || { - let bin = GuestBinary::FilePath(simple_guest_as_string().unwrap()); - let snapshot = - crate::sandbox::snapshot::Snapshot::from_env(bin, sandbox_cfg).unwrap(); - let mgr = SandboxMemoryManager::from_snapshot(&snapshot).unwrap(); - let (hmgr, _) = mgr.build().unwrap(); - hmgr - }; - - // as a span does not exist one will be automatically created - // after that there will be an event for each log message - // we are interested only in the events for the log messages that we created - - let levels = vec![ - LogLevel::Trace, - LogLevel::Debug, - LogLevel::Information, - LogLevel::Warning, - LogLevel::Error, - LogLevel::Critical, - LogLevel::None, - ]; - for level in levels { - let mut mgr = new_mgr(); - let layout = mgr.layout; - let log_data: GuestLogData = new_guest_log_data(level); - subscriber.clear(); - - let guest_log_data_buffer: Vec = log_data.try_into().unwrap(); - mgr.scratch_mem - .push_buffer( - layout.get_output_data_buffer_scratch_host_offset(), - sandbox_cfg.get_output_data_size(), - guest_log_data_buffer.as_slice(), - ) - .unwrap(); - subscriber.clear(); - outb_log(&mut mgr).unwrap(); - - subscriber.test_trace_records(|spans, events| { - let expected_level = match level { - LogLevel::Trace => "TRACE", - LogLevel::Debug => "DEBUG", - LogLevel::Information => "INFO", - LogLevel::Warning => "WARN", - LogLevel::Error => "ERROR", - LogLevel::Critical => "ERROR", - LogLevel::None => "TRACE", - }; - - // We cannot get the parent span using the `current_span()` method as by the time we get to this point that span has been exited so there is no current span - // We need to make sure that the span that we created is in the spans map instead - // We expect to have created 21 spans at this point. We are only interested in the first one that was created when calling outb_log. - - assert!( - spans.len() == 21, - "expected 21 spans, found {}", - spans.len() - ); - - let span_value = spans - .get(&1) - .unwrap() - .as_object() - .unwrap() - .get("span") - .unwrap() - .get("attributes") - .unwrap() - .as_object() - .unwrap() - .get("metadata") - .unwrap() - .as_object() - .unwrap(); - - //test_value_as_str(span_value, "level", "INFO"); - test_value_as_str(span_value, "module_path", "hyperlight_host::sandbox::outb"); - let expected_file = if cfg!(windows) { - "src\\hyperlight_host\\src\\sandbox\\outb.rs" - } else { - "src/hyperlight_host/src/sandbox/outb.rs" - }; - test_value_as_str(span_value, "file", expected_file); - test_value_as_str(span_value, "target", "hyperlight_host::sandbox::outb"); - - let mut count_matching_events = 0; - - for json_value in events { - let event_values = json_value.as_object().unwrap().get("event").unwrap(); - let metadata_values_map = - event_values.get("metadata").unwrap().as_object().unwrap(); - let event_values_map = event_values.as_object().unwrap(); - test_value_as_str(metadata_values_map, "level", expected_level); - test_value_as_str(event_values_map, "log.file", "test source file"); - test_value_as_str(event_values_map, "log.module_path", "test source"); - test_value_as_str(event_values_map, "log.target", "hyperlight_guest"); - count_matching_events += 1; - } - assert!( - count_matching_events == 1, - "trace log call did not occur for level {:?}", - level.clone() - ); - }); - } - }); - } -} diff --git a/src/hyperlight_host/src/testing/log_values.rs b/src/hyperlight_host/src/testing/log_values.rs deleted file mode 100644 index 47f40ae0a..000000000 --- a/src/hyperlight_host/src/testing/log_values.rs +++ /dev/null @@ -1,62 +0,0 @@ -/* -Copyright 2025 The Hyperlight Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -use serde_json::{Map, Value}; - -use crate::{Result, new_error}; - -/// Call `check_value_as_str` and panic if it returned an `Err`. Otherwise, -/// do nothing. -#[track_caller] -pub(crate) fn test_value_as_str(values: &Map, key: &str, expected_value: &str) { - if let Err(e) = check_value_as_str(values, key, expected_value) { - panic!("{e:?}"); - } -} - -/// Check to see if the value in `values` for key `key` matches -/// `expected_value`. If so, return `Ok(())`. Otherwise, return an `Err` -/// indicating the mismatch. -pub(crate) fn check_value_as_str( - values: &Map, - key: &str, - expected_value: &str, -) -> Result<()> { - let value = try_to_string(values, key)?; - if expected_value != value { - return Err(new_error!( - "expected value {} != value {}", - expected_value, - value - )); - } - Ok(()) -} - -/// Fetch the value in `values` with key `key` and, if it existed, convert -/// it to a string. If all those steps succeeded, return an `Ok` with the -/// string value inside. Otherwise, return an `Err`. -fn try_to_string<'a>(values: &'a Map, key: &'a str) -> Result<&'a str> { - if let Some(value) = values.get(key) { - if let Some(value_str) = value.as_str() { - Ok(value_str) - } else { - Err(new_error!("value with key {} was not a string", key)) - } - } else { - Err(new_error!("value for key {} was not found", key)) - } -} diff --git a/src/hyperlight_host/src/testing/mod.rs b/src/hyperlight_host/src/testing/mod.rs index 26776b405..503fda1ee 100644 --- a/src/hyperlight_host/src/testing/mod.rs +++ b/src/hyperlight_host/src/testing/mod.rs @@ -13,4 +13,3 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -pub(crate) mod log_values; diff --git a/src/hyperlight_host/tests/sandbox_host_tests.rs b/src/hyperlight_host/tests/sandbox_host_tests.rs index d6db20ddc..9a4b90865 100644 --- a/src/hyperlight_host/tests/sandbox_host_tests.rs +++ b/src/hyperlight_host/tests/sandbox_host_tests.rs @@ -554,3 +554,105 @@ fn virtq_backpressure_no_data_loss() { assert!((res - 1.234).abs() < f64::EPSILON); }); } + +#[test] +fn virtq_log_tracing_delivery() { + // Verify guest logs are emitted as tracing events when a tracing + // subscriber is active, matching the behavior of the old outb_log. + use hyperlight_testing::tracing_subscriber::TracingSubscriber; + + let subscriber = TracingSubscriber::new(tracing::Level::TRACE); + + tracing::subscriber::with_default(subscriber.clone(), || { + with_rust_uninit_sandbox(|mut sbox| { + sbox.set_max_guest_log_level(tracing_core::LevelFilter::INFO); + let mut sandbox = sbox.evolve().unwrap(); + + subscriber.clear(); + + sandbox + .call::<()>("LogMessage", ("tracing delivery test".to_string(), 3_i32)) + .unwrap(); + + // Guest log goes through format_trace which creates tracing + // events with log.target = "hyperlight_guest" as a field. + let events = subscriber.get_events(); + assert!( + !events.is_empty(), + "expected tracing events after guest log call, got none" + ); + }); + }); +} + +#[test] +fn virtq_log_tracing_levels() { + // Verify each guest log level produces tracing events. + use hyperlight_testing::tracing_subscriber::TracingSubscriber; + + let subscriber = TracingSubscriber::new(tracing::Level::TRACE); + + tracing::subscriber::with_default(subscriber.clone(), || { + with_rust_uninit_sandbox(|mut sbox| { + sbox.set_max_guest_log_level(tracing_core::LevelFilter::TRACE); + let mut sandbox = sbox.evolve().unwrap(); + + // Test each level: 1=Trace, 2=Debug, 3=Info, 4=Warn, 5=Error + for level in [1_i32, 2, 3, 4, 5] { + subscriber.clear(); + let msg = format!("level-test-{}", level); + sandbox.call::<()>("LogMessage", (msg, level)).unwrap(); + + let events = subscriber.get_events(); + assert!( + !events.is_empty(), + "expected tracing events for guest log level {}", + level + ); + } + }); + }); +} + +#[test] +fn virtq_invalid_guest_function_returns_error() { + // Calling a non-existent guest function should return a proper + // GuestError, not corrupt data or a hang. This validates that + // the virtq error path (MsgKind::Response with GuestError payload) + // works end-to-end. + with_rust_sandbox_cfg(SandboxConfiguration::default(), |mut sandbox| { + let res = sandbox.call::<()>("ThisFunctionDoesNotExist", ()); + assert!(res.is_err(), "expected error for non-existent function"); + let err = res.unwrap_err(); + assert!( + matches!( + err, + HyperlightError::GuestError( + hyperlight_common::flatbuffer_wrappers::guest_error::ErrorCode::GuestFunctionNotFound, + _ + ) + ), + "expected GuestFunctionNotFound, got {:?}", + err + ); + }); +} + +#[test] +fn virtq_large_payload_roundtrip() { + // Verify that larger payloads survive the virtq roundtrip without corruption. + with_rust_sandbox_cfg(SandboxConfiguration::default(), |mut sandbox| { + // 1KB string + let large_msg: String = "X".repeat(1024); + let res: String = sandbox.call("Echo", large_msg.clone()).unwrap(); + assert_eq!(res, large_msg); + + // 1KB byte array + let large_bytes = vec![0xABu8; 1024]; + let res: Vec = sandbox + .call("SetByteArrayToZero", large_bytes.clone()) + .unwrap(); + assert_eq!(res.len(), 1024); + assert!(res.iter().all(|&b| b == 0)); + }); +} diff --git a/src/tests/rust_guests/simpleguest/src/main.rs b/src/tests/rust_guests/simpleguest/src/main.rs index 024f7d1b4..b238be535 100644 --- a/src/tests/rust_guests/simpleguest/src/main.rs +++ b/src/tests/rust_guests/simpleguest/src/main.rs @@ -49,8 +49,7 @@ use hyperlight_guest_bin::exception::arch::{Context, ExceptionInfo}; use hyperlight_guest_bin::guest_function::definition::{GuestFunc, GuestFunctionDefinition}; use hyperlight_guest_bin::guest_function::register::register_function; use hyperlight_guest_bin::host_comm::{ - call_host_function, call_host_function_without_returning_result, get_host_return_value_raw, - print_output_with_host_print, read_n_bytes_from_user_memory, + call_host_function, print_output_with_host_print, read_n_bytes_from_user_memory, }; use hyperlight_guest_bin::memory::malloc; use hyperlight_guest_bin::{GUEST_HANDLE, guest_function, guest_logger, host_function}; @@ -1045,32 +1044,23 @@ fn fuzz_host_function(func: FunctionCall) -> Result> { } }; - // Because we do not know at compile time the actual return type of the host function to be called - // we cannot use the `call_host_function` generic function. - // We need to use the `call_host_function_without_returning_result` function that does not retrieve the return - // value - call_host_function_without_returning_result( - &host_func_name, - Some(params), - func.expected_return_type, - ) - .expect("failed to call host function"); - - let host_return = get_host_return_value_raw(); - match host_return { - Ok(return_value) => match return_value { - ReturnValue::Int(i) => Ok(get_flatbuffer_result(i)), - ReturnValue::UInt(i) => Ok(get_flatbuffer_result(i)), - ReturnValue::Long(i) => Ok(get_flatbuffer_result(i)), - ReturnValue::ULong(i) => Ok(get_flatbuffer_result(i)), - ReturnValue::Float(i) => Ok(get_flatbuffer_result(i)), - ReturnValue::Double(i) => Ok(get_flatbuffer_result(i)), - ReturnValue::String(str) => Ok(get_flatbuffer_result(str.as_str())), - ReturnValue::Bool(bool) => Ok(get_flatbuffer_result(bool)), - ReturnValue::Void(()) => Ok(get_flatbuffer_result(())), - ReturnValue::VecBytes(byte) => Ok(get_flatbuffer_result(byte.as_slice())), - }, - Err(e) => Err(e), + // Call the host function with dynamic return type. Since we don't + // know T at compile time, use ReturnValue as the return type and + // match on the result. + let return_value: ReturnValue = + call_host_function(&host_func_name, Some(params), func.expected_return_type)?; + + match return_value { + ReturnValue::Int(i) => Ok(get_flatbuffer_result(i)), + ReturnValue::UInt(i) => Ok(get_flatbuffer_result(i)), + ReturnValue::Long(i) => Ok(get_flatbuffer_result(i)), + ReturnValue::ULong(i) => Ok(get_flatbuffer_result(i)), + ReturnValue::Float(i) => Ok(get_flatbuffer_result(i)), + ReturnValue::Double(i) => Ok(get_flatbuffer_result(i)), + ReturnValue::String(str) => Ok(get_flatbuffer_result(str.as_str())), + ReturnValue::Bool(bool) => Ok(get_flatbuffer_result(bool)), + ReturnValue::Void(()) => Ok(get_flatbuffer_result(())), + ReturnValue::VecBytes(byte) => Ok(get_flatbuffer_result(byte.as_slice())), } } From 395dd12c8ea6c844cbc098642eabb8c1ad84819a Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Tue, 7 Apr 2026 19:13:58 +0200 Subject: [PATCH 12/31] feat(virtq): remove input output regions from ABI Signed-off-by: Tomasz Andrzejak --- fuzz/fuzz_targets/host_call.rs | 6 +- .../src/arch/aarch64/layout.rs | 7 +- .../src/arch/amd64/layout.rs | 12 +- src/hyperlight_common/src/arch/i686/layout.rs | 7 +- src/hyperlight_common/src/layout.rs | 21 +-- src/hyperlight_common/src/mem.rs | 2 - src/hyperlight_guest_capi/src/dispatch.rs | 20 +- src/hyperlight_host/benches/benchmarks.rs | 6 +- .../src/hypervisor/hyperlight_vm/x86_64.rs | 11 +- src/hyperlight_host/src/mem/layout.rs | 128 +------------ src/hyperlight_host/src/mem/mgr.rs | 11 -- src/hyperlight_host/src/sandbox/config.rs | 173 ++++++------------ .../src/sandbox/initialized_multi_use.rs | 8 +- .../src/sandbox/uninitialized.rs | 12 +- src/hyperlight_host/tests/integration_test.rs | 36 ---- .../tests/sandbox_host_tests.rs | 5 +- src/tests/rust_guests/simpleguest/src/main.rs | 49 +---- 17 files changed, 105 insertions(+), 409 deletions(-) diff --git a/fuzz/fuzz_targets/host_call.rs b/fuzz/fuzz_targets/host_call.rs index b0d37cf1a..0559b2bd6 100644 --- a/fuzz/fuzz_targets/host_call.rs +++ b/fuzz/fuzz_targets/host_call.rs @@ -33,9 +33,9 @@ static SANDBOX: OnceLock> = OnceLock::new(); fuzz_target!( init: { let mut cfg = SandboxConfiguration::default(); - cfg.set_output_data_size(64 * 1024); // 64 KB output buffer - cfg.set_input_data_size(64 * 1024); // 64 KB input buffer - cfg.set_scratch_size(512 * 1024); // large scratch region to contain those buffers, any data copies, etc. + cfg.set_g2h_pool_pages(16); // 64 KB / 4096 = 16 pages + cfg.set_h2g_pool_pages(16); // 64 KB / 4096 = 16 pages + cfg.set_scratch_size(512 * 1024); // large scratch region let u_sbox = UninitializedSandbox::new( GuestBinary::FilePath(simple_guest_for_fuzzing_as_string().expect("Guest Binary Missing")), Some(cfg) diff --git a/src/hyperlight_common/src/arch/aarch64/layout.rs b/src/hyperlight_common/src/arch/aarch64/layout.rs index 25bd99a1e..9f9c504a6 100644 --- a/src/hyperlight_common/src/arch/aarch64/layout.rs +++ b/src/hyperlight_common/src/arch/aarch64/layout.rs @@ -20,11 +20,6 @@ pub const SNAPSHOT_PT_GVA_MIN: usize = 0xffff_8000_0000_0000; pub const SNAPSHOT_PT_GVA_MAX: usize = 0xffff_80ff_ffff_ffff; pub const MAX_GPA: usize = 0x0000_000f_ffff_ffff; -pub fn min_scratch_size( - _input_data_size: usize, - _output_data_size: usize, - _g2h_num_descs: usize, - _h2g_num_descs: usize, -) -> usize { +pub fn min_scratch_size(_g2h_num_descs: usize, _h2g_num_descs: usize) -> usize { unimplemented!("min_scratch_size") } diff --git a/src/hyperlight_common/src/arch/amd64/layout.rs b/src/hyperlight_common/src/arch/amd64/layout.rs index 4731f21b2..12644de6c 100644 --- a/src/hyperlight_common/src/arch/amd64/layout.rs +++ b/src/hyperlight_common/src/arch/amd64/layout.rs @@ -37,17 +37,11 @@ pub const MAX_GPA: usize = 0x0000_000f_ffff_ffff; /// - A page for the smallest possible non-exception stack /// - (up to) 3 pages for mapping that /// - Two pages for the exception stack and metadata -/// - A page-aligned amount of memory for I/O buffers and virtqueue rings -pub fn min_scratch_size( - input_data_size: usize, - output_data_size: usize, - g2h_num_descs: usize, - h2g_num_descs: usize, -) -> usize { +/// - A page-aligned amount of memory for virtqueue rings +pub fn min_scratch_size(g2h_num_descs: usize, h2g_num_descs: usize) -> usize { let g2h_ring_size = crate::virtq::Layout::query_size(g2h_num_descs); let h2g_ring_size = crate::virtq::Layout::query_size(h2g_num_descs); - (input_data_size + output_data_size + g2h_ring_size + h2g_ring_size) - .next_multiple_of(crate::vmem::PAGE_SIZE) + (g2h_ring_size + h2g_ring_size).next_multiple_of(crate::vmem::PAGE_SIZE) + 12 * crate::vmem::PAGE_SIZE } diff --git a/src/hyperlight_common/src/arch/i686/layout.rs b/src/hyperlight_common/src/arch/i686/layout.rs index ff93f1d09..54a651cc2 100644 --- a/src/hyperlight_common/src/arch/i686/layout.rs +++ b/src/hyperlight_common/src/arch/i686/layout.rs @@ -21,11 +21,6 @@ pub const MAX_GVA: usize = 0xffff_ffff; /// regions are large enough to reach that address. pub const MAX_GPA: usize = 0xFEDF_FFFF; -pub fn min_scratch_size( - _input_data_size: usize, - _output_data_size: usize, - _g2h_num_descs: usize, - _h2g_num_descs: usize, -) -> usize { +pub fn min_scratch_size(_g2h_num_descs: usize, _h2g_num_descs: usize) -> usize { crate::vmem::PAGE_SIZE } diff --git a/src/hyperlight_common/src/layout.rs b/src/hyperlight_common/src/layout.rs index 5bb8a2cb5..7448b7d0d 100644 --- a/src/hyperlight_common/src/layout.rs +++ b/src/hyperlight_common/src/layout.rs @@ -83,27 +83,20 @@ pub const fn scratch_top_ptr(offset: u64) -> *mut T { /// Compute the byte offset from the scratch base to the G2H ring. /// -/// TODO(virtq): Remove input/output -pub const fn g2h_ring_scratch_offset(input_data_size: usize, output_data_size: usize) -> usize { - let io_off = input_data_size + output_data_size; - let align = crate::virtq::Descriptor::ALIGN; - - (io_off + align - 1) & !(align - 1) +/// The G2H ring starts at offset 0, aligned to descriptor alignment. +pub const fn g2h_ring_scratch_offset() -> usize { + 0 } /// Compute the byte offset from the scratch base to the H2G ring. /// -/// TODO(ring): Remove input/output -pub const fn h2g_ring_scratch_offset( - input_data_size: usize, - output_data_size: usize, - g2h_num_descs: usize, -) -> usize { - let g2h_offset = g2h_ring_scratch_offset(input_data_size, output_data_size); +/// The H2G ring follows immediately after the G2H ring, aligned to +/// descriptor alignment. +pub const fn h2g_ring_scratch_offset(g2h_num_descs: usize) -> usize { let g2h_size = crate::virtq::Layout::query_size(g2h_num_descs); let align = crate::virtq::Descriptor::ALIGN; - (g2h_offset + g2h_size + align - 1) & !(align - 1) + (g2h_size + align - 1) & !(align - 1) } /// Compute the minimum scratch region size needed for a sandbox. diff --git a/src/hyperlight_common/src/mem.rs b/src/hyperlight_common/src/mem.rs index fb850acc8..1cdd65cef 100644 --- a/src/hyperlight_common/src/mem.rs +++ b/src/hyperlight_common/src/mem.rs @@ -68,8 +68,6 @@ impl Default for FileMappingInfo { #[derive(Debug, Clone, Copy)] #[repr(C)] pub struct HyperlightPEB { - pub input_stack: GuestMemoryRegion, - pub output_stack: GuestMemoryRegion, pub init_data: GuestMemoryRegion, pub guest_heap: GuestMemoryRegion, /// File mappings array descriptor. diff --git a/src/hyperlight_guest_capi/src/dispatch.rs b/src/hyperlight_guest_capi/src/dispatch.rs index 245eaf700..4fd61df44 100644 --- a/src/hyperlight_guest_capi/src/dispatch.rs +++ b/src/hyperlight_guest_capi/src/dispatch.rs @@ -109,10 +109,22 @@ pub extern "C" fn hl_call_host_function(function_call: &FfiFunctionCall) { let return_type = unsafe { function_call.copy_return_type() }; virtq::with_context(|ctx| { - let result: ReturnValue = ctx - .call_host_function(&func_name, Some(parameters), return_type) - .expect("Failed to call host function"); - ctx.stash_host_return(result); + match ctx.call_host_function::(&func_name, Some(parameters), return_type) { + Ok(result) => ctx.stash_host_return(result), + Err(e) => { + // Host function returned an error. Abort with the error + // message so the host can capture it via the abort buffer. + let msg = alloc::ffi::CString::new(e.message) + .unwrap_or_else(|_| alloc::ffi::CString::new("host error").unwrap()); + + unsafe { + hyperlight_guest::exit::abort_with_code_and_message( + &[e.kind as u8], + msg.as_ptr(), + ); + } + } + } }); } diff --git a/src/hyperlight_host/benches/benchmarks.rs b/src/hyperlight_host/benches/benchmarks.rs index 462e8908d..61e925116 100644 --- a/src/hyperlight_host/benches/benchmarks.rs +++ b/src/hyperlight_host/benches/benchmarks.rs @@ -380,9 +380,9 @@ fn guest_call_benchmark_large_param(c: &mut Criterion) { let large_string = String::from_utf8(large_vec.clone()).unwrap(); let mut config = SandboxConfiguration::default(); - config.set_input_data_size(2 * SIZE + (1024 * 1024)); // 2 * SIZE + 1 MB, to allow 1MB for the rest of the serialized function call + config.set_h2g_pool_pages((2 * SIZE + (1024 * 1024)) / 4096); // pool pages for the large input config.set_heap_size(SIZE as u64 * 15); - config.set_scratch_size(6 * SIZE + 4 * (1024 * 1024)); // Big enough for the IO data regions and enough of the heap to be used + config.set_scratch_size(6 * SIZE + 4 * (1024 * 1024)); // Big enough for any data copies, etc. let sandbox = UninitializedSandbox::new( GuestBinary::FilePath(simple_guest_as_string().unwrap()), @@ -465,7 +465,7 @@ fn sample_workloads_benchmark(c: &mut Criterion) { fn bench_24k_in_8k_out(b: &mut criterion::Bencher, guest_path: String) { let mut cfg = SandboxConfiguration::default(); - cfg.set_input_data_size(25 * 1024); + cfg.set_h2g_pool_pages(7); // 25 * 1024 / 4096 ~= 7 pages let mut sandbox = UninitializedSandbox::new(GuestBinary::FilePath(guest_path), Some(cfg)) .unwrap() diff --git a/src/hyperlight_host/src/hypervisor/hyperlight_vm/x86_64.rs b/src/hyperlight_host/src/hypervisor/hyperlight_vm/x86_64.rs index f06c94964..39bb1b56f 100644 --- a/src/hyperlight_host/src/hypervisor/hyperlight_vm/x86_64.rs +++ b/src/hyperlight_host/src/hypervisor/hyperlight_vm/x86_64.rs @@ -2110,17 +2110,18 @@ mod tests { } /// Creates VM with guest code that: dirtys FPU (if flag==0), does FXSAVE to buffer, sets flag=1. - /// Uses output_data region for FXSAVE buffer (like regular guest output), scratch for stack. + /// Uses scratch region after rings for FXSAVE buffer. fn hyperlight_vm_with_mem_mgr_fxsave() -> FxsaveTestContext { use iced_x86::code_asm::*; // Compute fixed addresses for FXSAVE buffer and flag. - // These are in the output_data region which starts at a known offset. - // We use a default SandboxConfiguration to get the same layout as create_test_vm_context. + // We use the page-table area in scratch after rings as a + // convenient 512-byte aligned buffer for FXSAVE. let config: SandboxConfiguration = Default::default(); let layout = SandboxMemoryLayout::new(config, 512, 4096, None).unwrap(); - let fxsave_offset = layout.get_output_data_buffer_scratch_host_offset(); - let fxsave_gva = layout.get_output_data_buffer_gva(); + let fxsave_offset = layout.get_pt_base_scratch_offset(); + let fxsave_gva = hyperlight_common::layout::scratch_base_gva(config.get_scratch_size()) + + fxsave_offset as u64; let flag_gva = fxsave_gva + 512; let mut a = CodeAssembler::new(64).unwrap(); diff --git a/src/hyperlight_host/src/mem/layout.rs b/src/hyperlight_host/src/mem/layout.rs index 2b821a965..52bc1af70 100644 --- a/src/hyperlight_host/src/mem/layout.rs +++ b/src/hyperlight_host/src/mem/layout.rs @@ -47,17 +47,12 @@ limitations under the License. //! //! There is also a scratch region at the top of physical memory, //! which is mostly laid out as a large undifferentiated blob of -//! memory, although at present the snapshot process specially -//! privileges the statically allocated input and output data regions: +//! memory: //! //! +-------------------------------------------+ (top of physical memory) //! | Exception Stack, Metadata | //! +-------------------------------------------+ (1 page below) //! | Scratch Memory | -//! +-------------------------------------------+ -//! | Output Data | -//! +-------------------------------------------+ -//! | Input Data | //! +-------------------------------------------+ (scratch size) use std::fmt::Debug; @@ -223,8 +218,6 @@ pub(crate) struct SandboxMemoryLayout { /// The following fields are offsets to the actual PEB struct fields. /// They are used when writing the PEB struct itself peb_offset: usize, - peb_input_data_offset: usize, - peb_output_data_offset: usize, peb_init_data_offset: usize, peb_heap_data_offset: usize, #[cfg(feature = "nanvix-unstable")] @@ -267,14 +260,6 @@ impl Debug for SandboxMemoryLayout { .field("PEB Address", &format_args!("{:#x}", self.peb_address)) .field("PEB Offset", &format_args!("{:#x}", self.peb_offset)) .field("Code Size", &format_args!("{:#x}", self.code_size)) - .field( - "Input Data Offset", - &format_args!("{:#x}", self.peb_input_data_offset), - ) - .field( - "Output Data Offset", - &format_args!("{:#x}", self.peb_output_data_offset), - ) .field( "Init Data Offset", &format_args!("{:#x}", self.peb_init_data_offset), @@ -320,9 +305,6 @@ impl SandboxMemoryLayout { /// The base address of the sandbox's memory. pub(crate) const BASE_ADDRESS: usize = 0x1000; - // the offset into a sandbox's input/output buffer where the stack starts - pub(crate) const STACK_POINTER_SIZE_BYTES: u64 = 8; - /// Create a new `SandboxMemoryLayout` with the given /// `SandboxConfiguration`, code size and stack/heap size. #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] @@ -338,8 +320,6 @@ impl SandboxMemoryLayout { return Err(MemoryRequestTooBig(scratch_size, Self::MAX_MEMORY_SIZE)); } let min_scratch_size = hyperlight_common::layout::min_scratch_size( - cfg.get_input_data_size(), - cfg.get_output_data_size(), cfg.get_g2h_queue_depth(), cfg.get_h2g_queue_depth(), ); @@ -350,8 +330,6 @@ impl SandboxMemoryLayout { let guest_code_offset = 0; // The following offsets are to the fields of the PEB struct itself! let peb_offset = code_size.next_multiple_of(PAGE_SIZE_USIZE); - let peb_input_data_offset = peb_offset + offset_of!(HyperlightPEB, input_stack); - let peb_output_data_offset = peb_offset + offset_of!(HyperlightPEB, output_stack); let peb_init_data_offset = peb_offset + offset_of!(HyperlightPEB, init_data); let peb_heap_data_offset = peb_offset + offset_of!(HyperlightPEB, guest_heap); #[cfg(feature = "nanvix-unstable")] @@ -386,8 +364,6 @@ impl SandboxMemoryLayout { let mut ret = Self { peb_offset, heap_size, - peb_input_data_offset, - peb_output_data_offset, peb_init_data_offset, peb_heap_data_offset, #[cfg(feature = "nanvix-unstable")] @@ -408,13 +384,6 @@ impl SandboxMemoryLayout { Ok(ret) } - /// Get the offset in guest memory to the output data size - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - pub(super) fn get_output_data_size_offset(&self) -> usize { - // The size field is the first field in the `OutputData` struct - self.peb_output_data_offset - } - /// Get the offset in guest memory to the init data size #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub(super) fn get_init_data_size_offset(&self) -> usize { @@ -427,14 +396,6 @@ impl SandboxMemoryLayout { self.scratch_size } - /// Get the offset in guest memory to the output data pointer. - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_output_data_pointer_offset(&self) -> usize { - // This field is immediately after the output data size field, - // which is a `u64`. - self.get_output_data_size_offset() + size_of::() - } - /// Get the offset in guest memory to the init data pointer. #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub(super) fn get_init_data_pointer_offset(&self) -> usize { @@ -443,54 +404,9 @@ impl SandboxMemoryLayout { self.get_init_data_size_offset() + size_of::() } - /// Get the guest virtual address of the start of output data. - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - pub(crate) fn get_output_data_buffer_gva(&self) -> u64 { - hyperlight_common::layout::scratch_base_gva(self.scratch_size) - + self.sandbox_memory_config.get_input_data_size() as u64 - } - - /// Get the offset into the host scratch buffer of the start of - /// the output data. - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - pub(crate) fn get_output_data_buffer_scratch_host_offset(&self) -> usize { - self.sandbox_memory_config.get_input_data_size() - } - - /// Get the offset in guest memory to the input data size. - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - pub(super) fn get_input_data_size_offset(&self) -> usize { - // The input data size is the first field in the input stack's `GuestMemoryRegion` struct - self.peb_input_data_offset - } - - /// Get the offset in guest memory to the input data pointer. - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_input_data_pointer_offset(&self) -> usize { - // The input data pointer is immediately after the input - // data size field in the input data `GuestMemoryRegion` struct which is a `u64`. - self.get_input_data_size_offset() + size_of::() - } - - /// Get the guest virtual address of the start of input data - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_input_data_buffer_gva(&self) -> u64 { - hyperlight_common::layout::scratch_base_gva(self.scratch_size) - } - - /// Get the offset into the host scratch buffer of the start of - /// the input data - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - pub(crate) fn get_input_data_buffer_scratch_host_offset(&self) -> usize { - 0 - } - /// Get the offset into the scratch region of the G2H ring. fn get_g2h_ring_scratch_offset(&self) -> usize { - hyperlight_common::layout::g2h_ring_scratch_offset( - self.sandbox_memory_config.get_input_data_size(), - self.sandbox_memory_config.get_output_data_size(), - ) + hyperlight_common::layout::g2h_ring_scratch_offset() } /// Get the size of the G2H ring in bytes. @@ -504,8 +420,6 @@ impl SandboxMemoryLayout { /// Get the offset into the scratch region of the H2G ring. fn get_h2g_ring_scratch_offset(&self) -> usize { hyperlight_common::layout::h2g_ring_scratch_offset( - self.sandbox_memory_config.get_input_data_size(), - self.sandbox_memory_config.get_output_data_size(), self.sandbox_memory_config.get_g2h_queue_depth(), ) } @@ -637,8 +551,6 @@ impl SandboxMemoryLayout { #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub(crate) fn set_pt_size(&mut self, size: usize) -> Result<()> { let min_fixed_scratch = hyperlight_common::layout::min_scratch_size( - self.sandbox_memory_config.get_input_data_size(), - self.sandbox_memory_config.get_output_data_size(), self.sandbox_memory_config.get_g2h_queue_depth(), self.sandbox_memory_config.get_h2g_queue_depth(), ); @@ -800,34 +712,6 @@ impl SandboxMemoryLayout { // Start of setting up the PEB. The following are in the order of the PEB fields - // Set up input buffer pointer - write_u64( - mem, - self.get_input_data_size_offset(), - self.sandbox_memory_config - .get_input_data_size() - .try_into()?, - )?; - write_u64( - mem, - self.get_input_data_pointer_offset(), - self.get_input_data_buffer_gva(), - )?; - - // Set up output buffer pointer - write_u64( - mem, - self.get_output_data_size_offset(), - self.sandbox_memory_config - .get_output_data_size() - .try_into()?, - )?; - write_u64( - mem, - self.get_output_data_pointer_offset(), - self.get_output_data_buffer_gva(), - )?; - // Set up init data pointer write_u64( mem, @@ -859,12 +743,7 @@ impl SandboxMemoryLayout { // End of setting up the PEB - // The input and output data regions do not have their layout - // initialised here, because they are in the scratch - // region---they are instead set in - // [`SandboxMemoryManager::update_scratch_bookkeeping`]. - // - // Virtqueue ring layouts are also communicated via scratch-top + // Virtqueue ring layouts are communicated via scratch-top // metadata (queue depths), not the PEB. Both host and guest // compute ring addresses from shared offset functions. @@ -944,7 +823,6 @@ mod tests { let mut cfg = SandboxConfiguration::default(); // scratch_size exceeds 16 GiB limit cfg.set_scratch_size(17 * 1024 * 1024 * 1024); - cfg.set_input_data_size(16 * 1024 * 1024 * 1024); let layout = SandboxMemoryLayout::new(cfg, 4096, 4096, None); assert!(matches!(layout.unwrap_err(), MemoryRequestTooBig(..))); } diff --git a/src/hyperlight_host/src/mem/mgr.rs b/src/hyperlight_host/src/mem/mgr.rs index f869b31be..a26aef4d4 100644 --- a/src/hyperlight_host/src/mem/mgr.rs +++ b/src/hyperlight_host/src/mem/mgr.rs @@ -568,17 +568,6 @@ impl SandboxMemoryManager { self.snapshot_count, )?; - // Initialise the guest input and output data buffers in - // scratch memory. TODO: remove the need for this. - self.scratch_mem.write::( - self.layout.get_input_data_buffer_scratch_host_offset(), - SandboxMemoryLayout::STACK_POINTER_SIZE_BYTES, - )?; - self.scratch_mem.write::( - self.layout.get_output_data_buffer_scratch_host_offset(), - SandboxMemoryLayout::STACK_POINTER_SIZE_BYTES, - )?; - // Write virtqueue metadata to scratch-top so the guest can // discover ring locations without reading the PEB. self.update_scratch_bookkeeping_item( diff --git a/src/hyperlight_host/src/sandbox/config.rs b/src/hyperlight_host/src/sandbox/config.rs index b3e5fd6d3..da068b2cd 100644 --- a/src/hyperlight_host/src/sandbox/config.rs +++ b/src/hyperlight_host/src/sandbox/config.rs @@ -14,9 +14,9 @@ See the License for the specific language governing permissions and limitations under the License. */ -use std::cmp::max; use std::time::Duration; +use hyperlight_common::mem::PAGE_SIZE_USIZE; #[cfg(target_os = "linux")] use libc::c_int; use tracing::{Span, instrument}; @@ -44,12 +44,6 @@ pub struct SandboxConfiguration { /// Guest gdb debug port #[cfg(gdb)] guest_debug_info: Option, - /// The size of the memory buffer that is made available for input to the - /// Guest Binary - input_data_size: usize, - /// The size of the memory buffer that is made available for input to the - /// Guest Binary - output_data_size: usize, /// The heap size to use in the guest sandbox. If set to 0, the heap /// size will be determined from the PE file header /// @@ -74,31 +68,29 @@ pub struct SandboxConfiguration { interrupt_vcpu_sigrtmin_offset: u8, /// How much writable memory to offer the guest scratch_size: usize, - /// Number of descriptors for the G2H (guest-to-host) virtqueue. Must be a power of 2. + /// Number of descriptors for the guest-to-host virtqueue. Must be a power of 2. /// Default: 64 sized to 2x H2G depth for deadlock prevention. g2h_queue_depth: usize, /// Number of descriptors for the host-to-guest virtqueue. Must be a power of 2. /// Default: 32 h2g_queue_depth: usize, /// Number of physical pages for the G2H (guest-to-host) buffer pool. - /// If not set, derived from `input_data_size` for backward compatibility. - /// Default: 8 pages (32KB). + /// When None, falls back to deprecated `output_data_size` or default. g2h_pool_pages: Option, /// Number of physical pages for the H2G (host-to-guest) buffer pool. - /// If not set, derived from `output_data_size` for backward compatibility. - /// Default: 4 page (16KB). + /// When None, falls back to deprecated `input_data_size` or default. h2g_pool_pages: Option, + /// Deprecated: use `g2h_pool_pages` instead. + /// When set (non-zero), translates to `g2h_pool_pages` if pool pages + /// are not explicitly configured. + output_data_size: usize, + /// Deprecated: use `h2g_pool_pages` instead. + /// When set (non-zero), translates to `h2g_pool_pages` if pool pages + /// are not explicitly configured. + input_data_size: usize, } impl SandboxConfiguration { - /// The default size of input data - pub const DEFAULT_INPUT_SIZE: usize = 0x4000; - /// The minimum size of input data - pub const MIN_INPUT_SIZE: usize = 0x2000; - /// The default size of output data - pub const DEFAULT_OUTPUT_SIZE: usize = 0x4000; - /// The minimum size of output data - pub const MIN_OUTPUT_SIZE: usize = 0x2000; /// The default interrupt retry delay pub const DEFAULT_INTERRUPT_RETRY_DELAY: Duration = Duration::from_micros(500); /// The default signal offset from `SIGRTMIN` used to determine the signal number for interrupting @@ -120,8 +112,6 @@ impl SandboxConfiguration { /// Create a new configuration for a sandbox with the given sizes. #[instrument(skip_all, parent = Span::current(), level= "Trace")] fn new( - input_data_size: usize, - output_data_size: usize, heap_size_override: Option, scratch_size: usize, interrupt_retry_delay: Duration, @@ -130,8 +120,6 @@ impl SandboxConfiguration { #[cfg(crashdump)] guest_core_dump: bool, ) -> Self { Self { - input_data_size: max(input_data_size, Self::MIN_INPUT_SIZE), - output_data_size: max(output_data_size, Self::MIN_OUTPUT_SIZE), heap_size_override: heap_size_override.unwrap_or(0), scratch_size, interrupt_retry_delay, @@ -140,6 +128,8 @@ impl SandboxConfiguration { h2g_queue_depth: Self::DEFAULT_H2G_QUEUE_DEPTH, g2h_pool_pages: None, h2g_pool_pages: None, + output_data_size: 0, + input_data_size: 0, #[cfg(gdb)] guest_debug_info, #[cfg(crashdump)] @@ -147,26 +137,6 @@ impl SandboxConfiguration { } } - /// Set the size of the legacy input data buffer (host-to-guest). - /// - /// Deprecated: use [`set_h2g_pool_pages`](Self::set_h2g_pool_pages) instead. - /// When `h2g_pool_pages` is not set, the H2G pool size is derived - /// from this value for backward compatibility. - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - pub fn set_input_data_size(&mut self, input_data_size: usize) { - self.input_data_size = max(input_data_size, Self::MIN_INPUT_SIZE); - } - - /// Set the size of the legacy output data buffer (guest-to-host). - /// - /// Deprecated: use [`set_g2h_pool_pages`](Self::set_g2h_pool_pages) instead. - /// When `g2h_pool_pages` is not set, the G2H pool size is derived - /// from this value for backward compatibility. - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - pub fn set_output_data_size(&mut self, output_data_size: usize) { - self.output_data_size = max(output_data_size, Self::MIN_OUTPUT_SIZE); - } - /// Set the heap size to use in the guest sandbox. If set to 0, the heap size will be determined from the PE file header #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub fn set_heap_size(&mut self, heap_size: u64) { @@ -226,16 +196,6 @@ impl SandboxConfiguration { self.guest_debug_info = Some(debug_info); } - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - pub(crate) fn get_input_data_size(&self) -> usize { - self.input_data_size - } - - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - pub(crate) fn get_output_data_size(&self) -> usize { - self.output_data_size - } - #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub(crate) fn get_scratch_size(&self) -> usize { self.scratch_size @@ -266,28 +226,36 @@ impl SandboxConfiguration { } /// Get the number of G2H buffer pool pages. - /// Falls back to deriving from `output_data_size` if not explicitly set - /// (output = guest-to-host direction). + /// + /// Priority: explicit `g2h_pool_pages` > derived from deprecated + /// `output_data_size` > default. #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub fn get_g2h_pool_pages(&self) -> usize { self.g2h_pool_pages.unwrap_or_else(|| { - let pages = self - .output_data_size - .div_ceil(hyperlight_common::mem::PAGE_SIZE_USIZE); - pages.max(Self::DEFAULT_G2H_POOL_PAGES) + if self.output_data_size > 0 { + self.output_data_size + .div_ceil(PAGE_SIZE_USIZE) + .max(Self::DEFAULT_G2H_POOL_PAGES) + } else { + Self::DEFAULT_G2H_POOL_PAGES + } }) } /// Get the number of H2G buffer pool pages. - /// Falls back to deriving from `input_data_size` if not explicitly set - /// (input = host-to-guest direction). + /// + /// Priority: explicit `h2g_pool_pages` > derived from deprecated + /// `input_data_size` > default. #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub fn get_h2g_pool_pages(&self) -> usize { self.h2g_pool_pages.unwrap_or_else(|| { - let pages = self - .input_data_size - .div_ceil(hyperlight_common::mem::PAGE_SIZE_USIZE); - pages.max(Self::DEFAULT_H2G_POOL_PAGES) + if self.input_data_size > 0 { + self.input_data_size + .div_ceil(PAGE_SIZE_USIZE) + .max(Self::DEFAULT_H2G_POOL_PAGES) + } else { + Self::DEFAULT_H2G_POOL_PAGES + } }) } @@ -303,6 +271,24 @@ impl SandboxConfiguration { self.h2g_pool_pages = Some(pages); } + /// Deprecated: use [`set_g2h_pool_pages`](Self::set_g2h_pool_pages). + /// + /// Sets the output data size. If `g2h_pool_pages` is not explicitly + /// set, this value is translated to pool pages. + #[deprecated(note = "use set_g2h_pool_pages instead")] + pub fn set_output_data_size(&mut self, size: usize) { + self.output_data_size = size; + } + + /// Deprecated: use [`set_h2g_pool_pages`](Self::set_h2g_pool_pages). + /// + /// Sets the input data size. If `h2g_pool_pages` is not explicitly + /// set, this value is translated to pool pages. + #[deprecated(note = "use set_h2g_pool_pages instead")] + pub fn set_input_data_size(&mut self, size: usize) { + self.input_data_size = size; + } + /// Set the size of the scratch regiong #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub fn set_scratch_size(&mut self, scratch_size: usize) { @@ -339,8 +325,6 @@ impl Default for SandboxConfiguration { #[instrument(skip_all, parent = Span::current(), level= "Trace")] fn default() -> Self { Self::new( - Self::DEFAULT_INPUT_SIZE, - Self::DEFAULT_OUTPUT_SIZE, None, Self::DEFAULT_SCRATCH_SIZE, Self::DEFAULT_INTERRUPT_RETRY_DELAY, @@ -360,12 +344,8 @@ mod tests { #[test] fn overrides() { const HEAP_SIZE_OVERRIDE: u64 = 0x50000; - const INPUT_DATA_SIZE_OVERRIDE: usize = 0x4000; - const OUTPUT_DATA_SIZE_OVERRIDE: usize = 0x4001; const SCRATCH_SIZE_OVERRIDE: usize = 0x60000; - let mut cfg = SandboxConfiguration::new( - INPUT_DATA_SIZE_OVERRIDE, - OUTPUT_DATA_SIZE_OVERRIDE, + let cfg = SandboxConfiguration::new( Some(HEAP_SIZE_OVERRIDE), SCRATCH_SIZE_OVERRIDE, SandboxConfiguration::DEFAULT_INTERRUPT_RETRY_DELAY, @@ -380,38 +360,6 @@ mod tests { let scratch_size = cfg.get_scratch_size(); assert_eq!(HEAP_SIZE_OVERRIDE, heap_size); assert_eq!(SCRATCH_SIZE_OVERRIDE, scratch_size); - - cfg.heap_size_override = 2048; - cfg.scratch_size = 0x40000; - assert_eq!(2048, cfg.heap_size_override); - assert_eq!(0x40000, cfg.scratch_size); - assert_eq!(INPUT_DATA_SIZE_OVERRIDE, cfg.input_data_size); - assert_eq!(OUTPUT_DATA_SIZE_OVERRIDE, cfg.output_data_size); - } - - #[test] - fn min_sizes() { - let mut cfg = SandboxConfiguration::new( - SandboxConfiguration::MIN_INPUT_SIZE - 1, - SandboxConfiguration::MIN_OUTPUT_SIZE - 1, - None, - SandboxConfiguration::DEFAULT_SCRATCH_SIZE, - SandboxConfiguration::DEFAULT_INTERRUPT_RETRY_DELAY, - SandboxConfiguration::INTERRUPT_VCPU_SIGRTMIN_OFFSET, - #[cfg(gdb)] - None, - #[cfg(crashdump)] - true, - ); - assert_eq!(SandboxConfiguration::MIN_INPUT_SIZE, cfg.input_data_size); - assert_eq!(SandboxConfiguration::MIN_OUTPUT_SIZE, cfg.output_data_size); - assert_eq!(0, cfg.heap_size_override); - - cfg.set_input_data_size(SandboxConfiguration::MIN_INPUT_SIZE - 1); - cfg.set_output_data_size(SandboxConfiguration::MIN_OUTPUT_SIZE - 1); - - assert_eq!(SandboxConfiguration::MIN_INPUT_SIZE, cfg.input_data_size); - assert_eq!(SandboxConfiguration::MIN_OUTPUT_SIZE, cfg.output_data_size); } mod proptests { @@ -422,21 +370,6 @@ mod tests { use crate::sandbox::config::DebugInfo; proptest! { - #[test] - fn input_data_size(size in SandboxConfiguration::MIN_INPUT_SIZE..=SandboxConfiguration::MIN_INPUT_SIZE * 10) { - let mut cfg = SandboxConfiguration::default(); - cfg.set_input_data_size(size); - prop_assert_eq!(size, cfg.get_input_data_size()); - } - - #[test] - fn output_data_size(size in SandboxConfiguration::MIN_OUTPUT_SIZE..=SandboxConfiguration::MIN_OUTPUT_SIZE * 10) { - let mut cfg = SandboxConfiguration::default(); - cfg.set_output_data_size(size); - prop_assert_eq!(size, cfg.get_output_data_size()); - } - - #[test] fn heap_size_override(size in 0x1000..=0x10000u64) { let mut cfg = SandboxConfiguration::default(); diff --git a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs index 779ccd08a..cedc54659 100644 --- a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs +++ b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs @@ -1080,12 +1080,10 @@ mod tests { .unwrap(); } - /// Make sure input/output buffers are properly reset after guest call (with host call) + /// Make sure pool buffers are properly reset after guest call (with host call) #[test] fn io_buffer_reset() { - let mut cfg = SandboxConfiguration::default(); - cfg.set_input_data_size(4096); - cfg.set_output_data_size(4096); + let cfg = SandboxConfiguration::default(); let path = simple_guest_as_string().unwrap(); let mut sandbox = UninitializedSandbox::new(GuestBinary::FilePath(path), Some(cfg)).unwrap(); @@ -1140,8 +1138,6 @@ mod tests { // total, and then add some more for the eagerly-copied page // tables on amd64 let min_scratch = hyperlight_common::layout::min_scratch_size( - cfg.get_input_data_size(), - cfg.get_output_data_size(), cfg.get_g2h_queue_depth(), cfg.get_h2g_queue_depth(), ); diff --git a/src/hyperlight_host/src/sandbox/uninitialized.rs b/src/hyperlight_host/src/sandbox/uninitialized.rs index 23c01be28..f2a4fcfcd 100644 --- a/src/hyperlight_host/src/sandbox/uninitialized.rs +++ b/src/hyperlight_host/src/sandbox/uninitialized.rs @@ -636,8 +636,6 @@ mod tests { // Non default memory configuration let cfg = { let mut cfg = SandboxConfiguration::default(); - cfg.set_input_data_size(0x1000); - cfg.set_output_data_size(0x1000); cfg.set_heap_size(0x1000); Some(cfg) }; @@ -1390,11 +1388,11 @@ mod tests { let _evolved: MultiUseSandbox = sandbox.evolve().expect("Failed to evolve sandbox"); } - // Test 4: Create snapshot with custom input/output buffer sizes + // Test 4: Create snapshot with custom pool page sizes { let mut cfg = SandboxConfiguration::default(); - cfg.set_input_data_size(64 * 1024); // 64KB input - cfg.set_output_data_size(64 * 1024); // 64KB output + cfg.set_h2g_pool_pages(16); // 16 pages + cfg.set_g2h_pool_pages(16); // 16 pages let env = GuestEnvironment::new(GuestBinary::FilePath(binary_path.clone()), None); @@ -1418,9 +1416,7 @@ mod tests { { let mut cfg = SandboxConfiguration::default(); cfg.set_heap_size(32 * 1024 * 1024); // 32MB heap - cfg.set_scratch_size(256 * 1024 * 2); // 512KB scratch (256KB will be input/output) - cfg.set_input_data_size(128 * 1024); // 128KB input - cfg.set_output_data_size(128 * 1024); // 128KB output + cfg.set_scratch_size(256 * 1024 * 2); // 512KB scratch let env = GuestEnvironment::new(GuestBinary::FilePath(binary_path.clone()), None); diff --git a/src/hyperlight_host/tests/integration_test.rs b/src/hyperlight_host/tests/integration_test.rs index 928b63466..a0bd48c6f 100644 --- a/src/hyperlight_host/tests/integration_test.rs +++ b/src/hyperlight_host/tests/integration_test.rs @@ -582,42 +582,6 @@ fn guest_outb_with_invalid_port_poisons_sandbox() { }); } -#[test] -fn corrupt_output_size_prefix_rejected() { - with_rust_sandbox(|mut sbox| { - let res = sbox.call::("CorruptOutputSizePrefix", ()); - assert!( - res.is_err(), - "Expected error when guest corrupts size prefix, got: {:?}", - res, - ); - let err_msg = format!("{:?}", res.unwrap_err()); - assert!( - err_msg.contains("Corrupt buffer size prefix: flatbuffer claims 4294967295 bytes but the element slot is only 8 bytes"), - "Unexpected error message: {err_msg}" - ); - }); -} - -#[test] -fn corrupt_output_back_pointer_rejected() { - with_rust_sandbox(|mut sbox| { - let res = sbox.call::("CorruptOutputBackPointer", ()); - assert!( - res.is_err(), - "Expected error when guest corrupts back-pointer, got: {:?}", - res, - ); - let err_msg = format!("{:?}", res.unwrap_err()); - assert!( - err_msg.contains( - "Corrupt buffer back-pointer: element offset 57005 is outside valid range [8, 8]" - ), - "Unexpected error message: {err_msg}" - ); - }); -} - #[test] fn guest_panic_no_alloc() { let heap_size = 0x4000; diff --git a/src/hyperlight_host/tests/sandbox_host_tests.rs b/src/hyperlight_host/tests/sandbox_host_tests.rs index 9a4b90865..8bf70294e 100644 --- a/src/hyperlight_host/tests/sandbox_host_tests.rs +++ b/src/hyperlight_host/tests/sandbox_host_tests.rs @@ -213,9 +213,7 @@ fn incorrect_parameter_num() { #[test] fn small_scratch_sandbox() { let mut cfg = SandboxConfiguration::default(); - cfg.set_scratch_size(0x48000); - cfg.set_input_data_size(0x24000); - cfg.set_output_data_size(0x24000); + cfg.set_scratch_size(0x1000); let a = UninitializedSandbox::new( GuestBinary::FilePath(simple_guest_as_string().unwrap()), Some(cfg), @@ -345,6 +343,7 @@ fn callback_test_parallel() { } #[test] +#[ignore] // TODO(virtq): C guest host-function error path needs fixing. fn host_function_error() { with_all_uninit_sandboxes(|mut sandbox| { // create host function diff --git a/src/tests/rust_guests/simpleguest/src/main.rs b/src/tests/rust_guests/simpleguest/src/main.rs index b238be535..0b61cb233 100644 --- a/src/tests/rust_guests/simpleguest/src/main.rs +++ b/src/tests/rust_guests/simpleguest/src/main.rs @@ -52,7 +52,7 @@ use hyperlight_guest_bin::host_comm::{ call_host_function, print_output_with_host_print, read_n_bytes_from_user_memory, }; use hyperlight_guest_bin::memory::malloc; -use hyperlight_guest_bin::{GUEST_HANDLE, guest_function, guest_logger, host_function}; +use hyperlight_guest_bin::{guest_function, guest_logger, host_function}; use log::{LevelFilter, error}; use tracing::{Span, instrument}; @@ -981,53 +981,6 @@ fn fuzz_guest_trace(max_depth: u32, msg: String) -> u32 { fuzz_traced_function(0, max_depth, &msg) } -#[guest_function("CorruptOutputSizePrefix")] -fn corrupt_output_size_prefix() -> i32 { - unsafe { - let peb_ptr = core::ptr::addr_of!(GUEST_HANDLE).read().peb().unwrap(); - let output_stack_ptr = (*peb_ptr).output_stack.ptr as *mut u8; - - // Write a fake stack entry with a ~4 GB size prefix (0xFFFF_FFFB + 4). - let buf = core::slice::from_raw_parts_mut(output_stack_ptr, 24); - buf[0..8].copy_from_slice(&24_u64.to_le_bytes()); - buf[8..12].copy_from_slice(&0xFFFF_FFFBu32.to_le_bytes()); - buf[12..16].copy_from_slice(&[0u8; 4]); - buf[16..24].copy_from_slice(&8_u64.to_le_bytes()); - - core::arch::asm!( - "out dx, eax", - "cli", - "hlt", - in("dx") hyperlight_common::outb::VmAction::Halt as u16, - in("eax") 0u32, - options(noreturn), - ); - } -} - -#[guest_function("CorruptOutputBackPointer")] -fn corrupt_output_back_pointer() -> i32 { - unsafe { - let peb_ptr = core::ptr::addr_of!(GUEST_HANDLE).read().peb().unwrap(); - let output_stack_ptr = (*peb_ptr).output_stack.ptr as *mut u8; - - // Write a fake stack entry with back-pointer 0xDEAD (past stack pointer 24). - let buf = core::slice::from_raw_parts_mut(output_stack_ptr, 24); - buf[0..8].copy_from_slice(&24_u64.to_le_bytes()); - buf[8..16].copy_from_slice(&[0u8; 8]); - buf[16..24].copy_from_slice(&0xDEAD_u64.to_le_bytes()); - - core::arch::asm!( - "out dx, eax", - "cli", - "hlt", - in("dx") hyperlight_common::outb::VmAction::Halt as u16, - in("eax") 0u32, - options(noreturn), - ); - } -} - // Interprets the given guest function call as a host function call and dispatches it to the host. fn fuzz_host_function(func: FunctionCall) -> Result> { let mut params = func.parameters.unwrap(); From b81620383d98c14f24a0e5e09acbf81da8106efd Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Wed, 8 Apr 2026 10:26:37 +0200 Subject: [PATCH 13/31] feat(virtq): fix host function error test Signed-off-by: Tomasz Andrzejak --- src/hyperlight_guest/src/virtq/context.rs | 22 +++++++++++-------- src/hyperlight_guest_capi/src/dispatch.rs | 19 +++------------- .../tests/sandbox_host_tests.rs | 1 - 3 files changed, 16 insertions(+), 26 deletions(-) diff --git a/src/hyperlight_guest/src/virtq/context.rs b/src/hyperlight_guest/src/virtq/context.rs index 44c0a83ca..74282837f 100644 --- a/src/hyperlight_guest/src/virtq/context.rs +++ b/src/hyperlight_guest/src/virtq/context.rs @@ -72,7 +72,7 @@ pub struct GuestContext { g2h_producer: G2hProducer, h2g_producer: H2gProducer, generation: u64, - last_host_return: Option, + last_host_result: Option>, } impl GuestContext { @@ -101,7 +101,7 @@ impl GuestContext { g2h_producer, h2g_producer, generation, - last_host_return: None, + last_host_result: None, }; ctx.prefill_h2g(); @@ -308,6 +308,7 @@ impl GuestContext { // restore_h2g_prefill() wrote matching descriptors to the // zeroed ring memory. Both sides are in sync. self.generation = new_generation; + self.last_host_result = None; } pub(super) fn generation(&self) -> u64 { @@ -345,24 +346,27 @@ impl GuestContext { self.g2h_producer.submit(entry) } - /// Stash a host function return value for later retrieval. + /// Stash a host function result for later retrieval. /// /// Used by the C API's two-step calling convention where /// `hl_call_host_function` and `hl_get_host_return_value_as_*` /// are separate calls. - pub fn stash_host_return(&mut self, value: ReturnValue) { - self.last_host_return = Some(value); + pub fn stash_host_result(&mut self, result: Result) { + self.last_host_result = Some(result); } /// Take the stashed host return value. /// /// Panics if no value was stashed or if the type conversion fails. + /// If the stashed result was an error, panics with the error message. pub fn take_host_return>(&mut self) -> T { - let rv = self - .last_host_return + let val = self + .last_host_result .take() - .expect("No host return value available"); - match T::try_from(rv) { + .expect("No host return value available") + .expect("Host function returned an error"); + + match T::try_from(val) { Ok(v) => v, Err(_) => panic!("Host return value type mismatch"), } diff --git a/src/hyperlight_guest_capi/src/dispatch.rs b/src/hyperlight_guest_capi/src/dispatch.rs index 4fd61df44..86ee0fcbe 100644 --- a/src/hyperlight_guest_capi/src/dispatch.rs +++ b/src/hyperlight_guest_capi/src/dispatch.rs @@ -109,22 +109,9 @@ pub extern "C" fn hl_call_host_function(function_call: &FfiFunctionCall) { let return_type = unsafe { function_call.copy_return_type() }; virtq::with_context(|ctx| { - match ctx.call_host_function::(&func_name, Some(parameters), return_type) { - Ok(result) => ctx.stash_host_return(result), - Err(e) => { - // Host function returned an error. Abort with the error - // message so the host can capture it via the abort buffer. - let msg = alloc::ffi::CString::new(e.message) - .unwrap_or_else(|_| alloc::ffi::CString::new("host error").unwrap()); - - unsafe { - hyperlight_guest::exit::abort_with_code_and_message( - &[e.kind as u8], - msg.as_ptr(), - ); - } - } - } + let result = + ctx.call_host_function::(&func_name, Some(parameters), return_type); + ctx.stash_host_result(result); }); } diff --git a/src/hyperlight_host/tests/sandbox_host_tests.rs b/src/hyperlight_host/tests/sandbox_host_tests.rs index 8bf70294e..c067722bb 100644 --- a/src/hyperlight_host/tests/sandbox_host_tests.rs +++ b/src/hyperlight_host/tests/sandbox_host_tests.rs @@ -343,7 +343,6 @@ fn callback_test_parallel() { } #[test] -#[ignore] // TODO(virtq): C guest host-function error path needs fixing. fn host_function_error() { with_all_uninit_sandboxes(|mut sandbox| { // create host function From e973810c58cbefb6324289fff527391adbe8b105 Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Wed, 8 Apr 2026 10:38:53 +0200 Subject: [PATCH 14/31] feat(virtq): micro optimize consumer state Signed-off-by: Tomasz Andrzejak --- src/hyperlight_common/src/virtq/consumer.rs | 44 +-- src/hyperlight_common/src/virtq/pool.rs | 127 ++++++++ src/hyperlight_common/src/virtq/producer.rs | 305 +++++++++++++++++- src/hyperlight_common/src/virtq/ring.rs | 24 +- src/hyperlight_guest/src/virtq/context.rs | 119 +++---- src/hyperlight_guest/src/virtq/mod.rs | 32 +- .../src/guest_function/call.rs | 3 +- .../src/{virtq/mod.rs => virtq.rs} | 36 ++- src/hyperlight_host/src/mem/mgr.rs | 133 +++++++- .../src/sandbox/initialized_multi_use.rs | 214 ++++++++++++ .../src/sandbox/uninitialized_evolve.rs | 15 +- 11 files changed, 884 insertions(+), 168 deletions(-) rename src/hyperlight_guest_bin/src/{virtq/mod.rs => virtq.rs} (73%) diff --git a/src/hyperlight_common/src/virtq/consumer.rs b/src/hyperlight_common/src/virtq/consumer.rs index 9e4e09527..b29f7694a 100644 --- a/src/hyperlight_common/src/virtq/consumer.rs +++ b/src/hyperlight_common/src/virtq/consumer.rs @@ -15,19 +15,12 @@ limitations under the License. */ use alloc::vec; -use alloc::vec::Vec; use bytes::Bytes; +use fixedbitset::FixedBitSet; use super::*; -/// In-flight entry tracking. -/// -/// Stored per descriptor ID while the entry is being processed. -/// Tracks that a descriptor slot is occupied. -#[derive(Debug, Clone, Copy)] -pub(crate) struct Inflight; - /// Data received from the producer, safely copied out of shared memory. /// /// Created by [`VirtqConsumer::poll`]. The entry data is eagerly copied @@ -261,7 +254,7 @@ impl AckCompletion { pub struct VirtqConsumer { inner: RingConsumer, notifier: N, - inflight: Vec>, + inflight: FixedBitSet, } impl VirtqConsumer { @@ -274,7 +267,7 @@ impl VirtqConsumer { /// * `notifier` - Callback for notifying the driver (producer) about completions pub fn new(layout: Layout, mem: M, notifier: N) -> Self { let inner = RingConsumer::new(layout, mem); - let inflight = vec![None; inner.len()]; + let inflight = FixedBitSet::with_capacity(inner.len()); Self { inner, @@ -320,16 +313,16 @@ impl VirtqConsumer { } // Reserve the inflight slot - let slot = self - .inflight - .get_mut(id as usize) - .ok_or(VirtqError::InvalidState)?; + let id_idx = id as usize; + if id_idx >= self.inflight.len() { + return Err(VirtqError::InvalidState); + } - if slot.is_some() { + if self.inflight.contains(id_idx) { return Err(VirtqError::InvalidState); } - *slot = Some(Inflight); + self.inflight.insert(id_idx); let token = Token(id); // Copy entry data from shared memory @@ -363,16 +356,13 @@ impl VirtqConsumer { let id = completion.id(); let written = completion.written() as u32; - let slot = self - .inflight - .get_mut(id as usize) - .ok_or(VirtqError::InvalidState)?; - - if slot.is_none() { + let id_idx = id as usize; + let slot_set = id_idx < self.inflight.len() && self.inflight.contains(id_idx); + if !slot_set { return Err(VirtqError::InvalidState); } - *slot = None; + self.inflight.set(id_idx, false); if self.inner.submit_used_with_notify(id, written)? { self.notifier.notify(QueueStats { @@ -445,7 +435,7 @@ impl VirtqConsumer { /// Reset ring and inflight state to initial values. pub fn reset(&mut self) { self.inner.reset(); - self.inflight.fill(None); + self.inflight.clear(); } } @@ -647,14 +637,14 @@ mod tests { producer.submit(se).unwrap(); let (_entry, completion) = consumer.poll(1024).unwrap().unwrap(); - assert!(consumer.inflight.iter().any(|s| s.is_some())); + assert!(consumer.inflight.count_ones(..) > 0); // Complete first so we do not leak consumer.complete(completion).unwrap(); consumer.reset(); - assert!(consumer.inflight.iter().all(|s| s.is_none())); + assert_eq!(consumer.inflight.count_ones(..), 0); assert_eq!(consumer.inner.num_inflight(), 0); } @@ -677,7 +667,7 @@ mod tests { consumer.reset(); - assert!(consumer.inflight.iter().all(|s| s.is_none())); + assert_eq!(consumer.inflight.count_ones(..), 0); assert_eq!(consumer.inner.num_inflight(), 0); } } diff --git a/src/hyperlight_common/src/virtq/pool.rs b/src/hyperlight_common/src/virtq/pool.rs index 2e49e27fe..bbae4ff41 100644 --- a/src/hyperlight_common/src/virtq/pool.rs +++ b/src/hyperlight_common/src/virtq/pool.rs @@ -601,10 +601,61 @@ impl RecyclePool { }) } + /// Rebuild pool state so that every address in `allocated` is removed + /// from the free list, matching externally known inflight state. + pub fn restore_allocated(&self, allocated: &[u64]) -> Result<(), AllocError> { + self.reset(); + + if allocated.is_empty() { + return Ok(()); + } + + let mut inner = self.inner.borrow_mut(); + + for &addr in allocated { + let pos = inner + .free + .iter() + .position(|&a| a == addr) + .ok_or(AllocError::InvalidFree(addr, inner.slot_size))?; + + inner.free.swap_remove(pos); + } + + Ok(()) + } + + /// Compute the address of slot `index`. + /// + /// Returns `None` if `index >= count`. + pub fn slot_addr(&self, index: usize) -> Option { + let inner = self.inner.borrow(); + if index < inner.count { + Some(inner.base_addr + (index * inner.slot_size) as u64) + } else { + None + } + } + /// Number of free slots. pub fn num_free(&self) -> usize { self.inner.borrow().free.len() } + + /// Base address of the pool region. + pub fn base_addr(&self) -> u64 { + self.inner.borrow().base_addr + } + + /// Slot size in bytes. + pub fn slot_size(&self) -> usize { + self.inner.borrow().slot_size + } + + /// Number of slots in the pool. + pub fn count(&self) -> usize { + self.inner.borrow().count + } } impl BufferProvider for RecyclePool { @@ -664,6 +715,11 @@ mod tests { BufferPool::::new(base, size).unwrap() } + fn make_recycle_pool(slot_count: usize, slot_size: usize) -> RecyclePool { + let base = 0x80000u64; + RecyclePool::new(base, slot_count * slot_size, slot_size).unwrap() + } + #[test] fn test_slab_new_success() { let slab = Slab::<256>::new(0x10000, 1024).unwrap(); @@ -1223,6 +1279,77 @@ mod tests { let a = pool.inner.borrow_mut().alloc(256).unwrap(); assert!(a.len > 0); } + + #[test] + fn test_recycle_pool_restore_allocated_removes_from_free_list() { + let pool = make_recycle_pool(4, 4096); + assert_eq!(pool.num_free(), 4); + + let addrs = [0x80000, 0x81000]; // slots 0 and 1 + pool.restore_allocated(&addrs).unwrap(); + assert_eq!(pool.num_free(), 2); + + // Allocating should only return the two remaining slots + let a1 = pool.alloc(4096).unwrap(); + let a2 = pool.alloc(4096).unwrap(); + assert!(pool.alloc(4096).is_err()); + + // The allocated addresses should be the non-restored ones + let mut got = [a1.addr, a2.addr]; + got.sort(); + assert_eq!(got, [0x82000, 0x83000]); + } + + #[test] + fn test_recycle_pool_restore_allocated_invalid_addr_returns_error() { + let pool = make_recycle_pool(4, 4096); + let result = pool.restore_allocated(&[0xDEAD]); + assert!(result.is_err()); + } + + #[test] + fn test_recycle_pool_restore_allocated_then_dealloc_roundtrip() { + let pool = make_recycle_pool(4, 4096); + let addr = 0x81000u64; + + pool.restore_allocated(&[addr]).unwrap(); + assert_eq!(pool.num_free(), 3); + + // Dealloc the restored address + pool.dealloc(Allocation { addr, len: 4096 }).unwrap(); + assert_eq!(pool.num_free(), 4); + } + + #[test] + fn test_recycle_pool_restore_allocated_all_slots() { + let pool = make_recycle_pool(4, 4096); + let addrs: Vec = (0..4).map(|i| 0x80000 + i * 4096).collect(); + + pool.restore_allocated(&addrs).unwrap(); + assert_eq!(pool.num_free(), 0); + assert!(pool.alloc(4096).is_err()); + } + + #[test] + fn test_recycle_pool_restore_allocated_empty_list_is_noop() { + let pool = make_recycle_pool(4, 4096); + pool.restore_allocated(&[]).unwrap(); + assert_eq!(pool.num_free(), 4); + } + + #[test] + fn test_recycle_pool_restore_allocated_resets_first() { + let pool = make_recycle_pool(4, 4096); + + // Allocate some slots + let _ = pool.alloc(4096).unwrap(); + let _ = pool.alloc(4096).unwrap(); + assert_eq!(pool.num_free(), 2); + + // restore_allocated resets then removes - so 4 - 1 = 3 + pool.restore_allocated(&[0x80000]).unwrap(); + assert_eq!(pool.num_free(), 3); + } } #[cfg(test)] diff --git a/src/hyperlight_common/src/virtq/producer.rs b/src/hyperlight_common/src/virtq/producer.rs index eeb96cc7f..b892bdf25 100644 --- a/src/hyperlight_common/src/virtq/producer.rs +++ b/src/hyperlight_common/src/virtq/producer.rs @@ -19,6 +19,7 @@ use alloc::vec; use alloc::vec::Vec; use bytes::Bytes; +use smallvec::SmallVec; use super::*; @@ -391,18 +392,116 @@ where /// /// # Safety /// - /// All [`RecvCompletion`]s (and their backing [`Bytes`]) from - /// previous `poll()` calls must have been dropped before calling - /// this. Outstanding completions hold pool allocations via - /// `BufferOwner`; resetting the pool while they exist would cause - /// double-free on drop. + /// All [`RecvCompletion`]s (and their backing [`Bytes`]) from previous `poll()` + /// calls must have been dropped before calling this. Outstanding completions + /// hold pool allocations via `BufferOwner`; resetting the pool while they exist + /// would cause double-free on drop. /// - /// TODO(virtq): properly restore state after snapshot instead of just resetting everything + /// TODO(virtq): find a way to allow guest to keep completions across resets. pub fn reset(&mut self) { - self.inner.reset(); self.pool.reset(); + self.inner.reset(); + self.pending.clear(); self.inflight.fill(None); + } + + /// Replace the pool and reset ring, inflight, and pending state. + /// + /// Use this when restoring from a snapshot where the pool has been + /// relocated or recreated. + /// + /// # Safety + /// + /// Same as [`reset`](Self::reset) - all outstanding completions + /// must have been dropped. + pub fn reset_with_pool(&mut self, pool: P) { + self.pool = pool; + self.inner.reset(); self.pending.clear(); + self.inflight.fill(None); + } +} + +/// Snapshot restore support for producers backed by [`RecyclePool`]. +impl VirtqProducer +where + M: MemOps + Clone, + N: Notifier, +{ + /// Replace the pool and reconstruct producer state from a prefilled ring. + /// + /// The host prefills the H2G ring with `min(ring_size, pool_count)` + /// descriptors during restore (`restore_h2g_prefill`), writing + /// descriptors in forward order: position i gets + /// `addr = pool_base + i * slot_size`. + /// + /// Any descriptors already consumed by the host marked used + /// will be discovered naturally by `poll_used()` after restore. + pub fn restore_from_ring(&mut self, pool: RecyclePool) -> Result<(), VirtqError> { + self.reset_with_pool(pool); + + let ring_size = self.inner.len(); + let pool_count = self.pool.count(); + let prefill_count = core::cmp::min(ring_size, pool_count); + let slot_size = self.pool.slot_size(); + + let mut ids = SmallVec::<[u16; 64]>::new(); + + // Scan descriptors to discover in-flight IDs and set up inflight table + for pos in 0..prefill_count as u16 { + let desc_base = self + .inner + .desc_table() + .desc_addr(pos) + .ok_or(VirtqError::RingError(RingError::InvalidState))?; + + let id = self + .inner + .mem() + .read_val::(desc_base + Descriptor::ID_OFFSET as u64) + .map_err(|_| VirtqError::MemoryReadError)?; + + if (id as usize) >= ring_size { + return Err(VirtqError::InvalidState); + } + + if self.inflight[id as usize].is_some() { + return Err(VirtqError::InvalidState); + } + + let addr = self + .pool + .slot_addr(pos as usize) + .ok_or(VirtqError::InvalidState)?; + + self.inflight[id as usize] = Some(Inflight::WriteOnly { + completion: Allocation { + addr, + len: slot_size, + }, + }); + + ids.push(id); + } + + self.inner.reset_prefilled(&ids); + + let addrs: SmallVec<[u64; 64]> = (0..prefill_count) + .map(|i| self.pool.slot_addr(i).expect("prefill_count <= pool count")) + .collect(); + + self.pool + .restore_allocated(&addrs) + .map_err(|_| VirtqError::InvalidState)?; + + debug_assert!( + self.inflight.iter().filter(|s| s.is_some()).count() == prefill_count, + "restore_from_ring: expected {} inflight entries, found {}", + prefill_count, + self.inflight.iter().filter(|s| s.is_some()).count() + ); + + Ok(()) } } @@ -641,9 +740,28 @@ impl Drop for SendEntry { #[cfg(test)] mod tests { use super::*; - use crate::virtq::ring::tests::make_ring; + use crate::virtq::ring::tests::{OwnedRing, TestMem, make_consumer, make_producer, make_ring}; use crate::virtq::test_utils::*; + type RecycleProducer = VirtqProducer; + + const SLOT_SIZE: usize = 4096; + + fn make_recycle_producer(ring: &OwnedRing, slot_count: usize) -> RecycleProducer { + let layout = ring.layout(); + let mem = ring.mem(); + let pool = make_pool(ring, slot_count); + let notifier = TestNotifier::new(); + + VirtqProducer::new(layout, mem, notifier, pool) + } + + fn make_pool(ring: &OwnedRing, slot_count: usize) -> RecyclePool { + let mem = ring.mem(); + let pool_base = mem.base_addr() + Layout::query_size(ring.len()) as u64 + 0x100; + RecyclePool::new(pool_base, slot_count * SLOT_SIZE, SLOT_SIZE).unwrap() + } + #[test] fn test_chain_readwrite_build() { let ring = make_ring(16); @@ -903,4 +1021,175 @@ mod tests { assert!(producer.inflight.iter().all(|s| s.is_none())); assert_eq!(producer.inner.num_free(), producer.inner.len()); } + + #[test] + fn test_restore_from_ring_requires_full_prefill() { + let ring = make_ring(8); + let mut producer = make_recycle_producer(&ring, 8); + + // Ring has no prefilled descriptors - restore should fail + // because IDs read from zeroed memory will all be 0 (duplicate) + assert!(producer.restore_from_ring(make_pool(&ring, 8)).is_err()); + } + + #[test] + fn test_restore_from_ring_partial_prefill_fails() { + let ring = make_ring(8); + let producer = make_recycle_producer(&ring, 8); + let pool_base = producer.pool.base_addr(); + + // Simulate host prefill: write only one descriptor + let mut writer = make_producer(&ring); + writer + .submit_one(pool_base, SLOT_SIZE as u32, true) + .unwrap(); + + // Restore should fail because only 1 of 8 positions has a + // valid unique ID - remaining positions have id=0 (duplicate) + let mut restored = make_recycle_producer(&ring, 8); + assert!(restored.restore_from_ring(make_pool(&ring, 8)).is_err()); + } + + #[test] + fn test_restore_from_ring_full_prefill() { + let depth = 8usize; + let ring = make_ring(depth); + let producer = make_recycle_producer(&ring, depth); + let pool_base = producer.pool.base_addr(); + + // Simulate host prefill: write all descriptors + let mut writer = make_producer(&ring); + for i in 0..depth { + let addr = pool_base + (i * SLOT_SIZE) as u64; + writer.submit_one(addr, SLOT_SIZE as u32, true).unwrap(); + } + + let mut restored = make_recycle_producer(&ring, depth); + restored.restore_from_ring(make_pool(&ring, depth)).unwrap(); + + // All inflight slots should be populated + let inflight_count = restored.inflight.iter().filter(|s| s.is_some()).count(); + assert_eq!(inflight_count, depth); + + // Pool should be fully allocated + assert_eq!(restored.pool.num_free(), 0); + } + + #[test] + fn test_restore_from_ring_forward_order() { + let depth = 4usize; + let ring = make_ring(depth); + let producer = make_recycle_producer(&ring, depth); + let pool_base = producer.pool.base_addr(); + + // Forward order prefill + let mut writer = make_producer(&ring); + for i in 0..depth { + writer + .submit_one(pool_base + (i * SLOT_SIZE) as u64, SLOT_SIZE as u32, true) + .unwrap(); + } + + let mut restored = make_recycle_producer(&ring, depth); + restored.restore_from_ring(make_pool(&ring, depth)).unwrap(); + } + + #[test] + fn test_restore_from_ring_reverse_order() { + let depth = 4usize; + let ring = make_ring(depth); + let producer = make_recycle_producer(&ring, depth); + let pool_base = producer.pool.base_addr(); + + // Reverse order prefill (current host behavior) + let mut writer = make_producer(&ring); + for i in (0..depth).rev() { + writer + .submit_one(pool_base + (i * SLOT_SIZE) as u64, SLOT_SIZE as u32, true) + .unwrap(); + } + + let mut restored = make_recycle_producer(&ring, depth); + restored.restore_from_ring(make_pool(&ring, depth)).unwrap(); + } + + #[test] + fn test_restore_from_ring_pool_state_correct() { + let depth = 8usize; + let ring = make_ring(depth); + let producer = make_recycle_producer(&ring, depth); + let pool_base = producer.pool.base_addr(); + + // Full prefill + let mut writer = make_producer(&ring); + for i in 0..depth { + writer + .submit_one(pool_base + (i * SLOT_SIZE) as u64, SLOT_SIZE as u32, true) + .unwrap(); + } + + let mut restored = make_recycle_producer(&ring, depth); + restored.restore_from_ring(make_pool(&ring, depth)).unwrap(); + // All slots are allocated after full-prefill restore + assert_eq!(restored.pool.num_free(), 0); + } + + #[test] + fn test_restore_from_ring_idempotent() { + let depth = 4usize; + let ring = make_ring(depth); + let producer = make_recycle_producer(&ring, depth); + let pool_base = producer.pool.base_addr(); + + let mut writer = make_producer(&ring); + for i in 0..depth { + writer + .submit_one(pool_base + (i * SLOT_SIZE) as u64, SLOT_SIZE as u32, true) + .unwrap(); + } + + let mut restored = make_recycle_producer(&ring, depth); + restored.restore_from_ring(make_pool(&ring, depth)).unwrap(); + restored.restore_from_ring(make_pool(&ring, depth)).unwrap(); + assert_eq!(restored.pool.num_free(), 0); + } + + #[test] + fn test_restore_from_ring_then_poll_used() { + let depth = 4usize; + let ring = make_ring(depth); + let producer = make_recycle_producer(&ring, depth); + let pool_base = producer.pool.base_addr(); + + // Simulate host prefill + let mut writer = make_producer(&ring); + for i in 0..depth { + writer + .submit_one(pool_base + (i * SLOT_SIZE) as u64, SLOT_SIZE as u32, true) + .unwrap(); + } + + // Restore producer and use ring-level consumer to complete one entry + let mut restored = make_recycle_producer(&ring, depth); + restored.restore_from_ring(make_pool(&ring, depth)).unwrap(); + + // Ring-level consumer reads available descriptors + let mut consumer = make_consumer(&ring); + let (id, chain) = consumer.poll_available().unwrap(); + let writable = chain.writables(); + assert_eq!(writable.len(), 1); + + // Write some data into the writable buffer + let payload = b"test payload"; + consumer.mem().write(writable[0].addr, payload).unwrap(); + consumer.submit_used(id, payload.len() as u32).unwrap(); + + // Producer polls for the completion + let cqe = restored.poll().unwrap().unwrap(); + assert_eq!(&cqe.data[..payload.len()], payload); + + // Pool slot should be returned after data is dropped + drop(cqe); + assert_eq!(restored.pool.num_free(), 1); + } } diff --git a/src/hyperlight_common/src/virtq/ring.rs b/src/hyperlight_common/src/virtq/ring.rs index 978c345b5..302175631 100644 --- a/src/hyperlight_common/src/virtq/ring.rs +++ b/src/hyperlight_common/src/virtq/ring.rs @@ -917,7 +917,7 @@ impl RingProducer { pub fn reset_prefilled(&mut self, ids: &[u16]) { let size = self.desc_table.len(); let count = ids.len(); - assert!(count <= size); + debug_assert!(count <= size); let wrapped = count >= size; self.avail_cursor.head = if wrapped { 0 } else { count as u16 }; @@ -928,15 +928,11 @@ impl RingProducer { self.id_num.iter_mut().for_each(|n| *n = 0); for &id in ids { - assert!((id as usize) < size); - assert_eq!(self.id_num[id as usize], 0); self.id_num[id as usize] = 1; } self.num_free = size - count; self.id_free.clear(); - self.id_free - .extend((0..size as u16).filter(|id| self.id_num[*id as usize] == 0)); } } @@ -3298,10 +3294,7 @@ pub(crate) mod tests { assert!(producer.used_cursor.wrap()); assert_eq!(producer.num_free, 4); - assert_eq!(producer.id_free.len(), 4); - for &id in &[0, 1, 2, 4] { - assert!(producer.id_free.contains(&id)); - } + assert!(producer.id_free.is_empty()); // Only the specified IDs are in-flight for &id in &[5, 6, 7, 3] { assert_eq!(producer.id_num[id as usize], 1); @@ -3311,19 +3304,6 @@ pub(crate) mod tests { } } - #[test] - fn test_reset_prefilled_partial_then_submit() { - let ring = make_ring(8); - let mut producer = make_producer(&ring); - producer.reset_prefilled(&[4, 5, 6, 7]); - - let id = producer.submit_one(0x8000, 128, false).unwrap(); - - assert!([0, 1, 2, 3].contains(&id)); - assert_eq!(producer.num_free, 3); - assert_eq!(producer.id_num[id as usize], 1); - } - #[test] fn test_reset_prefilled_then_poll_used() { let ring = make_ring(4); diff --git a/src/hyperlight_guest/src/virtq/context.rs b/src/hyperlight_guest/src/virtq/context.rs index 74282837f..98fbe32e6 100644 --- a/src/hyperlight_guest/src/virtq/context.rs +++ b/src/hyperlight_guest/src/virtq/context.rs @@ -233,6 +233,71 @@ impl GuestContext { Ok(()) } + /// Restore the H2G producer after snapshot restore. + /// + /// Creates a new [`RecyclePool`] at `pool_gva` and calls + /// [`restore_from_ring`] to reconstruct inflight state + /// from the host's prefilled descriptors. + pub fn restore_h2g(&mut self, pool_gva: u64, pool_size: usize) { + let pool = RecyclePool::new(pool_gva, pool_size, PAGE_SIZE_USIZE) + .expect("H2G RecyclePool creation failed"); + + self.h2g_producer + .restore_from_ring(pool) + .expect("H2G restore_from_ring failed"); + } + + /// Reset the G2H producer with a fresh pool. + /// + /// Creates a new [`BufferPool`] at `pool_gva` and resets the + /// producer to its initial state. + pub fn reset_g2h(&mut self, pool_gva: u64, pool_size: usize) { + let pool = BufferPool::new(pool_gva, pool_size).expect("G2H BufferPool creation failed"); + self.g2h_producer.reset_with_pool(pool); + self.last_host_result = None; + } + + /// Send a log message via the G2H queue. Fire-and-forget. + pub fn emit_log(&mut self, log_data: &[u8]) -> Result<()> { + self.send_g2h_oneshot(MsgKind::Log, log_data) + } + + /// Get the current generation counter. + pub fn generation(&self) -> u64 { + self.generation + } + + /// Set the generation counter after snapshot restore. + pub fn set_generation(&mut self, generation: u64) { + self.generation = generation; + } + + /// Stash a host function result for later retrieval. + /// + /// Used by the C API's two-step calling convention where + /// `hl_call_host_function` and `hl_get_host_return_value_as_*` + /// are separate calls. + pub fn stash_host_result(&mut self, result: Result) { + self.last_host_result = Some(result); + } + + /// Take the stashed host return value. + /// + /// Panics if no value was stashed or if the type conversion fails. + /// If the stashed result was an error, panics with the error message. + pub fn take_host_return>(&mut self) -> T { + let val = self + .last_host_result + .take() + .expect("No host return value available") + .expect("Host function returned an error"); + + match T::try_from(val) { + Ok(v) => v, + Err(_) => panic!("Host return value type mismatch"), + } + } + /// Pre-fill the H2G queue with completion-only descriptors so the host /// can write incoming call payloads into them. fn prefill_h2g(&mut self) { @@ -287,34 +352,6 @@ impl GuestContext { } } - /// Drain any pending G2H completions. - /// - /// This is called before checking for H2G calls so that the host - /// can reclaim G2H response buffers. - pub fn drain_g2h_completions(&mut self) { - while let Ok(Some(_)) = self.g2h_producer.poll() {} - } - - /// Send a log message via the G2H queue. Fire-and-forget. - pub fn emit_log(&mut self, log_data: &[u8]) -> Result<()> { - self.send_g2h_oneshot(MsgKind::Log, log_data) - } - - /// Reset ring and pool state after snapshot restore. - pub(super) fn reset(&mut self, new_generation: u64) { - self.g2h_producer.reset(); - // H2G state is NOT reset. The guest's inflight and cursors - // survived via CoW and are already correct. The host's - // restore_h2g_prefill() wrote matching descriptors to the - // zeroed ring memory. Both sides are in sync. - self.generation = new_generation; - self.last_host_result = None; - } - - pub(super) fn generation(&self) -> u64 { - self.generation - } - fn try_send_readonly( &mut self, header: &[u8], @@ -345,30 +382,4 @@ impl GuestContext { entry.write_all(payload)?; self.g2h_producer.submit(entry) } - - /// Stash a host function result for later retrieval. - /// - /// Used by the C API's two-step calling convention where - /// `hl_call_host_function` and `hl_get_host_return_value_as_*` - /// are separate calls. - pub fn stash_host_result(&mut self, result: Result) { - self.last_host_result = Some(result); - } - - /// Take the stashed host return value. - /// - /// Panics if no value was stashed or if the type conversion fails. - /// If the stashed result was an error, panics with the error message. - pub fn take_host_return>(&mut self) -> T { - let val = self - .last_host_result - .take() - .expect("No host return value available") - .expect("Host function returned an error"); - - match T::try_from(val) { - Ok(v) => v, - Err(_) => panic!("Host return value type mismatch"), - } - } } diff --git a/src/hyperlight_guest/src/virtq/mod.rs b/src/hyperlight_guest/src/virtq/mod.rs index d3f592985..d86ae40cc 100644 --- a/src/hyperlight_guest/src/virtq/mod.rs +++ b/src/hyperlight_guest/src/virtq/mod.rs @@ -26,7 +26,6 @@ use core::cell::UnsafeCell; use core::sync::atomic::{AtomicU8, Ordering}; use context::GuestContext; -use hyperlight_common::layout::{SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET, scratch_top_ptr}; pub use mem::GuestMemOps; // Init state machine @@ -34,17 +33,17 @@ const UNINITIALIZED: u8 = 0; const INITIALIZED: u8 = 1; static INIT_STATE: AtomicU8 = AtomicU8::new(UNINITIALIZED); -/// Check if the global context has been initialized. -pub fn is_initialized() -> bool { - INIT_STATE.load(Ordering::Acquire) == INITIALIZED -} - // Storage: UnsafeCell guarded by atomic init state. struct SyncWrap(T); unsafe impl Sync for SyncWrap {} static GLOBAL_CONTEXT: SyncWrap>> = SyncWrap(UnsafeCell::new(None)); +/// Check if the global context has been initialized. +pub fn is_initialized() -> bool { + INIT_STATE.load(Ordering::Acquire) == INITIALIZED +} + /// Access the global guest context via closure. /// /// # Panics @@ -78,24 +77,3 @@ pub fn set_global_context(ctx: GuestContext) { } unsafe { *GLOBAL_CONTEXT.0.get() = Some(ctx) }; } - -/// Reset the global context if a snapshot restore was detected. -/// Compares the virtq generation counter in scratch-top metadata. -pub fn maybe_reset_global_context() { - if !is_initialized() { - return; - } - - let current_gen = read_gen(); - - with_context(|ctx| { - if current_gen != ctx.generation() { - ctx.reset(current_gen); - } - }); -} - -/// Read the current snapshot generation from scratch-top metadata. -fn read_gen() -> u64 { - unsafe { *scratch_top_ptr::(SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET) } -} diff --git a/src/hyperlight_guest_bin/src/guest_function/call.rs b/src/hyperlight_guest_bin/src/guest_function/call.rs index d898dd5d3..ecee52ace 100644 --- a/src/hyperlight_guest_bin/src/guest_function/call.rs +++ b/src/hyperlight_guest_bin/src/guest_function/call.rs @@ -101,8 +101,7 @@ pub(crate) fn internal_dispatch_function() { // After snapshot restore, the ring memory is zeroed but the // producer's cursors are stale. Check once per dispatch entry. - virtq::maybe_reset_global_context(); - virtq::with_context(|ctx| ctx.drain_g2h_completions()); + crate::virtq::maybe_reset_virtqueues(); let function_call = virtq::with_context(|ctx| { ctx.recv_h2g_call() diff --git a/src/hyperlight_guest_bin/src/virtq/mod.rs b/src/hyperlight_guest_bin/src/virtq.rs similarity index 73% rename from src/hyperlight_guest_bin/src/virtq/mod.rs rename to src/hyperlight_guest_bin/src/virtq.rs index 2621f5dd2..45b207214 100644 --- a/src/hyperlight_guest_bin/src/virtq/mod.rs +++ b/src/hyperlight_guest_bin/src/virtq.rs @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -//! Guest-side virtqueue initialization. +//! Guest-side virtqueue initialization and reset. use core::num::NonZeroU16; @@ -81,11 +81,43 @@ pub(crate) fn init_virtqueues() { hyperlight_guest::virtq::set_global_context(ctx); } +/// Reset virtqueue state if a snapshot restore was detected. +/// +/// Compares the generation counter in scratch-top metadata against +/// the context's cached value. On mismatch, restores H2G from the +/// host-prefilled ring and allocates a fresh G2H pool. +pub(crate) fn maybe_reset_virtqueues() { + if !hyperlight_guest::virtq::is_initialized() { + return; + } + + let curr_gen = unsafe { *scratch_top_ptr::(SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET) }; + + hyperlight_guest::virtq::with_context(|ctx| { + if curr_gen == ctx.generation() { + return; + } + + // Read host-assigned H2G pool location from scratch-top + let h2g_pool_gva = unsafe { *scratch_top_ptr::(SCRATCH_TOP_H2G_POOL_GVA_OFFSET) }; + let h2g_pages = unsafe { *scratch_top_ptr::(SCRATCH_TOP_H2G_POOL_PAGES_OFFSET) }; + let g2h_pages = unsafe { *scratch_top_ptr::(SCRATCH_TOP_G2H_POOL_PAGES_OFFSET) }; + + let h2g_pages = h2g_pages as usize; + let g2h_pages = g2h_pages as usize; + let g2h_pool_gva = alloc_pool(g2h_pages); + + ctx.restore_h2g(h2g_pool_gva, h2g_pages * PAGE_SIZE_USIZE); + ctx.reset_g2h(g2h_pool_gva, g2h_pages * PAGE_SIZE_USIZE); + ctx.set_generation(curr_gen); + }); +} + /// Allocate and zero `n` physical pages, returning the GVA. fn alloc_pool(n: usize) -> u64 { let gpa = unsafe { alloc_phys_pages(n as u64) }; let ptr = phys_to_virt(gpa).expect("failed to map pool pages"); - let size = n as usize * PAGE_SIZE_USIZE; + let size = n * PAGE_SIZE_USIZE; unsafe { core::ptr::write_bytes(ptr, 0, size) }; ptr as u64 } diff --git a/src/hyperlight_host/src/mem/mgr.rs b/src/hyperlight_host/src/mem/mgr.rs index a26aef4d4..798837bc1 100644 --- a/src/hyperlight_host/src/mem/mgr.rs +++ b/src/hyperlight_host/src/mem/mgr.rs @@ -530,9 +530,15 @@ impl SandboxMemoryManager { self.snapshot_count = snapshot.snapshot_generation(); self.update_scratch_bookkeeping()?; + + // Place the H2G pool at first_free so the bump allocator starts right after it. + // Guest reads this GVA from scratch-top during reset(). + let h2g_pool_gva = self.place_h2g_pool_at_first_free()?; + self.init_g2h_consumer()?; self.init_h2g_consumer()?; - self.restore_h2g_prefill()?; + self.restore_h2g_prefill(h2g_pool_gva)?; + Ok((gsnapshot, gscratch)) } @@ -912,22 +918,39 @@ impl SandboxMemoryManager { Ok(()) } + /// Place the H2G pool at `first_free` during snapshot restore. + /// + /// Writes the pool GVA to scratch-top and advances the bump + /// allocator past the pool so COW page-fault resolution cannot + /// alias pool memory. Returns the computed pool GVA for use by + /// [`restore_h2g_prefill`]. + fn place_h2g_pool_at_first_free(&mut self) -> Result { + use hyperlight_common::layout::*; + + let scratch_size = self.scratch_mem.mem_size(); + let first_free = self.layout.get_first_free_scratch_gpa(); + let base_gpa = scratch_base_gpa(scratch_size); + let base_gva = scratch_base_gva(scratch_size); + let h2g_pool_gva = base_gva + (first_free - base_gpa); + let h2g_pages = self.layout.sandbox_memory_config.get_h2g_pool_pages() as u64; + + self.update_scratch_bookkeeping_item(SCRATCH_TOP_H2G_POOL_GVA_OFFSET, h2g_pool_gva)?; + let allocator = first_free + h2g_pages * PAGE_SIZE_USIZE as u64; + self.update_scratch_bookkeeping_item(SCRATCH_TOP_ALLOCATOR_OFFSET, allocator)?; + + Ok(h2g_pool_gva) + } + /// Prefill the H2G ring with writable descriptors after snapshot restore. /// /// Uses a temporary `RingProducer` to write descriptors into the H2G ring /// so the host consumer can poll them. The guest's `restore_from_ring` /// will later reconstruct its inflight state from these descriptors. - pub(crate) fn restore_h2g_prefill(&mut self) -> Result<()> { - let pool_gva = match self.h2g_pool_gva { - Some(gva) => gva, - None => return Ok(()), - }; - + fn restore_h2g_prefill(&mut self, pool_gva: u64) -> Result<()> { let layout = self.h2g_virtq_layout()?; let mem_ops = self.host_mem_ops(); let h2g_depth = self.layout.sandbox_memory_config.get_h2g_queue_depth(); - // Pool size from config let slot_size = PAGE_SIZE_USIZE; let pool_size = self.layout.sandbox_memory_config.get_h2g_pool_pages() * PAGE_SIZE_USIZE; let slot_count = pool_size / slot_size; @@ -935,10 +958,11 @@ impl SandboxMemoryManager { let mut producer = virtq::RingProducer::new(layout, mem_ops); let prefill_count = core::cmp::min(slot_count, h2g_depth); - // Write descriptors in reverse order to match the guest's LIFO - // allocation pattern (RecyclePool::alloc pops from the end of - // the free list, so the first prefill gets the highest address). - for i in (0..prefill_count).rev() { + // Write descriptors in forward order. The guest calls + // restore_from_ring which reconstructs used-descriptor addresses + // as base + position * slot_size, so the iteration order must + // match this formula. + for i in 0..prefill_count { let addr = pool_gva + (i * slot_size) as u64; producer .submit_one(addr, slot_size as u32, true) @@ -1149,4 +1173,89 @@ mod tests { verify_page_tables(name, config); } } + + /// Verify that the H2G pool placed at `first_free` during restore + /// does not overlap with the bump allocator range or the + /// scratch-top metadata region. + /// + /// This guards against the COW-pool GPA overlap bug: if the bump + /// allocator could return GPAs inside the pool region, a COW + /// page-fault would overwrite pool buffer data with stale shared + /// memory content, corrupting virtqueue communication. + fn verify_pool_allocator_no_collision(name: &str, config: SandboxConfiguration) { + let path = simple_guest_as_string().expect("failed to get simple guest path"); + let snapshot = Snapshot::from_env(GuestBinary::FilePath(path), config) + .unwrap_or_else(|e| panic!("{name}: failed to create snapshot: {e}")); + + let layout = snapshot.layout(); + let scratch_size = layout.get_scratch_size(); + let first_free = layout.get_first_free_scratch_gpa(); + let h2g_pages = layout.sandbox_memory_config.get_h2g_pool_pages(); + let scratch_base = hyperlight_common::layout::scratch_base_gpa(scratch_size); + + let pool_start = first_free; + let pool_end = first_free + (h2g_pages * PAGE_TABLE_SIZE) as u64; + let allocator_start = pool_end; + + // The metadata region lives at the very top of scratch. + // SCRATCH_TOP_EXN_STACK_OFFSET (0x50) is the highest offset. + // Two pages are reserved at the top for exception stack and metadata. + let scratch_end = scratch_base + scratch_size as u64; + let metadata_start = scratch_end - 2 * PAGE_TABLE_SIZE as u64; + + assert!( + pool_start >= scratch_base, + "{name}: pool starts before scratch (pool=0x{pool_start:x}, scratch=0x{scratch_base:x})" + ); + + assert!( + pool_end <= metadata_start, + "{name}: pool overlaps metadata (pool_end=0x{pool_end:x}, metadata=0x{metadata_start:x})" + ); + + assert_eq!( + allocator_start, pool_end, + "{name}: allocator should start immediately after pool" + ); + + assert!( + allocator_start < metadata_start, + "{name}: no room for COW allocations (allocator=0x{allocator_start:x}, metadata=0x{metadata_start:x})" + ); + } + + #[test] + fn test_pool_allocator_no_collision() { + let test_cases: Vec<(&str, SandboxConfiguration)> = vec![ + ("default", SandboxConfiguration::default()), + ("large pools", { + let mut cfg = SandboxConfiguration::default(); + cfg.set_h2g_pool_pages(16); + cfg.set_g2h_pool_pages(16); + cfg + }), + ("minimal scratch", { + let mut cfg = SandboxConfiguration::default(); + cfg.set_scratch_size(0x20000); + cfg + }), + ("large scratch", { + let mut cfg = SandboxConfiguration::default(); + cfg.set_scratch_size(0x100000); + cfg + }), + ("large heap + large pools", { + let mut cfg = SandboxConfiguration::default(); + cfg.set_heap_size(LARGE_HEAP_SIZE); + cfg.set_scratch_size(0x100000); + cfg.set_h2g_pool_pages(32); + cfg.set_g2h_pool_pages(32); + cfg + }), + ]; + + for (name, config) in test_cases { + verify_pool_allocator_no_collision(name, config); + } + } } diff --git a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs index cedc54659..ab5f33f0e 100644 --- a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs +++ b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs @@ -1194,6 +1194,220 @@ mod tests { assert_eq!(res, 0); } + /// Many snapshot restore cycles with state-modifying guest calls. + #[test] + fn restore_stress_no_pool_corruption() { + let mut sbox: MultiUseSandbox = { + let path = simple_guest_as_string().unwrap(); + let u_sbox = UninitializedSandbox::new(GuestBinary::FilePath(path), None).unwrap(); + u_sbox.evolve() + } + .unwrap(); + + let snapshot = sbox.snapshot().unwrap(); + + for _ in 0..50 { + sbox.restore(snapshot.clone()).unwrap(); + let _ = sbox.call::("AddToStatic", 1i32).unwrap(); + + let res: i32 = sbox.call("GetStatic", ()).unwrap(); + assert_eq!(res, 1); + + let res: i32 = sbox.call("AddToStatic", 2i32).unwrap(); + assert_eq!(res, 3); + } + } + + /// Stress test: snapshot/restore with G2H queue pressure. + #[test] + fn restore_stress_with_host_calls() { + let mut sbox: MultiUseSandbox = { + let path = simple_guest_as_string().unwrap(); + let u_sbox = UninitializedSandbox::new(GuestBinary::FilePath(path), None).unwrap(); + u_sbox.evolve() + } + .unwrap(); + + let snapshot = sbox.snapshot().unwrap(); + + for i in 0..50 { + sbox.restore(snapshot.clone()).unwrap(); + + // Fire-and-forget log oneshots - multiple G2H entries queued + // without waiting for responses + sbox.call::<()>("LogMessageN", 5_i32).unwrap(); + + // G2H round-trip with returned data after logs filled the queue + let echo: String = sbox.call("Echo", "ping".to_string()).unwrap(); + assert_eq!(echo, "ping"); + + // Multiple calls without restore to exercise queue reuse + let res: i32 = sbox.call("AddToStatic", 1i32).unwrap(); + assert_eq!(res, 1); + + let echo2: String = sbox.call("Echo", format!("echo {i}")).unwrap(); + assert_eq!(echo2, format!("echo {i}")); + } + } + + /// Back-to-back restores without any guest call in between. + /// The generation bumps twice but the guest only sees the latest value. + #[test] + fn restore_back_to_back() { + let mut sbox: MultiUseSandbox = { + let path = simple_guest_as_string().unwrap(); + let u_sbox = UninitializedSandbox::new(GuestBinary::FilePath(path), None).unwrap(); + u_sbox.evolve() + } + .unwrap(); + + let _ = sbox.call::("AddToStatic", 42i32).unwrap(); + let snapshot = sbox.snapshot().unwrap(); + + // Two restores in a row, no guest calls between them + sbox.restore(snapshot.clone()).unwrap(); + sbox.restore(snapshot.clone()).unwrap(); + + // Guest should see the snapshot state (static = 42) + let res: i32 = sbox.call("GetStatic", ()).unwrap(); + assert_eq!(res, 42); + + // Another round: three restores, then call + sbox.restore(snapshot.clone()).unwrap(); + sbox.restore(snapshot.clone()).unwrap(); + sbox.restore(snapshot.clone()).unwrap(); + + let res: i32 = sbox.call("AddToStatic", 1i32).unwrap(); + assert_eq!(res, 43); + } + + /// Restore after flooding the G2H queue with log oneshots. + #[test] + fn restore_after_g2h_pressure() { + let mut sbox: MultiUseSandbox = { + let path = simple_guest_as_string().unwrap(); + let u_sbox = UninitializedSandbox::new(GuestBinary::FilePath(path), None).unwrap(); + u_sbox.evolve() + } + .unwrap(); + + let snapshot = sbox.snapshot().unwrap(); + + for _ in 0..20 { + // Flood G2H with many log oneshots to pressure the queue/pool + sbox.call::<()>("LogMessageN", 30_i32).unwrap(); + + // Restore after heavy G2H usage + sbox.restore(snapshot.clone()).unwrap(); + + // Verify queue works cleanly after restore + sbox.call::<()>("LogMessageN", 5_i32).unwrap(); + let echo: String = sbox.call("Echo", "ok".to_string()).unwrap(); + assert_eq!(echo, "ok"); + } + } + + /// Many calls cycling through all descriptor IDs, then restore. + /// Ensures restore handles post-wraparound ring state. + #[test] + fn restore_after_id_wraparound() { + let mut sbox: MultiUseSandbox = { + let path = simple_guest_as_string().unwrap(); + let u_sbox = UninitializedSandbox::new(GuestBinary::FilePath(path), None).unwrap(); + u_sbox.evolve() + } + .unwrap(); + + let snapshot = sbox.snapshot().unwrap(); + + for i in 0..200 { + let res: i32 = sbox.call("AddToStatic", 1i32).unwrap(); + assert_eq!(res, i + 1); + } + + // Restore after IDs have wrapped around many times + sbox.restore(snapshot.clone()).unwrap(); + + let res: i32 = sbox.call("GetStatic", ()).unwrap(); + assert_eq!(res, 0); + + // Do another round of wraparound + restore + for _ in 0..200 { + let _ = sbox.call::("AddToStatic", 1i32).unwrap(); + } + sbox.restore(snapshot.clone()).unwrap(); + + let echo: String = sbox.call("Echo", "after wraparound".to_string()).unwrap(); + assert_eq!(echo, "after wraparound"); + } + + /// Restore after a guest exception recovers the sandbox. + /// The virtqueue must be fully functional after restore despite + /// the guest having been in a broken state. + #[test] + fn restore_after_guest_error() { + let mut sbox: MultiUseSandbox = { + let path = simple_guest_as_string().unwrap(); + let u_sbox = UninitializedSandbox::new(GuestBinary::FilePath(path), None).unwrap(); + u_sbox.evolve() + } + .unwrap(); + + let snapshot = sbox.snapshot().unwrap(); + + // Normal call first + let res: i32 = sbox.call("AddToStatic", 5i32).unwrap(); + assert_eq!(res, 5); + + // Trigger an exception - guest is now in a broken state + let err = sbox.call::<()>("TriggerException", ()); + assert!(err.is_err()); + + // Restore should recover fully + sbox.restore(snapshot.clone()).unwrap(); + + // Verify everything works after recovery + let res: i32 = sbox.call("GetStatic", ()).unwrap(); + assert_eq!(res, 0); + + let echo: String = sbox.call("Echo", "recovered".to_string()).unwrap(); + assert_eq!(echo, "recovered"); + + sbox.call::<()>("LogMessageN", 5_i32).unwrap(); + let res: i32 = sbox.call("AddToStatic", 1i32).unwrap(); + assert_eq!(res, 1); + } + + /// Snapshot immediately after evolve, restore before any calls. + /// Baseline test: the virtqueue has never been used. + #[test] + fn restore_fresh_snapshot() { + let mut sbox: MultiUseSandbox = { + let path = simple_guest_as_string().unwrap(); + let u_sbox = UninitializedSandbox::new(GuestBinary::FilePath(path), None).unwrap(); + u_sbox.evolve() + } + .unwrap(); + + // Snapshot immediately - no guest calls yet + let snapshot = sbox.snapshot().unwrap(); + + sbox.restore(snapshot.clone()).unwrap(); + + // First-ever guest call after restore + let res: i32 = sbox.call("GetStatic", ()).unwrap(); + assert_eq!(res, 0); + + let echo: String = sbox.call("Echo", "first".to_string()).unwrap(); + assert_eq!(echo, "first"); + + // Restore again and verify + sbox.restore(snapshot.clone()).unwrap(); + sbox.call::<()>("LogMessageN", 10_i32).unwrap(); + let res: i32 = sbox.call("AddToStatic", 7i32).unwrap(); + assert_eq!(res, 7); + } + #[test] fn test_trigger_exception_on_guest() { let usbox = UninitializedSandbox::new( diff --git a/src/hyperlight_host/src/sandbox/uninitialized_evolve.rs b/src/hyperlight_host/src/sandbox/uninitialized_evolve.rs index 32ec1a2b6..7f0cc1c0d 100644 --- a/src/hyperlight_host/src/sandbox/uninitialized_evolve.rs +++ b/src/hyperlight_host/src/sandbox/uninitialized_evolve.rs @@ -16,7 +16,6 @@ limitations under the License. #[cfg(gdb)] use std::sync::{Arc, Mutex}; -use hyperlight_common::layout::SCRATCH_TOP_H2G_POOL_GVA_OFFSET; use rand::RngExt; use tracing::{Span, instrument}; @@ -27,7 +26,7 @@ use crate::hypervisor::hyperlight_vm::{HyperlightVm, HyperlightVmError}; use crate::mem::exe::LoadInfo; use crate::mem::mgr::SandboxMemoryManager; use crate::mem::ptr::RawPtr; -use crate::mem::shared_mem::{GuestSharedMemory, SharedMemory}; +use crate::mem::shared_mem::GuestSharedMemory; #[cfg(gdb)] use crate::sandbox::config::DebugInfo; #[cfg(feature = "mem_profile")] @@ -132,18 +131,6 @@ pub(super) fn evolve_impl_multi_use(u_sbox: UninitializedSandbox) -> Result(offset) - && gva != 0 - { - hshm.h2g_pool_gva = Some(gva); - } - } - #[cfg(gdb)] let dbg_mem_wrapper = Arc::new(Mutex::new(hshm.clone())); From 89e144ed2d1cb1e4cb3be6bc2b8cf9b98a5d681a Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Thu, 9 Apr 2026 15:53:20 +0200 Subject: [PATCH 15/31] feat(virtq): add support for multi-descriptor payloads Signed-off-by: Tomasz Andrzejak --- src/hyperlight_common/src/virtq/consumer.rs | 5 +- src/hyperlight_common/src/virtq/mod.rs | 56 ++++++++- src/hyperlight_common/src/virtq/msg.rs | 32 ++++- src/hyperlight_common/src/virtq/pool.rs | 45 ++++++- src/hyperlight_common/src/virtq/producer.rs | 45 ++++--- src/hyperlight_guest/src/virtq/context.rs | 118 +++++++++++------- src/hyperlight_host/src/mem/mgr.rs | 87 +++++++++---- src/hyperlight_host/src/sandbox/outb.rs | 4 +- .../tests/sandbox_host_tests.rs | 76 +++++++++++ 9 files changed, 366 insertions(+), 102 deletions(-) diff --git a/src/hyperlight_common/src/virtq/consumer.rs b/src/hyperlight_common/src/virtq/consumer.rs index b29f7694a..fb11c778e 100644 --- a/src/hyperlight_common/src/virtq/consumer.rs +++ b/src/hyperlight_common/src/virtq/consumer.rs @@ -255,6 +255,7 @@ pub struct VirtqConsumer { inner: RingConsumer, notifier: N, inflight: FixedBitSet, + next_token: u32, } impl VirtqConsumer { @@ -273,6 +274,7 @@ impl VirtqConsumer { inner, notifier, inflight, + next_token: 0, } } @@ -323,7 +325,8 @@ impl VirtqConsumer { } self.inflight.insert(id_idx); - let token = Token(id); + let token = Token(self.next_token, id); + self.next_token = self.next_token.wrapping_add(1); // Copy entry data from shared memory let data = entry_elem diff --git a/src/hyperlight_common/src/virtq/mod.rs b/src/hyperlight_common/src/virtq/mod.rs index 7d2aa7f0a..6e8ad82a5 100644 --- a/src/hyperlight_common/src/virtq/mod.rs +++ b/src/hyperlight_common/src/virtq/mod.rs @@ -368,11 +368,11 @@ pub enum SuppressionKind { /// A token representing a sent entry in the virtqueue. /// -/// Tokens uniquely identify in-flight requests and are used to correlate -/// requests with their responses. The token value corresponds to the -/// descriptor ID in the underlying ring. +/// Tokens uniquely identify in-flight requests and are used to correlate requests with their responses. +/// The first element is a monotonically increasing generation counter. The second element is the +/// underlying descriptor ID #[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub struct Token(pub u16); +pub struct Token(pub u32, pub u16); impl From for Allocation { fn from(value: BufferElement) -> Self { @@ -1007,6 +1007,54 @@ mod tests { assert_eq!(cqe2.token, tok_rw); assert_eq!(&cqe2.data[..], b"reply"); } + + /// Regression test: reclaim + submit must not cause token collisions. + /// + /// Before the monotonic generation counter, Token wrapped the descriptor + /// ID which gets recycled. This caused stale pending completions to + /// match newly submitted entries with the same recycled descriptor ID. + #[test] + fn test_reclaim_submit_no_token_collision() { + let ring = make_ring(8); + let (mut producer, mut consumer, _) = make_test_producer(&ring); + + // Submit and complete a ReadOnly entry + let tok_old = send_readonly(&mut producer, b"log"); + + let (_, c) = consumer.poll(1024).unwrap().unwrap(); + consumer.complete(c).unwrap(); + + // Reclaim pushes the completion to pending (token = tok_old) + let count = producer.reclaim().unwrap(); + assert_eq!(count, 1); + + // Submit a new ReadWrite entry - may reuse the same descriptor ID + let tok_new = send_readwrite(&mut producer, b"call", 64); + + // Tokens must differ even if the descriptor ID was recycled + assert_ne!( + tok_old, tok_new, + "tokens must be unique across reclaim/submit cycles" + ); + + // Complete the ReadWrite entry + let (_, c) = consumer.poll(1024).unwrap().unwrap(); + let SendCompletion::Writable(mut wc) = c else { + panic!("expected writable"); + }; + wc.write_all(b"result").unwrap(); + consumer.complete(wc.into()).unwrap(); + + // Poll should return the stale ReadOnly completion first (wrong token) + let cqe1 = producer.poll().unwrap().unwrap(); + assert_eq!(cqe1.token, tok_old); + assert!(cqe1.data.is_empty()); + + // Then the new ReadWrite completion (matching token) + let cqe2 = producer.poll().unwrap().unwrap(); + assert_eq!(cqe2.token, tok_new); + assert_eq!(&cqe2.data[..], b"result"); + } } #[cfg(all(test, loom))] mod fuzz { diff --git a/src/hyperlight_common/src/virtq/msg.rs b/src/hyperlight_common/src/virtq/msg.rs index ade59643b..090c2eb5b 100644 --- a/src/hyperlight_common/src/virtq/msg.rs +++ b/src/hyperlight_common/src/virtq/msg.rs @@ -20,6 +20,8 @@ limitations under the License. //! fixed 8-byte header, enabling message type discrimination and //! request/response correlation. +use bitflags::bitflags; + /// Message types for the virtqueue wire protocol. #[repr(u8)] #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -54,24 +56,33 @@ impl TryFrom for MsgKind { } } +bitflags! { + #[repr(transparent)] + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] + pub struct MsgFlags: u8 { + /// More descriptors follow for this message. + const MORE = 1 << 0; + } +} + /// Wire header for all virtqueue messages #[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)] #[repr(C)] pub struct VirtqMsgHeader { /// Discriminates the message type. pub kind: u8, - /// Per-type flags TODO(ring): add flags type. + /// Per-message flags (see [`MsgFlags`]). pub flags: u8, /// Caller-assigned correlation ID. Responses echo the request's ID. pub req_id: u16, - /// Byte length of the payload following this header. + /// Byte length of the payload following this header in this descriptor. pub payload_len: u32, } impl VirtqMsgHeader { pub const SIZE: usize = core::mem::size_of::(); - /// Create a new message header. + /// Create a new message header with no flags set. pub const fn new(kind: MsgKind, req_id: u16, payload_len: u32) -> Self { Self { kind: kind as u8, @@ -82,10 +93,10 @@ impl VirtqMsgHeader { } /// Create a new header with flags. - pub const fn with_flags(kind: MsgKind, flags: u8, req_id: u16, payload_len: u32) -> Self { + pub const fn with_flags(kind: MsgKind, flags: MsgFlags, req_id: u16, payload_len: u32) -> Self { Self { kind: kind as u8, - flags, + flags: flags.bits(), req_id, payload_len, } @@ -95,4 +106,15 @@ impl VirtqMsgHeader { pub fn msg_kind(&self) -> Result { MsgKind::try_from(self.kind) } + + /// Interpret the raw flags field as [`MsgFlags`]. + pub fn msg_flags(&self) -> MsgFlags { + MsgFlags::from_bits_truncate(self.flags) + } + + /// Returns true if [`MsgFlags::MORE`] is set, indicating more + /// descriptors follow for this message. + pub const fn has_more(&self) -> bool { + self.flags & MsgFlags::MORE.bits() != 0 + } } diff --git a/src/hyperlight_common/src/virtq/pool.rs b/src/hyperlight_common/src/virtq/pool.rs index bbae4ff41..d60d432bf 100644 --- a/src/hyperlight_common/src/virtq/pool.rs +++ b/src/hyperlight_common/src/virtq/pool.rs @@ -150,10 +150,10 @@ impl Slab { } // Fallback to full search + let total = self.used_slots.len(); self.used_slots.zeroes().find(|&next_free| { - self.used_slots - .count_zeroes(next_free..next_free + slots_num) - == slots_num + let end = next_free + slots_num; + end <= total && self.used_slots.count_zeroes(next_free..end) == slots_num }) } @@ -416,6 +416,11 @@ impl BufferPool { inner: SyncWrap(Rc::new(RefCell::new(inner))), }) } + + /// Upper slab slot size in bytes. + pub const fn upper_slot_size() -> usize { + U + } } #[cfg(all(test, loom))] @@ -821,6 +826,40 @@ mod tests { assert!(matches!(result, Err(AllocError::InvalidFree(_, _)))); } + #[test] + fn test_slab_multi_slot_alloc_near_end() { + let mut slab = make_slab::<256>(1792); // 7 slots + let a0 = slab.alloc(256).unwrap(); + let a1 = slab.alloc(256).unwrap(); + let _a2 = slab.alloc(256).unwrap(); + let _a3 = slab.alloc(256).unwrap(); + let _a4 = slab.alloc(256).unwrap(); + let _a5 = slab.alloc(256).unwrap(); + let _a6 = slab.alloc(256).unwrap(); + + slab.dealloc(a0).unwrap(); + slab.dealloc(a1).unwrap(); + + // 2-slot run fits at indices 0..2 but the search visits index 6 + // (a free zero) first if slots 0-1 are not found before it. + // Actually slots 0-1 are free, so it should find them. + let run = slab.alloc(300).unwrap(); // needs 2 slots + assert_eq!(run.len, 512); + } + + #[test] + fn test_slab_multi_slot_alloc_no_room_at_end() { + // Only the last slot is free but a 2-slot run is requested. + // find_slots must not panic when checking beyond the bitset. + let mut slab = make_slab::<256>(1792); // 7 slots + let allocs: Vec<_> = (0..7).map(|_| slab.alloc(256).unwrap()).collect(); + // Free only the last slot (index 6) + slab.dealloc(allocs[6]).unwrap(); + + let result = slab.alloc(300); // needs 2 slots, only 1 free + assert!(matches!(result, Err(AllocError::NoSpace))); + } + #[test] fn test_slab_free_invalid_address() { let mut slab = make_slab::<256>(1024); diff --git a/src/hyperlight_common/src/virtq/producer.rs b/src/hyperlight_common/src/virtq/producer.rs index b892bdf25..eda4237d4 100644 --- a/src/hyperlight_common/src/virtq/producer.rs +++ b/src/hyperlight_common/src/virtq/producer.rs @@ -125,7 +125,8 @@ pub struct VirtqProducer { inner: RingProducer, notifier: N, pool: P, - inflight: Vec>, + next_token: u32, + inflight: Vec>, pending: VecDeque, } @@ -152,6 +153,7 @@ where pool, notifier, inflight, + next_token: 0, pending: VecDeque::new(), } } @@ -218,13 +220,20 @@ where }; let id = used.id as usize; - let inf = self + let (token, inf) = self .inflight .get_mut(id) .ok_or(VirtqError::InvalidState)? .take() .ok_or(VirtqError::InvalidState)?; + // the token's descriptor ID must match the ring's + debug_assert_eq!( + token.1, used.id, + "ring returned desc_id={} but inflight slot {} has token with desc_id={}", + used.id, id, token.1, + ); + let written = used.len as usize; // Free entry buffers (request data no longer needed) @@ -250,10 +259,7 @@ where None => Bytes::new(), }; - Ok(Some(RecvCompletion { - token: Token(used.id), - data, - })) + Ok(Some(RecvCompletion { token, data })) } /// Drain all available completions, calling the provided closure for each. @@ -310,6 +316,9 @@ where let chain = inflight.try_into_chain(written)?; let id = self.inner.submit_available(&chain)?; + let token = Token(self.next_token, id); + self.next_token = self.next_token.wrapping_add(1); + let slot = self .inflight .get_mut(id as usize) @@ -319,7 +328,7 @@ where return Err(VirtqError::InvalidState); } - *slot = Some(inflight); + *slot = Some((token, inflight)); let should_notify = self.inner.should_notify_since(cursor_before)?; @@ -336,7 +345,7 @@ where }); } - Ok(Token(id)) + Ok(token) } /// Signal backpressure to the consumer. @@ -474,12 +483,18 @@ where .slot_addr(pos as usize) .ok_or(VirtqError::InvalidState)?; - self.inflight[id as usize] = Some(Inflight::WriteOnly { - completion: Allocation { - addr, - len: slot_size, + let token = Token(self.next_token, id); + self.next_token = self.next_token.wrapping_add(1); + + self.inflight[id as usize] = Some(( + token, + Inflight::WriteOnly { + completion: Allocation { + addr, + len: slot_size, + }, }, - }); + )); ids.push(id); } @@ -869,7 +884,7 @@ mod tests { // Ring should still be fully usable let se = producer.chain().entry(64).completion(128).build().unwrap(); let tok = producer.submit(se).unwrap(); - assert!(tok.0 < 16); + assert!(tok.1 < 16); } #[test] @@ -885,7 +900,7 @@ mod tests { // Ring should still be fully usable let se = producer.chain().entry(64).completion(128).build().unwrap(); let tok = producer.submit(se).unwrap(); - assert!(tok.0 < 16); + assert!(tok.1 < 16); } #[test] diff --git a/src/hyperlight_guest/src/virtq/context.rs b/src/hyperlight_guest/src/virtq/context.rs index 98fbe32e6..c3ba6adad 100644 --- a/src/hyperlight_guest/src/virtq/context.rs +++ b/src/hyperlight_guest/src/virtq/context.rs @@ -39,7 +39,6 @@ use crate::bail; use crate::error::Result; static REQUEST_ID: AtomicU16 = AtomicU16::new(0); -const MAX_RESPONSE_CAP: usize = 4096; /// Guest-side notifier that triggers a VM exit via outb. #[derive(Clone, Copy)] @@ -69,9 +68,18 @@ pub struct QueueConfig { /// Virtqueue runtime state for guest-host communication. pub struct GuestContext { + /// guest-to-host driver g2h_producer: G2hProducer, + /// host-to-guest driver h2g_producer: H2gProducer, + /// Max writable bytes the host can write into a G2H completion. + /// Derived from the G2H pool upper slab slot size. + g2h_response_cap: usize, + /// H2G slot size in bytes (each prefilled writable descriptor). + h2g_slot_size: usize, + /// snapshot generation counter generation: u64, + /// used by cabi last_host_result: Option>, } @@ -81,30 +89,27 @@ impl GuestContext { let size = g2h.pool_pages * PAGE_SIZE_USIZE; let g2h_pool = BufferPool::new(g2h.pool_gva, size).expect("failed to create G2H buffer pool"); + let g2h_response_cap = BufferPool::<256, 4096>::upper_slot_size(); let g2h_producer = VirtqProducer::new(g2h.layout, GuestMemOps, GuestNotifier, g2h_pool.clone()); - // Each H2G prefill entry is a single descriptor with one contiguous buffer: one - // fixed-size buffer per descriptor, large payloads split across multiple independent - // completions. - // - // TODO(virtq): consider smaller slot_size (e.g. pool_size / desc_count) to maximize - // prefilled entries for host-side call batching. let size = h2g.pool_pages * PAGE_SIZE_USIZE; - let slot = PAGE_SIZE_USIZE; - let h2g_pool = - RecyclePool::new(h2g.pool_gva, size, slot).expect("failed to create H2G recycle pool"); + let h2g_slot_size = PAGE_SIZE_USIZE; + let h2g_pool = RecyclePool::new(h2g.pool_gva, size, h2g_slot_size) + .expect("failed to create H2G recycle pool"); let h2g_producer = VirtqProducer::new(h2g.layout, GuestMemOps, GuestNotifier, h2g_pool.clone()); let mut ctx = Self { g2h_producer, h2g_producer, + g2h_response_cap, + h2g_slot_size, generation, last_host_result: None, }; - ctx.prefill_h2g(); + ctx.prefill_h2g().expect("H2G initial prefill failed"); ctx } @@ -164,8 +169,8 @@ impl GuestContext { }; let result_bytes = &completion.data; - if result_bytes.len() > MAX_RESPONSE_CAP { - bail!("G2H: response is too large"); + if result_bytes.len() < VirtqMsgHeader::SIZE { + bail!("G2H: response too short for header"); } let payload_bytes = &result_bytes[VirtqMsgHeader::SIZE..]; @@ -182,12 +187,16 @@ impl GuestContext { } /// Receive a host-to-guest function call from the H2G queue. + /// + /// Each descriptor carries a [`VirtqMsgHeader`] with `payload_len` for + /// that chunk. If [`MsgFlags::MORE`](hyperlight_common::virtq::msg::MsgFlags::MORE) + /// is set, more descriptors follow. pub fn recv_h2g_call(&mut self) -> Result { - let Some(completion) = self.h2g_producer.poll()? else { + let Some(first) = self.h2g_producer.poll()? else { bail!("H2G: no pending call"); }; - let data = &completion.data; + let data = &first.data; if data.len() < VirtqMsgHeader::SIZE { bail!("H2G: completion too short for header"); } @@ -198,39 +207,52 @@ impl GuestContext { bail!("H2G: unexpected message kind: 0x{:02x}", hdr.kind); } - let payload_end = VirtqMsgHeader::SIZE + hdr.payload_len as usize; - if payload_end > data.len() { - bail!("H2G: payload length exceeds completion data"); + let chunk_len = hdr.payload_len as usize; + + if !hdr.has_more() { + // Single-descriptor fast path + let payload = &data[VirtqMsgHeader::SIZE..VirtqMsgHeader::SIZE + chunk_len]; + let fc = FunctionCall::try_from(payload)?; + return Ok(fc); } - let payload = &data[VirtqMsgHeader::SIZE..payload_end]; - let fc = FunctionCall::try_from(payload)?; + // Multi-descriptor: accumulate payload until MsgFlags::MORE is cleared + let mut assembled = Vec::with_capacity(chunk_len * 2); + assembled.extend_from_slice(&data[VirtqMsgHeader::SIZE..VirtqMsgHeader::SIZE + chunk_len]); + + loop { + let Some(next) = self.h2g_producer.poll()? else { + bail!("H2G: expected continuation descriptor, none available"); + }; + + let next_data = &next.data; + if next_data.len() < VirtqMsgHeader::SIZE { + bail!("H2G: continuation too short for header"); + } + + let next_hdr: &VirtqMsgHeader = + bytemuck::from_bytes(&next_data[..VirtqMsgHeader::SIZE]); + + let next_chunk = next_hdr.payload_len as usize; + + assembled.extend_from_slice( + &next_data[VirtqMsgHeader::SIZE..VirtqMsgHeader::SIZE + next_chunk], + ); + + if !next_hdr.has_more() { + break; + } + } + + let fc = FunctionCall::try_from(assembled.as_slice())?; Ok(fc) } /// Send the result of a host-to-guest call back to the host via the - /// G2H queue, then refill one H2G descriptor slot. + /// G2H queue, then refill H2G descriptor slots until the ring is full. pub fn send_h2g_result(&mut self, payload: &[u8]) -> Result<()> { self.send_g2h_oneshot(MsgKind::Response, payload)?; - - // Best-effort refill of one H2G slot. Backpressure is expected - // (pool/ring may be full), other errors are propagated. - match self - .h2g_producer - .chain() - .completion(PAGE_SIZE_USIZE) - .build() - { - Ok(e) => match self.h2g_producer.submit(e) { - Ok(_) => {} - Err(virtq::VirtqError::Backpressure) => {} - Err(e) => bail!("H2G refill submit: {e}"), - }, - Err(virtq::VirtqError::Backpressure) => {} - Err(e) => bail!("H2G refill build: {e}"), - } - - Ok(()) + self.prefill_h2g() } /// Restore the H2G producer after snapshot restore. @@ -239,7 +261,7 @@ impl GuestContext { /// [`restore_from_ring`] to reconstruct inflight state /// from the host's prefilled descriptors. pub fn restore_h2g(&mut self, pool_gva: u64, pool_size: usize) { - let pool = RecyclePool::new(pool_gva, pool_size, PAGE_SIZE_USIZE) + let pool = RecyclePool::new(pool_gva, pool_size, self.h2g_slot_size) .expect("H2G RecyclePool creation failed"); self.h2g_producer @@ -300,23 +322,23 @@ impl GuestContext { /// Pre-fill the H2G queue with completion-only descriptors so the host /// can write incoming call payloads into them. - fn prefill_h2g(&mut self) { + fn prefill_h2g(&mut self) -> Result<()> { loop { let entry = match self .h2g_producer .chain() - .completion(PAGE_SIZE_USIZE) + .completion(self.h2g_slot_size) .build() { Ok(e) => e, - Err(virtq::VirtqError::Backpressure) => break, - Err(e) => panic!("H2G prefill build: {e}"), + Err(e) if e.is_transient() => return Ok(()), + Err(e) => bail!("H2G prefill build: {e}"), }; match self.h2g_producer.submit(entry) { Ok(_) => {} - Err(virtq::VirtqError::Backpressure) => break, - Err(e) => panic!("H2G prefill submit: {e}"), + Err(e) if e.is_transient() => return Ok(()), + Err(e) => bail!("H2G prefill submit: {e}"), } } } @@ -375,7 +397,7 @@ impl GuestContext { .g2h_producer .chain() .entry(entry_len) - .completion(MAX_RESPONSE_CAP) + .completion(self.g2h_response_cap) .build()?; entry.write_all(header)?; diff --git a/src/hyperlight_host/src/mem/mgr.rs b/src/hyperlight_host/src/mem/mgr.rs index 798837bc1..b1aa42d3d 100644 --- a/src/hyperlight_host/src/mem/mgr.rs +++ b/src/hyperlight_host/src/mem/mgr.rs @@ -19,7 +19,7 @@ use std::num::NonZeroU16; use hyperlight_common::flatbuffer_wrappers::function_types::FunctionCallResult; use hyperlight_common::mem::PAGE_SIZE_USIZE; -use hyperlight_common::virtq::msg::{MsgKind, VirtqMsgHeader}; +use hyperlight_common::virtq::msg::{MsgFlags, MsgKind, VirtqMsgHeader}; use hyperlight_common::virtq::{self, Layout as VirtqLayout}; use hyperlight_common::vmem::{self, PAGE_TABLE_SIZE}; #[cfg(all(feature = "crashdump", not(feature = "i686-guest")))] @@ -876,6 +876,20 @@ impl SandboxMemoryManager { HostMemOps::new(&self.scratch_mem, scratch_base_gva) } + /// Total G2H buffer pool size in bytes. + pub(crate) fn g2h_pool_size(&self) -> usize { + self.layout.sandbox_memory_config.get_g2h_pool_pages() * PAGE_SIZE_USIZE + } + + pub(crate) fn h2g_pool_size(&self) -> usize { + self.layout.sandbox_memory_config.get_h2g_pool_pages() * PAGE_SIZE_USIZE + } + + /// H2G slot size in bytes. Each prefilled writable descriptor has this capacity. + pub(crate) fn h2g_slot_size(&self) -> usize { + PAGE_SIZE_USIZE + } + /// Initialize the G2H virtqueue consumer. /// Must be called after scratch bookkeeping is written. pub(crate) fn init_g2h_consumer(&mut self) -> Result<()> { @@ -951,7 +965,7 @@ impl SandboxMemoryManager { let mem_ops = self.host_mem_ops(); let h2g_depth = self.layout.sandbox_memory_config.get_h2g_queue_depth(); - let slot_size = PAGE_SIZE_USIZE; + let slot_size = self.h2g_slot_size(); let pool_size = self.layout.sandbox_memory_config.get_h2g_pool_pages() * PAGE_SIZE_USIZE; let slot_count = pool_size / slot_size; @@ -974,39 +988,60 @@ impl SandboxMemoryManager { /// Write a guest function call into the H2G virtqueue. /// - /// Polls the H2G consumer for a prefilled entry from the guest, - /// writes `VirtqMsgHeader::Request` followed by `buffer` into the - /// writable completion, and completes the entry. + /// Large payloads that exceed a single slot are split across multiple descriptors. pub(crate) fn write_guest_function_call_virtq(&mut self, buffer: &[u8]) -> Result<()> { + let h2g_pool_size = self.h2g_pool_size(); + let consumer = self .h2g_consumer .as_mut() .ok_or_else(|| new_error!("H2G consumer not initialized"))?; - let (entry, completion) = consumer - .poll(8192) - .map_err(|e| new_error!("H2G poll: {:?}", e))? - .ok_or_else(|| new_error!("H2G: no prefilled entry available"))?; + let mut offset = 0usize; - // Consume the entry data - this should be empty - drop(entry); + loop { + let remaining = buffer.len() - offset; - let header = VirtqMsgHeader::new(MsgKind::Request, 0, buffer.len() as u32); + let (entry, completion) = consumer + .poll(h2g_pool_size) + .map_err(|e| new_error!("H2G poll: {:?}", e))? + .ok_or_else(|| new_error!("H2G: no prefilled descriptor available"))?; - let virtq::SendCompletion::Writable(mut wc) = completion else { - return Err(new_error!( - "H2G: expected writable completion, got non-writable (ring corruption)" - )); - }; + drop(entry); - wc.write_all(bytemuck::bytes_of(&header)) - .map_err(|e| new_error!("H2G write header: {:?}", e))?; - wc.write_all(buffer) - .map_err(|e| new_error!("H2G write payload: {:?}", e))?; + let virtq::SendCompletion::Writable(mut wc) = completion else { + return Err(new_error!( + "H2G: expected writable completion (ring corruption)" + )); + }; - consumer - .complete(wc.into()) - .map_err(|e| new_error!("H2G complete: {:?}", e))?; + let data_cap = wc.capacity() - VirtqMsgHeader::SIZE; + let chunk_len = remaining.min(data_cap); + let has_more = offset + chunk_len < buffer.len(); + + let flags = if has_more { + MsgFlags::MORE + } else { + MsgFlags::empty() + }; + + let hdr = VirtqMsgHeader::with_flags(MsgKind::Request, flags, 0, chunk_len as u32); + + wc.write_all(bytemuck::bytes_of(&hdr)) + .map_err(|e| new_error!("H2G write header: {:?}", e))?; + wc.write_all(&buffer[offset..offset + chunk_len]) + .map_err(|e| new_error!("H2G write payload: {:?}", e))?; + + consumer + .complete(wc.into()) + .map_err(|e| new_error!("H2G complete: {:?}", e))?; + + offset += chunk_len; + + if !has_more { + break; + } + } Ok(()) } @@ -1015,6 +1050,8 @@ impl SandboxMemoryManager { /// /// The guest submitted the Response on G2H with pub(crate) fn read_h2g_result_from_g2h(&mut self) -> Result { + let g2h_pool_size = self.g2h_pool_size(); + let consumer = self .g2h_consumer .as_mut() @@ -1024,7 +1061,7 @@ impl SandboxMemoryManager { // find the Response that carries the H2G function call result. loop { let maybe_next = consumer - .poll(8192) + .poll(g2h_pool_size) .map_err(|e| new_error!("G2H poll for H2G result: {:?}", e))?; let Some((entry, completion)) = maybe_next else { diff --git a/src/hyperlight_host/src/sandbox/outb.rs b/src/hyperlight_host/src/sandbox/outb.rs index 3fa571db2..626236ba4 100644 --- a/src/hyperlight_host/src/sandbox/outb.rs +++ b/src/hyperlight_host/src/sandbox/outb.rs @@ -191,6 +191,8 @@ fn outb_virtq_call( mem_mgr: &mut SandboxMemoryManager, host_funcs: &Arc>, ) -> Result<(), HandleOutbError> { + let g2h_pool_size = mem_mgr.g2h_pool_size(); + let consumer = mem_mgr.g2h_consumer.as_mut().ok_or_else(|| { HandleOutbError::ReadHostFunctionCall("G2H consumer not initialized".into()) })?; @@ -198,7 +200,7 @@ fn outb_virtq_call( // Drain entries, processing Log messages, until we find a Request. let (entry, completion) = loop { let Some((entry, completion)) = consumer - .poll(8192) + .poll(g2h_pool_size) .map_err(|e| HandleOutbError::ReadHostFunctionCall(format!("G2H poll: {e}")))? else { // No G2H entry - backpressure-only notify or prefill notify. diff --git a/src/hyperlight_host/tests/sandbox_host_tests.rs b/src/hyperlight_host/tests/sandbox_host_tests.rs index c067722bb..8b1d435a3 100644 --- a/src/hyperlight_host/tests/sandbox_host_tests.rs +++ b/src/hyperlight_host/tests/sandbox_host_tests.rs @@ -654,3 +654,79 @@ fn virtq_large_payload_roundtrip() { assert!(res.iter().all(|&b| b == 0)); }); } + +#[test] +fn virtq_multi_descriptor_h2g_two_slots() { + // Payload exceeds a single H2G slot (4096 - header), requiring 2 descriptors. + let mut cfg = SandboxConfiguration::default(); + cfg.set_h2g_pool_pages(4); + with_rust_sandbox_cfg(cfg, |mut sandbox| { + let large_msg: String = "A".repeat(4200); + let res: String = sandbox.call("Echo", large_msg.clone()).unwrap(); + assert_eq!(res, large_msg); + }); +} + +#[test] +fn virtq_multi_descriptor_h2g_max_slots() { + // Payload spanning all available H2G pool slots. + let mut cfg = SandboxConfiguration::default(); + cfg.set_h2g_pool_pages(4); + with_rust_sandbox_cfg(cfg, |mut sandbox| { + let large_msg: String = "B".repeat(8200); + let res: String = sandbox.call("Echo", large_msg.clone()).unwrap(); + assert_eq!(res, large_msg); + }); +} + +#[test] +fn virtq_multi_descriptor_h2g_byte_array() { + // Multi-descriptor with byte array arguments to test binary payloads. + let mut cfg = SandboxConfiguration::default(); + cfg.set_h2g_pool_pages(8); + with_rust_sandbox_cfg(cfg, |mut sandbox| { + let large_bytes: Vec = (0..5000).map(|i| (i % 256) as u8).collect(); + let res: Vec = sandbox + .call("SetByteArrayToZero", large_bytes.clone()) + .unwrap(); + assert_eq!(res.len(), 5000); + assert!(res.iter().all(|&b| b == 0)); + }); +} + +#[test] +fn virtq_multi_descriptor_h2g_boundary() { + // Payload exactly at single-slot capacity boundary. + // Header is 8 bytes, so a single slot fits exactly 4088 bytes of payload. + // The FlatBuffer encoding adds overhead, so we test near the boundary + // to verify no off-by-one errors. + let mut cfg = SandboxConfiguration::default(); + cfg.set_h2g_pool_pages(4); + with_rust_sandbox_cfg(cfg, |mut sandbox| { + // This should fit in one descriptor (small overhead) + let msg_under: String = "C".repeat(3900); + let res: String = sandbox.call("Echo", msg_under.clone()).unwrap(); + assert_eq!(res, msg_under); + + // This should just barely spill into a second descriptor + let msg_over: String = "D".repeat(4100); + let res: String = sandbox.call("Echo", msg_over.clone()).unwrap(); + assert_eq!(res, msg_over); + }); +} + +#[test] +fn virtq_multi_descriptor_h2g_repeated_calls() { + // Multiple large calls in sequence to verify H2G refill works correctly + // after multi-descriptor consumption. + let mut cfg = SandboxConfiguration::default(); + cfg.set_h2g_pool_pages(8); + with_rust_sandbox_cfg(cfg, |mut sandbox| { + for i in 0..5 { + let ch = char::from(b'A' + i as u8); + let msg: String = std::iter::repeat_n(ch, 4500).collect(); + let res: String = sandbox.call("Echo", msg.clone()).unwrap(); + assert_eq!(res, msg, "mismatch on call {i}"); + } + }); +} From ba91b9d685aa360a791e1b482f4abffe2a030303 Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Thu, 9 Apr 2026 21:32:58 +0200 Subject: [PATCH 16/31] feat(virtq): do not swallow errors Signed-off-by: Tomasz Andrzejak --- src/hyperlight_common/src/virtq/consumer.rs | 12 ++++-- src/hyperlight_host/src/mem/mgr.rs | 7 ++- src/hyperlight_host/src/sandbox/outb.rs | 48 ++++++++++++++------- 3 files changed, 45 insertions(+), 22 deletions(-) diff --git a/src/hyperlight_common/src/virtq/consumer.rs b/src/hyperlight_common/src/virtq/consumer.rs index fb11c778e..d3da1020c 100644 --- a/src/hyperlight_common/src/virtq/consumer.rs +++ b/src/hyperlight_common/src/virtq/consumer.rs @@ -329,10 +329,14 @@ impl VirtqConsumer { self.next_token = self.next_token.wrapping_add(1); // Copy entry data from shared memory - let data = entry_elem - .map(|elem| self.read_element(&elem)) - .transpose()? - .unwrap_or_default(); + let data = match entry_elem.map(|elem| self.read_element(&elem)).transpose() { + Ok(d) => d.unwrap_or_default(), + Err(e) => { + // Read failed - clear inflight before propagating + self.inflight.set(id_idx, false); + return Err(e); + } + }; let entry = RecvEntry { token, data }; diff --git a/src/hyperlight_host/src/mem/mgr.rs b/src/hyperlight_host/src/mem/mgr.rs index b1aa42d3d..bd5449fff 100644 --- a/src/hyperlight_host/src/mem/mgr.rs +++ b/src/hyperlight_host/src/mem/mgr.rs @@ -1073,8 +1073,11 @@ impl SandboxMemoryManager { return Err(new_error!("G2H: result entry too short")); } - let hdr: &VirtqMsgHeader = bytemuck::from_bytes(&entry_data[..VirtqMsgHeader::SIZE]); - let payload = &entry_data[VirtqMsgHeader::SIZE..]; + let hdr_size = VirtqMsgHeader::SIZE; + let hdr: &VirtqMsgHeader = bytemuck::from_bytes(&entry_data[..hdr_size]); + let available = entry_data.len() - hdr_size; + let payload_len = (hdr.payload_len as usize).min(available); + let payload = &entry_data[hdr_size..hdr_size + payload_len]; match hdr.msg_kind() { Ok(MsgKind::Response) => { diff --git a/src/hyperlight_host/src/sandbox/outb.rs b/src/hyperlight_host/src/sandbox/outb.rs index 626236ba4..85f3f1aeb 100644 --- a/src/hyperlight_host/src/sandbox/outb.rs +++ b/src/hyperlight_host/src/sandbox/outb.rs @@ -199,27 +199,40 @@ fn outb_virtq_call( // Drain entries, processing Log messages, until we find a Request. let (entry, completion) = loop { - let Some((entry, completion)) = consumer - .poll(g2h_pool_size) - .map_err(|e| HandleOutbError::ReadHostFunctionCall(format!("G2H poll: {e}")))? - else { + let Ok(maybe_next) = consumer.poll(g2h_pool_size) else { + return Err(HandleOutbError::ReadHostFunctionCall( + "G2H poll failed".into(), + )); + }; + + let Some((entry, completion)) = maybe_next else { // No G2H entry - backpressure-only notify or prefill notify. return Ok(()); }; + let hdr_size = VirtqMsgHeader::SIZE; let entry_data = entry.data(); - if entry_data.len() < VirtqMsgHeader::SIZE { + + if entry_data.len() < hdr_size { return Err(HandleOutbError::ReadHostFunctionCall( "G2H entry too short".into(), )); } - let hdr: VirtqMsgHeader = *bytemuck::from_bytes(&entry_data[..VirtqMsgHeader::SIZE]); + + let hdr: VirtqMsgHeader = *bytemuck::from_bytes(&entry_data[..hdr_size]); match hdr.msg_kind() { Ok(MsgKind::Log) => { - let payload = &entry_data[VirtqMsgHeader::SIZE..]; + let available = entry_data.len() - hdr_size; + let log_len = (hdr.payload_len as usize).min(available); + let payload = &entry_data[hdr_size..hdr_size + log_len]; + emit_guest_log(payload); - let _ = consumer.complete(completion); + + consumer.complete(completion).map_err(|e| { + HandleOutbError::ReadHostFunctionCall(format!("G2H complete log: {e}")) + })?; + continue; } Ok(MsgKind::Request) => break (entry, completion), @@ -237,8 +250,18 @@ fn outb_virtq_call( } }; + // Validate completion buffer before calling the host function + let virtq::SendCompletion::Writable(mut wc) = completion else { + return Err(HandleOutbError::WriteHostFunctionResponse( + "G2H: expected writable completion, got ack (ring corruption)".into(), + )); + }; + let entry_data = entry.data(); - let payload = &entry_data[VirtqMsgHeader::SIZE..]; + let hdr: VirtqMsgHeader = *bytemuck::from_bytes(&entry_data[..VirtqMsgHeader::SIZE]); + let available = entry_data.len() - VirtqMsgHeader::SIZE; + let payload_len = (hdr.payload_len as usize).min(available); + let payload = &entry_data[VirtqMsgHeader::SIZE..VirtqMsgHeader::SIZE + payload_len]; let call = FunctionCall::try_from(payload) .map_err(|e| HandleOutbError::ReadHostFunctionCall(e.to_string()))?; @@ -259,13 +282,6 @@ fn outb_virtq_call( let resp_header = VirtqMsgHeader::new(MsgKind::Response, 0, result_payload.len() as u32); let resp_header_bytes = bytemuck::bytes_of(&resp_header); - // Write response into the completion buffer - let virtq::SendCompletion::Writable(mut wc) = completion else { - return Err(HandleOutbError::WriteHostFunctionResponse( - "G2H: expected writable completion, got ack (ring corruption)".into(), - )); - }; - wc.write_all(resp_header_bytes) .map_err(|e| HandleOutbError::WriteHostFunctionResponse(format!("{e}")))?; wc.write_all(result_payload) From 674616d36ebb47106a3067be70dd125756eeb680 Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Fri, 10 Apr 2026 11:38:35 +0200 Subject: [PATCH 17/31] fix(virtq): adjust sizes for benchmarks Signed-off-by: Tomasz Andrzejak --- src/hyperlight_common/src/virtq/mod.rs | 76 +++++++++++-------- src/hyperlight_common/src/virtq/producer.rs | 25 +++++- src/hyperlight_common/src/virtq/ring.rs | 2 +- src/hyperlight_host/benches/benchmarks.rs | 15 ++-- .../src/sandbox/initialized_multi_use.rs | 10 +-- 5 files changed, 82 insertions(+), 46 deletions(-) diff --git a/src/hyperlight_common/src/virtq/mod.rs b/src/hyperlight_common/src/virtq/mod.rs index 6e8ad82a5..66e384c31 100644 --- a/src/hyperlight_common/src/virtq/mod.rs +++ b/src/hyperlight_common/src/virtq/mod.rs @@ -931,14 +931,14 @@ mod tests { } #[test] - fn test_reclaim_then_poll_preserves_order() { + fn test_reclaim_discards_readonly_completions() { let ring = make_ring(8); let (mut producer, mut consumer, _) = make_test_producer(&ring); // Submit 3 entries: RO, RW, RO - let tok_ro1 = send_readonly(&mut producer, b"log1"); + let _tok_ro1 = send_readonly(&mut producer, b"log1"); let tok_rw = send_readwrite(&mut producer, b"call", 64); - let tok_ro2 = send_readonly(&mut producer, b"log2"); + let _tok_ro2 = send_readonly(&mut producer, b"log2"); // Consumer processes all 3 let (_, c1) = consumer.poll(1024).unwrap().unwrap(); @@ -954,24 +954,16 @@ mod tests { let (_, c3) = consumer.poll(1024).unwrap().unwrap(); consumer.complete(c3).unwrap(); // ack RO - // Reclaim all 3 + // Reclaim all 3 - RO completions are discarded, only RW is buffered let count = producer.reclaim().unwrap(); assert_eq!(count, 3); - // poll() returns them in order - let cqe1 = producer.poll().unwrap().unwrap(); - assert_eq!(cqe1.token, tok_ro1); - assert!(cqe1.data.is_empty()); - - let cqe2 = producer.poll().unwrap().unwrap(); - assert_eq!(cqe2.token, tok_rw); - assert_eq!(&cqe2.data[..], b"result"); - - let cqe3 = producer.poll().unwrap().unwrap(); - assert_eq!(cqe3.token, tok_ro2); - assert!(cqe3.data.is_empty()); + // poll() returns only the RW completion + let cqe = producer.poll().unwrap().unwrap(); + assert_eq!(cqe.token, tok_rw); + assert_eq!(&cqe.data[..], b"result"); - // No more + // No more - RO completions were discarded assert!(producer.poll().unwrap().is_none()); } @@ -1008,11 +1000,7 @@ mod tests { assert_eq!(&cqe2.data[..], b"reply"); } - /// Regression test: reclaim + submit must not cause token collisions. - /// - /// Before the monotonic generation counter, Token wrapped the descriptor - /// ID which gets recycled. This caused stale pending completions to - /// match newly submitted entries with the same recycled descriptor ID. + /// reclaim + submit must not cause token collisions. #[test] fn test_reclaim_submit_no_token_collision() { let ring = make_ring(8); @@ -1024,7 +1012,6 @@ mod tests { let (_, c) = consumer.poll(1024).unwrap().unwrap(); consumer.complete(c).unwrap(); - // Reclaim pushes the completion to pending (token = tok_old) let count = producer.reclaim().unwrap(); assert_eq!(count, 1); @@ -1045,15 +1032,42 @@ mod tests { wc.write_all(b"result").unwrap(); consumer.complete(wc.into()).unwrap(); - // Poll should return the stale ReadOnly completion first (wrong token) - let cqe1 = producer.poll().unwrap().unwrap(); - assert_eq!(cqe1.token, tok_old); - assert!(cqe1.data.is_empty()); + // Poll returns only the RW completion (RO was discarded by reclaim) + let cqe = producer.poll().unwrap().unwrap(); + assert_eq!(cqe.token, tok_new); + assert_eq!(&cqe.data[..], b"result"); - // Then the new ReadWrite completion (matching token) - let cqe2 = producer.poll().unwrap().unwrap(); - assert_eq!(cqe2.token, tok_new); - assert_eq!(&cqe2.data[..], b"result"); + // No stale RO completion in the queue + assert!(producer.poll().unwrap().is_none()); + } + + /// Verify that repeated oneshot submit/reclaim cycles do not accumulate pending completions. + #[test] + fn test_reclaim_readonly_does_not_leak_pending() { + let ring = make_ring(4); + let (mut producer, mut consumer, _) = make_test_producer(&ring); + + for _ in 0..10 { + // Fill the ring + for _ in 0..4 { + send_readonly(&mut producer, b"msg"); + } + + // Consumer acks all + while let Some((_, completion)) = consumer.poll(1024).unwrap() { + consumer.complete(completion).unwrap(); + } + + // Reclaim frees ring slots; empty completions are discarded + let count = producer.reclaim().unwrap(); + assert_eq!(count, 4); + + // No completions should be buffered in pending + assert!( + producer.poll().unwrap().is_none(), + "pending should be empty after reclaiming RO entries" + ); + } } } #[cfg(all(test, loom))] diff --git a/src/hyperlight_common/src/virtq/producer.rs b/src/hyperlight_common/src/virtq/producer.rs index eda4237d4..81cd31952 100644 --- a/src/hyperlight_common/src/virtq/producer.rs +++ b/src/hyperlight_common/src/virtq/producer.rs @@ -34,6 +34,10 @@ pub struct RecvCompletion { pub token: Token, /// Completion data from the device. pub data: Bytes, + /// Whether this entry is oneshot so there is no writable completion buffer. + /// Oneshot entries are fire-and-forget: the producer does not + /// expect any response data from the device. + pub oneshot: bool, } /// Allocation tracking for an in-flight descriptor chain. @@ -146,7 +150,8 @@ where /// * `pool` - Buffer allocator for entry/completion data pub fn new(layout: Layout, mem: M, notifier: N, pool: P) -> Self { let inner = RingProducer::new(layout, mem); - let inflight = vec![None; inner.len()]; + let ring_len = inner.len(); + let inflight = vec![None; ring_len]; Self { inner, @@ -154,7 +159,7 @@ where notifier, inflight, next_token: 0, - pending: VecDeque::new(), + pending: VecDeque::with_capacity(ring_len), } } @@ -192,6 +197,9 @@ where /// buffer allocations immediately, and buffers completion data for /// later retrieval via [`poll`](Self::poll). /// + /// Completions with empty data from read-only/oneshot entries are + /// discarded immediately. + /// /// Use this to free resources under backpressure without losing /// completion data. Returns the number of entries reclaimed. pub fn reclaim(&mut self) -> Result @@ -201,7 +209,11 @@ where { let mut count = 0; while let Some(cqe) = self.poll_ring()? { - self.pending.push_back(cqe); + if !cqe.oneshot { + debug_assert!(self.pending.len() < self.inflight.len()); + debug_assert!(!cqe.data.is_empty()); + self.pending.push_back(cqe); + } count += 1; } Ok(count) @@ -242,6 +254,7 @@ where } // Read completion data + let has_completion = inf.completion().is_some(); let data = match inf.completion() { Some(buf) => { if written > buf.len { @@ -259,7 +272,11 @@ where None => Bytes::new(), }; - Ok(Some(RecvCompletion { token, data })) + Ok(Some(RecvCompletion { + token, + data, + oneshot: !has_completion, + })) } /// Drain all available completions, calling the provided closure for each. diff --git a/src/hyperlight_common/src/virtq/ring.rs b/src/hyperlight_common/src/virtq/ring.rs index 302175631..391a58a0b 100644 --- a/src/hyperlight_common/src/virtq/ring.rs +++ b/src/hyperlight_common/src/virtq/ring.rs @@ -406,7 +406,7 @@ impl RingCursor { /// Advance by n positions using modular arithmetic. #[inline] - fn advance_by(&mut self, n: u16) { + pub(crate) fn advance_by(&mut self, n: u16) { debug_assert!(self.head.checked_add(n).is_some()); let new = self.head + n; let wraps = new / self.size; diff --git a/src/hyperlight_host/benches/benchmarks.rs b/src/hyperlight_host/benches/benchmarks.rs index 61e925116..f6e72df3a 100644 --- a/src/hyperlight_host/benches/benchmarks.rs +++ b/src/hyperlight_host/benches/benchmarks.rs @@ -57,13 +57,13 @@ impl SandboxSize { Self::Medium => { let mut cfg = SandboxConfiguration::default(); cfg.set_heap_size(MEDIUM_HEAP_SIZE); - cfg.set_scratch_size(0x50000); + cfg.set_scratch_size(0x80000); Some(cfg) } Self::Large => { let mut cfg = SandboxConfiguration::default(); cfg.set_heap_size(LARGE_HEAP_SIZE); - cfg.set_scratch_size(0x100000); + cfg.set_scratch_size(0x200000); Some(cfg) } } @@ -379,10 +379,15 @@ fn guest_call_benchmark_large_param(c: &mut Criterion) { let large_vec = vec![0u8; SIZE]; let large_string = String::from_utf8(large_vec.clone()).unwrap(); + let h2g_pool_pages = (2 * SIZE + (1024 * 1024)) / 4096; + let heap_size = SIZE as u64 * 15; + let mut config = SandboxConfiguration::default(); - config.set_h2g_pool_pages((2 * SIZE + (1024 * 1024)) / 4096); // pool pages for the large input - config.set_heap_size(SIZE as u64 * 15); - config.set_scratch_size(6 * SIZE + 4 * (1024 * 1024)); // Big enough for any data copies, etc. + config.set_h2g_pool_pages(h2g_pool_pages); + config.set_h2g_queue_depth(h2g_pool_pages.next_power_of_two()); + config.set_heap_size(heap_size); + // Scratch backs all guest physical pages (heap, page tables, pools). + config.set_scratch_size(heap_size as usize + 4 * 1024 * 1024); let sandbox = UninitializedSandbox::new( GuestBinary::FilePath(simple_guest_as_string().unwrap()), diff --git a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs index ab5f33f0e..df8a6a466 100644 --- a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs +++ b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs @@ -1127,21 +1127,21 @@ mod tests { assert_eq!(res, 0); } - // Tests to ensure that many (1000) function calls can be made in a call context with a small stack (24K) and heap(20K). + // Tests to ensure that many (1000) function calls can be made in a call context with a small stack (24K) and heap(32K). // This test effectively ensures that the stack is being properly reset after each call and we are not leaking memory in the Guest. #[test] fn test_with_small_stack_and_heap() { let mut cfg = SandboxConfiguration::default(); - cfg.set_heap_size(20 * 1024); + cfg.set_heap_size(32 * 1024); // min_scratch_size already includes 1 page (4k on most // platforms) of guest stack, so add 20k more to get 24k // total, and then add some more for the eagerly-copied page - // tables on amd64 + // tables on amd64 and virtq pool pages. let min_scratch = hyperlight_common::layout::min_scratch_size( cfg.get_g2h_queue_depth(), cfg.get_h2g_queue_depth(), ); - cfg.set_scratch_size(min_scratch + 0x10000 + 0x10000); + cfg.set_scratch_size(min_scratch + 0x10000 + 0x18000); let mut sbox1: MultiUseSandbox = { let path = simple_guest_as_string().unwrap(); @@ -1755,7 +1755,7 @@ mod tests { for (name, heap_size) in test_cases { let mut cfg = SandboxConfiguration::default(); - cfg.set_heap_size(heap_size); + cfg.set_heap_size(128 * 1024); cfg.set_scratch_size(0x100000); let path = simple_guest_as_string().unwrap(); From 3cbe9d13f67f7b10f69ea75b7b5fd65fc3efcaa3 Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Fri, 10 Apr 2026 13:55:42 +0200 Subject: [PATCH 18/31] fix(virtq): make clippy happy Signed-off-by: Tomasz Andrzejak --- src/hyperlight_common/src/virtq/buffer.rs | 11 +++++++++-- src/hyperlight_common/src/virtq/mod.rs | 1 + src/hyperlight_common/src/virtq/producer.rs | 4 ++-- .../src/sandbox/initialized_multi_use.rs | 4 ++-- 4 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/hyperlight_common/src/virtq/buffer.rs b/src/hyperlight_common/src/virtq/buffer.rs index b41708b03..237eedcba 100644 --- a/src/hyperlight_common/src/virtq/buffer.rs +++ b/src/hyperlight_common/src/virtq/buffer.rs @@ -139,7 +139,10 @@ impl AllocGuard { } pub fn release(mut self) -> Allocation { - self.0.take().unwrap().0 + // Safety: AllocGuard is always constructed with Some, and release is only called once + self.0.take().map(|(alloc, _)| alloc).unwrap_or_else(|| { + unreachable!("AllocGuard::release called on dismissed guard") + }) } } @@ -147,7 +150,11 @@ impl core::ops::Deref for AllocGuard { type Target = Allocation; fn deref(&self) -> &Allocation { - &self.0.as_ref().unwrap().0 + // Safety: AllocGuard is always constructed with Some, and the inner value is only + // taken by release() or Drop. + &self.0.as_ref().unwrap_or_else(|| { + unreachable!("AllocGuard::deref called on dismissed guard") + }).0 } } diff --git a/src/hyperlight_common/src/virtq/mod.rs b/src/hyperlight_common/src/virtq/mod.rs index 66e384c31..ac1110bdd 100644 --- a/src/hyperlight_common/src/virtq/mod.rs +++ b/src/hyperlight_common/src/virtq/mod.rs @@ -384,6 +384,7 @@ impl From for Allocation { } const _: () = { + #[allow(clippy::unwrap_used)] const fn verify_layout(num_descs: usize) { let base = 0x1000u64; diff --git a/src/hyperlight_common/src/virtq/producer.rs b/src/hyperlight_common/src/virtq/producer.rs index 81cd31952..eda16327d 100644 --- a/src/hyperlight_common/src/virtq/producer.rs +++ b/src/hyperlight_common/src/virtq/producer.rs @@ -519,8 +519,8 @@ where self.inner.reset_prefilled(&ids); let addrs: SmallVec<[u64; 64]> = (0..prefill_count) - .map(|i| self.pool.slot_addr(i).expect("prefill_count <= pool count")) - .collect(); + .map(|i| self.pool.slot_addr(i).ok_or(VirtqError::InvalidState)) + .collect::>()?; self.pool .restore_allocated(&addrs) diff --git a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs index df8a6a466..bdcd43729 100644 --- a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs +++ b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs @@ -1755,8 +1755,8 @@ mod tests { for (name, heap_size) in test_cases { let mut cfg = SandboxConfiguration::default(); - cfg.set_heap_size(128 * 1024); - cfg.set_scratch_size(0x100000); + cfg.set_heap_size(heap_size); + cfg.set_scratch_size(heap_size as usize + 0x100000); let path = simple_guest_as_string().unwrap(); let sbox = UninitializedSandbox::new(GuestBinary::FilePath(path), Some(cfg)) From 1e4cd16beb6cbf729c64705e873c4423552ea8c3 Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Fri, 10 Apr 2026 14:49:06 +0200 Subject: [PATCH 19/31] feat(virtq): add recycle pool tests Signed-off-by: Tomasz Andrzejak --- src/hyperlight_common/src/virtq/buffer.rs | 17 +-- src/hyperlight_common/src/virtq/pool.rs | 141 ++++++++++++++++++++++ 2 files changed, 151 insertions(+), 7 deletions(-) diff --git a/src/hyperlight_common/src/virtq/buffer.rs b/src/hyperlight_common/src/virtq/buffer.rs index 237eedcba..7b637e38b 100644 --- a/src/hyperlight_common/src/virtq/buffer.rs +++ b/src/hyperlight_common/src/virtq/buffer.rs @@ -103,7 +103,7 @@ impl BufferProvider for Arc { /// zero-copy `Bytes` backed by shared memory. /// /// When dropped, the allocation is returned to the pool. -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct BufferOwner { pub(crate) pool: P, pub(crate) mem: M, @@ -140,9 +140,10 @@ impl AllocGuard { pub fn release(mut self) -> Allocation { // Safety: AllocGuard is always constructed with Some, and release is only called once - self.0.take().map(|(alloc, _)| alloc).unwrap_or_else(|| { - unreachable!("AllocGuard::release called on dismissed guard") - }) + self.0 + .take() + .map(|(alloc, _)| alloc) + .unwrap_or_else(|| unreachable!("AllocGuard::release called on dismissed guard")) } } @@ -152,9 +153,11 @@ impl core::ops::Deref for AllocGuard { fn deref(&self) -> &Allocation { // Safety: AllocGuard is always constructed with Some, and the inner value is only // taken by release() or Drop. - &self.0.as_ref().unwrap_or_else(|| { - unreachable!("AllocGuard::deref called on dismissed guard") - }).0 + &self + .0 + .as_ref() + .unwrap_or_else(|| unreachable!("AllocGuard::deref called on dismissed guard")) + .0 } } diff --git a/src/hyperlight_common/src/virtq/pool.rs b/src/hyperlight_common/src/virtq/pool.rs index d60d432bf..92db1a38e 100644 --- a/src/hyperlight_common/src/virtq/pool.rs +++ b/src/hyperlight_common/src/virtq/pool.rs @@ -680,6 +680,20 @@ impl BufferProvider for RecyclePool { fn dealloc(&self, alloc: Allocation) -> Result<(), AllocError> { let mut inner = self.inner.borrow_mut(); + let end = inner.base_addr + (inner.count * inner.slot_size) as u64; + + if alloc.addr < inner.base_addr || alloc.addr >= end { + return Err(AllocError::InvalidFree(alloc.addr, alloc.len)); + } + + if (alloc.addr - inner.base_addr) % inner.slot_size as u64 != 0 { + return Err(AllocError::InvalidFree(alloc.addr, alloc.len)); + } + + if inner.free.contains(&alloc.addr) { + return Err(AllocError::InvalidFree(alloc.addr, alloc.len)); + } + inner.free.push(alloc.addr); Ok(()) } @@ -1389,6 +1403,133 @@ mod tests { pool.restore_allocated(&[0x80000]).unwrap(); assert_eq!(pool.num_free(), 3); } + + #[test] + fn test_recycle_pool_dealloc_out_of_range() { + let pool = make_recycle_pool(4, 4096); + let _ = pool.alloc(4096).unwrap(); + + let bogus = Allocation { + addr: 0xDEAD, + len: 4096, + }; + assert!(matches!( + pool.dealloc(bogus), + Err(AllocError::InvalidFree(0xDEAD, 4096)) + )); + } + + #[test] + fn test_recycle_pool_dealloc_misaligned() { + let pool = make_recycle_pool(4, 4096); + let _ = pool.alloc(4096).unwrap(); + + let misaligned = Allocation { + addr: 0x80001, + len: 4096, + }; + assert!(matches!( + pool.dealloc(misaligned), + Err(AllocError::InvalidFree(0x80001, 4096)) + )); + } + + #[test] + fn test_recycle_pool_dealloc_double_free() { + let pool = make_recycle_pool(4, 4096); + let a = pool.alloc(4096).unwrap(); + pool.dealloc(a).unwrap(); + + // Second dealloc should fail - address is already in the free list + assert!(matches!( + pool.dealloc(a), + Err(AllocError::InvalidFree(_, _)) + )); + } + + #[test] + fn test_recycle_pool_random_order_dealloc() { + let pool = make_recycle_pool(8, 4096); + + let mut allocs: Vec = (0..8).map(|_| pool.alloc(4096).unwrap()).collect(); + assert_eq!(pool.num_free(), 0); + + // Dealloc in reverse order + allocs.reverse(); + for a in &allocs { + pool.dealloc(*a).unwrap(); + } + assert_eq!(pool.num_free(), 8); + + // All slots should be re-allocatable + let reallocs: Vec = (0..8).map(|_| pool.alloc(4096).unwrap()).collect(); + assert_eq!(pool.num_free(), 0); + + // Verify all addresses are distinct + let mut addrs: Vec = reallocs.iter().map(|a| a.addr).collect(); + addrs.sort(); + addrs.dedup(); + assert_eq!(addrs.len(), 8); + } + + #[test] + fn test_recycle_pool_interleaved_alloc_dealloc_order() { + let pool = make_recycle_pool(4, 4096); + + let a0 = pool.alloc(4096).unwrap(); + let a1 = pool.alloc(4096).unwrap(); + let a2 = pool.alloc(4096).unwrap(); + let a3 = pool.alloc(4096).unwrap(); + assert_eq!(pool.num_free(), 0); + + // Free middle slots first (out of allocation order) + pool.dealloc(a2).unwrap(); + pool.dealloc(a0).unwrap(); + assert_eq!(pool.num_free(), 2); + + // Re-alloc gets the out-of-order slots back (LIFO) + let b0 = pool.alloc(4096).unwrap(); + assert_eq!(b0.addr, a0.addr); + let b1 = pool.alloc(4096).unwrap(); + assert_eq!(b1.addr, a2.addr); + + // Free everything in yet another order + pool.dealloc(a1).unwrap(); + pool.dealloc(b0).unwrap(); + pool.dealloc(b1).unwrap(); + pool.dealloc(a3).unwrap(); + assert_eq!(pool.num_free(), 4); + + // All 4 original addresses should be available + let mut final_addrs: Vec = (0..4).map(|_| pool.alloc(4096).unwrap().addr).collect(); + final_addrs.sort(); + let expected: Vec = (0..4).map(|i| 0x80000 + i * 4096).collect(); + assert_eq!(final_addrs, expected); + } + + #[test] + fn test_recycle_pool_dealloc_order_independent_of_alloc_order() { + let pool = make_recycle_pool(6, 256); + + // Allocate all + let allocs: Vec = (0..6).map(|_| pool.alloc(256).unwrap()).collect(); + + // Dealloc in scattered order: 4, 1, 5, 0, 3, 2 + let order = [4, 1, 5, 0, 3, 2]; + for &i in &order { + pool.dealloc(allocs[i]).unwrap(); + } + assert_eq!(pool.num_free(), 6); + + // Re-allocate all and verify we get back the full set + let mut realloc_addrs: Vec = (0..6).map(|_| pool.alloc(256).unwrap().addr).collect(); + realloc_addrs.sort(); + + let mut orig_addrs: Vec = allocs.iter().map(|a| a.addr).collect(); + orig_addrs.sort(); + + assert_eq!(realloc_addrs, orig_addrs); + } } #[cfg(test)] From 8ca084413ca5359902042058bf449ccc6006d325 Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Fri, 10 Apr 2026 15:27:20 +0200 Subject: [PATCH 20/31] feat(virtq): implement G2H reply backlog guard Signed-off-by: Tomasz Andrzejak --- src/hyperlight_common/src/virtq/producer.rs | 6 +++ src/hyperlight_guest/src/virtq/context.rs | 48 ++++++++++++++++++++- 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/src/hyperlight_common/src/virtq/producer.rs b/src/hyperlight_common/src/virtq/producer.rs index eda16327d..ff8536ff6 100644 --- a/src/hyperlight_common/src/virtq/producer.rs +++ b/src/hyperlight_common/src/virtq/producer.rs @@ -384,6 +384,12 @@ where self.inner.used_cursor() } + /// Number of free (unsubmitted) descriptors in the ring. + #[inline] + pub fn num_free(&self) -> usize { + self.inner.num_free() + } + /// Configure event suppression for used buffer notifications. /// /// This controls when the device (consumer) signals us about completed buffers: diff --git a/src/hyperlight_guest/src/virtq/context.rs b/src/hyperlight_guest/src/virtq/context.rs index c3ba6adad..a13288fd3 100644 --- a/src/hyperlight_guest/src/virtq/context.rs +++ b/src/hyperlight_guest/src/virtq/context.rs @@ -79,6 +79,8 @@ pub struct GuestContext { h2g_slot_size: usize, /// snapshot generation counter generation: u64, + /// Number of H2G requests received that still need a G2H response. + pending_replies: u32, /// used by cabi last_host_result: Option>, } @@ -106,6 +108,7 @@ impl GuestContext { g2h_response_cap, h2g_slot_size, generation, + pending_replies: 0, last_host_result: None, }; @@ -114,6 +117,9 @@ impl GuestContext { } /// Call a host function via the G2H virtqueue. + /// + /// The reply guard is checked before submitting the readwrite chain + /// to ensure G2H capacity is reserved for pending responses. pub fn call_host_function>( &mut self, function_name: &str, @@ -139,6 +145,9 @@ impl GuestContext { let entry_len = VirtqMsgHeader::SIZE + payload.len(); + // Reply guard: readwrite chains use 2 descriptors, leave room for pending replies. + self.ensure_reply_capacity(2)?; + let token = match self.try_send_readwrite(hdr_bytes, payload, entry_len) { Ok(tok) => tok, Err(e) if e.is_transient() => { @@ -191,6 +200,9 @@ impl GuestContext { /// Each descriptor carries a [`VirtqMsgHeader`] with `payload_len` for /// that chunk. If [`MsgFlags::MORE`](hyperlight_common::virtq::msg::MsgFlags::MORE) /// is set, more descriptors follow. + /// + /// Increments the reply guard counter so that subsequent G2H sends + /// reserve capacity for the response. pub fn recv_h2g_call(&mut self) -> Result { let Some(first) = self.h2g_producer.poll()? else { bail!("H2G: no pending call"); @@ -209,6 +221,9 @@ impl GuestContext { let chunk_len = hdr.payload_len as usize; + // Track that we owe a response on G2H. + self.pending_replies = self.pending_replies.saturating_add(1); + if !hdr.has_more() { // Single-descriptor fast path let payload = &data[VirtqMsgHeader::SIZE..VirtqMsgHeader::SIZE + chunk_len]; @@ -250,8 +265,11 @@ impl GuestContext { /// Send the result of a host-to-guest call back to the host via the /// G2H queue, then refill H2G descriptor slots until the ring is full. + /// + /// Decrements the reply guard counter after a successful send. pub fn send_h2g_result(&mut self, payload: &[u8]) -> Result<()> { self.send_g2h_oneshot(MsgKind::Response, payload)?; + self.pending_replies = self.pending_replies.saturating_sub(1); self.prefill_h2g() } @@ -343,16 +361,42 @@ impl GuestContext { } } + /// Ensure the G2H ring has enough free descriptors to accommodate + /// both the requested send (`need_descs`) and all pending replies. + fn ensure_reply_capacity(&mut self, need_descs: usize) -> Result<()> { + let reserved = self.pending_replies as usize; + loop { + let free = self.g2h_producer.num_free(); + if free >= need_descs + reserved { + return Ok(()); + } + + self.g2h_producer.notify_backpressure(); + let reclaimed = self.g2h_producer.reclaim()?; + if reclaimed == 0 { + // No progress - host hasn't completed any entries yet. + // Fall through and let the send path handle backpressure + // via its own retry logic. + return Ok(()); + } + } + } + /// Send a one-way message on the G2H queue ReadOnly and no completion. /// - /// If the pool or ring is full, triggers backpressure, VM exit so - /// the host can drain, then retries once. + /// For non-response sends, the reply guard is checked first to + /// ensure enough G2H capacity is reserved for pending replies. fn send_g2h_oneshot(&mut self, kind: MsgKind, payload: &[u8]) -> Result<()> { let reqid = REQUEST_ID.fetch_add(1, Relaxed); let hdr = VirtqMsgHeader::new(kind, reqid, payload.len() as u32); let hdr_bytes = bytemuck::bytes_of(&hdr); let entry_len = VirtqMsgHeader::SIZE + payload.len(); + // Reply guard: non-response sends must leave room for pending replies. + if kind != MsgKind::Response { + self.ensure_reply_capacity(1)?; + } + // First attempt match self.try_send_readonly(hdr_bytes, payload, entry_len) { Ok(_) => return Ok(()), From 811fc98c2c0ccb565b79787caa73307557093b1e Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Fri, 10 Apr 2026 15:43:46 +0200 Subject: [PATCH 21/31] fix(virtq): add copyright header to benches Signed-off-by: Tomasz Andrzejak --- src/hyperlight_common/benches/buffer_pool.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/hyperlight_common/benches/buffer_pool.rs b/src/hyperlight_common/benches/buffer_pool.rs index 614f160b0..80d4f9daa 100644 --- a/src/hyperlight_common/benches/buffer_pool.rs +++ b/src/hyperlight_common/benches/buffer_pool.rs @@ -1,3 +1,19 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + use std::hint::black_box; use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main}; From 16344cfe1d6bb7e4033d33f6c02adcea7d6e7056 Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Fri, 10 Apr 2026 17:47:28 +0200 Subject: [PATCH 22/31] fix(virtq): we gonna need a bigger boat Move FXSAVE buffer to the middle of scratch to avoid overwriting live page tables that are copied to the beginning of scratch when update_scratch_bookkeeping is called Signed-off-by: Tomasz Andrzejak --- fuzz/fuzz_targets/guest_trace.rs | 4 ++-- .../src/hypervisor/hyperlight_vm/x86_64.rs | 14 ++++++++------ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/fuzz/fuzz_targets/guest_trace.rs b/fuzz/fuzz_targets/guest_trace.rs index 3dfb61c95..43373300e 100644 --- a/fuzz/fuzz_targets/guest_trace.rs +++ b/fuzz/fuzz_targets/guest_trace.rs @@ -69,8 +69,8 @@ impl<'a> Arbitrary<'a> for FuzzInput { fuzz_target!( init: { let mut cfg = SandboxConfiguration::default(); - // In local tests, 256 KiB seemed sufficient for deep recursion - cfg.set_scratch_size(256 * 1024); + // In local tests, 512 KiB seemed sufficient for deep recursion + cfg.set_scratch_size(512 * 1024); let path = simple_guest_for_fuzzing_as_string().expect("Guest Binary Missing"); let u_sbox = UninitializedSandbox::new( GuestBinary::FilePath(path), diff --git a/src/hyperlight_host/src/hypervisor/hyperlight_vm/x86_64.rs b/src/hyperlight_host/src/hypervisor/hyperlight_vm/x86_64.rs index 39bb1b56f..0b3593be1 100644 --- a/src/hyperlight_host/src/hypervisor/hyperlight_vm/x86_64.rs +++ b/src/hyperlight_host/src/hypervisor/hyperlight_vm/x86_64.rs @@ -2110,18 +2110,20 @@ mod tests { } /// Creates VM with guest code that: dirtys FPU (if flag==0), does FXSAVE to buffer, sets flag=1. - /// Uses scratch region after rings for FXSAVE buffer. + /// Uses a scratch region area for the FXSAVE buffer. fn hyperlight_vm_with_mem_mgr_fxsave() -> FxsaveTestContext { use iced_x86::code_asm::*; // Compute fixed addresses for FXSAVE buffer and flag. - // We use the page-table area in scratch after rings as a - // convenient 512-byte aligned buffer for FXSAVE. + // Place the buffer at halfway through scratch: well past + // the rings and page tables at the start, and well below + // the stack and scratch-top metadata at the end. let config: SandboxConfiguration = Default::default(); let layout = SandboxMemoryLayout::new(config, 512, 4096, None).unwrap(); - let fxsave_offset = layout.get_pt_base_scratch_offset(); - let fxsave_gva = hyperlight_common::layout::scratch_base_gva(config.get_scratch_size()) - + fxsave_offset as u64; + let scratch_size = config.get_scratch_size(); + let fxsave_offset = (scratch_size / 2) & !0xFFF; // page-aligned + let fxsave_gva = + hyperlight_common::layout::scratch_base_gva(scratch_size) + fxsave_offset as u64; let flag_gva = fxsave_gva + 512; let mut a = CodeAssembler::new(64).unwrap(); From 30f8ffcdef85ef1023d719d47dea02c5dc1b51ad Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Fri, 10 Apr 2026 19:23:15 +0200 Subject: [PATCH 23/31] fix(virtq): add instrumentation to virtq host calls Signed-off-by: Tomasz Andrzejak --- fuzz/fuzz_targets/host_call.rs | 2 ++ src/hyperlight_guest/src/virtq/context.rs | 2 ++ src/hyperlight_host/tests/integration_test.rs | 4 ++-- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/fuzz/fuzz_targets/host_call.rs b/fuzz/fuzz_targets/host_call.rs index 0559b2bd6..c65b10047 100644 --- a/fuzz/fuzz_targets/host_call.rs +++ b/fuzz/fuzz_targets/host_call.rs @@ -61,6 +61,8 @@ fuzz_target!( HyperlightError::GuestError(ErrorCode::HostFunctionError, msg) if msg.contains("The number of arguments to the function is wrong") => {} HyperlightError::ParameterValueConversionFailure(_, _) => {}, HyperlightError::GuestError(ErrorCode::HostFunctionError, msg) if msg.contains("Failed To Convert Parameter Value") => {} + HyperlightError::GuestError(ErrorCode::HostFunctionError, msg) if msg.contains("The parameter value type is unexpected") => {} + HyperlightError::GuestError(ErrorCode::HostFunctionError, msg) if msg.contains("The return value type is unexpected") => {} // any other error should be reported _ => panic!("Guest Aborted with Unexpected Error: {:?}", e), diff --git a/src/hyperlight_guest/src/virtq/context.rs b/src/hyperlight_guest/src/virtq/context.rs index a13288fd3..0724e0540 100644 --- a/src/hyperlight_guest/src/virtq/context.rs +++ b/src/hyperlight_guest/src/virtq/context.rs @@ -33,6 +33,7 @@ use hyperlight_common::virtq::msg::{MsgKind, VirtqMsgHeader}; use hyperlight_common::virtq::{ self, BufferPool, Layout, Notifier, QueueStats, RecyclePool, Token, VirtqProducer, }; +use tracing::instrument; use super::GuestMemOps; use crate::bail; @@ -120,6 +121,7 @@ impl GuestContext { /// /// The reply guard is checked before submitting the readwrite chain /// to ensure G2H capacity is reserved for pending responses. + #[instrument(skip_all, level = "Info")] pub fn call_host_function>( &mut self, function_name: &str, diff --git a/src/hyperlight_host/tests/integration_test.rs b/src/hyperlight_host/tests/integration_test.rs index a0bd48c6f..cd9689d43 100644 --- a/src/hyperlight_host/tests/integration_test.rs +++ b/src/hyperlight_host/tests/integration_test.rs @@ -715,10 +715,10 @@ fn log_message() { // follows: // - logs from trace level tracing spans created as logs because of the tracing `log` feature // - 4 from evolve call (generic_init + hyperlight_main) - // - 8 from guest call + // - 4 from guest call (call_host_function + read_n_bytes_from_user_memory) // and are multiplied because we make 6 calls to `log_test_messages` // NOTE: These numbers need to be updated if log messages or spans are added/removed - let num_fixed_trace_log = 12 * 6; + let num_fixed_trace_log = 8 * 6; // Calculate fixed info logs // - 4 logs per iteration from infrastructure at Info level (internal_dispatch_function) From 1a36bbb938bc19f11de3e96bfbde69ea95ef0ac6 Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Fri, 10 Apr 2026 19:55:19 +0200 Subject: [PATCH 24/31] fix(virtq): truncate error message so it fits completion Signed-off-by: Tomasz Andrzejak --- src/hyperlight_host/src/sandbox/outb.rs | 15 +++++++++++++-- src/hyperlight_host/tests/integration_test.rs | 4 ++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/hyperlight_host/src/sandbox/outb.rs b/src/hyperlight_host/src/sandbox/outb.rs index 85f3f1aeb..0b6bfb609 100644 --- a/src/hyperlight_host/src/sandbox/outb.rs +++ b/src/hyperlight_host/src/sandbox/outb.rs @@ -268,12 +268,23 @@ fn outb_virtq_call( let name = call.function_name.clone(); let args: Vec = call.parameters.unwrap_or(vec![]); - let res = host_funcs + + let registry = host_funcs .try_lock() - .map_err(|e| HandleOutbError::LockFailed(file!(), line!(), e.to_string()))? + .map_err(|e| HandleOutbError::LockFailed(file!(), line!(), e.to_string()))?; + + let mut res = registry .call_host_function(&name, args) .map_err(|e| GuestError::new(ErrorCode::HostFunctionError, e.to_string())); + // Truncate oversized error messages so the serialized response + // fits in the completion buffer the guest pre-allocated. + if let Err(err) = &mut res + && err.message.len() > wc.capacity() + { + err.message.truncate(wc.capacity()); + } + // Serialize response: VirtqMsgHeader + FunctionCallResult let func_result = FunctionCallResult::new(res); let mut builder = flatbuffers::FlatBufferBuilder::new(); diff --git a/src/hyperlight_host/tests/integration_test.rs b/src/hyperlight_host/tests/integration_test.rs index cd9689d43..1450783b8 100644 --- a/src/hyperlight_host/tests/integration_test.rs +++ b/src/hyperlight_host/tests/integration_test.rs @@ -535,7 +535,7 @@ fn guest_malloc_abort() { }); // allocate a vector (on heap) that is bigger than the heap - let heap_size = 0x4000; + let heap_size = 0x8000; let size_to_allocate = 0x10000; assert!( size_to_allocate > heap_size, @@ -584,7 +584,7 @@ fn guest_outb_with_invalid_port_poisons_sandbox() { #[test] fn guest_panic_no_alloc() { - let heap_size = 0x4000; + let heap_size = 0x8000; let mut cfg = SandboxConfiguration::default(); cfg.set_heap_size(heap_size); From dccb0774fa7057fa54bcbd5d7b2ccc9a65882b85 Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Mon, 13 Apr 2026 14:16:56 +0200 Subject: [PATCH 25/31] fix(virtq): move log tests to integration tests Signed-off-by: Tomasz Andrzejak --- Justfile | 2 +- src/hyperlight_guest/src/virtq/context.rs | 3 - src/hyperlight_guest/src/virtq/mod.rs | 17 +- src/hyperlight_guest_bin/src/host_comm.rs | 2 + src/hyperlight_host/tests/integration_test.rs | 162 +++++++++++++++++ .../tests/sandbox_host_tests.rs | 166 ------------------ 6 files changed, 174 insertions(+), 178 deletions(-) diff --git a/Justfile b/Justfile index bd569fa64..2064cccb8 100644 --- a/Justfile +++ b/Justfile @@ -225,7 +225,7 @@ test-unit target=default-target features="": test-isolated target=default-target features="" : {{ cargo-cmd }} test {{ if features =="" {''} else if features=="no-default-features" {"--no-default-features" } else {"--no-default-features -F " + features } }} --profile={{ if target == "debug" { "dev" } else { target } }} {{ target-triple-flag }} -p hyperlight-host --lib -- sandbox::uninitialized::tests::test_log_trace --exact --ignored {{ cargo-cmd }} test {{ if features =="" {''} else if features=="no-default-features" {"--no-default-features" } else {"--no-default-features -F " + features } }} --profile={{ if target == "debug" { "dev" } else { target } }} {{ target-triple-flag }} -p hyperlight-host --lib -- sandbox::outb::tests::test_log_outb_log --exact --ignored - {{ cargo-cmd }} test {{ if features =="" {''} else if features=="no-default-features" {"--no-default-features" } else {"--no-default-features -F " + features } }} --profile={{ if target == "debug" { "dev" } else { target } }} {{ target-triple-flag }} -p hyperlight-host --test integration_test -- log_message --exact --ignored + {{ cargo-cmd }} test {{ if features =="" {''} else if features=="no-default-features" {"--no-default-features" } else {"--no-default-features -F " + features } }} --profile={{ if target == "debug" { "dev" } else { target } }} {{ target-triple-flag }} -p hyperlight-host --test integration_test -- --test-threads=1 --ignored @# metrics tests {{ cargo-cmd }} test {{ if features =="" {''} else if features=="no-default-features" {"--no-default-features" } else {"--no-default-features -F function_call_metrics," + features } }} --profile={{ if target == "debug" { "dev" } else { target } }} {{ target-triple-flag }} -p hyperlight-host --lib -- metrics::tests::test_metrics_are_emitted --exact diff --git a/src/hyperlight_guest/src/virtq/context.rs b/src/hyperlight_guest/src/virtq/context.rs index 0724e0540..39efaba17 100644 --- a/src/hyperlight_guest/src/virtq/context.rs +++ b/src/hyperlight_guest/src/virtq/context.rs @@ -33,8 +33,6 @@ use hyperlight_common::virtq::msg::{MsgKind, VirtqMsgHeader}; use hyperlight_common::virtq::{ self, BufferPool, Layout, Notifier, QueueStats, RecyclePool, Token, VirtqProducer, }; -use tracing::instrument; - use super::GuestMemOps; use crate::bail; use crate::error::Result; @@ -121,7 +119,6 @@ impl GuestContext { /// /// The reply guard is checked before submitting the readwrite chain /// to ensure G2H capacity is reserved for pending responses. - #[instrument(skip_all, level = "Info")] pub fn call_host_function>( &mut self, function_name: &str, diff --git a/src/hyperlight_guest/src/virtq/mod.rs b/src/hyperlight_guest/src/virtq/mod.rs index d86ae40cc..119fa2175 100644 --- a/src/hyperlight_guest/src/virtq/mod.rs +++ b/src/hyperlight_guest/src/virtq/mod.rs @@ -22,7 +22,7 @@ limitations under the License. pub mod context; pub mod mem; -use core::cell::UnsafeCell; +use core::cell::RefCell; use core::sync::atomic::{AtomicU8, Ordering}; use context::GuestContext; @@ -31,14 +31,15 @@ pub use mem::GuestMemOps; // Init state machine const UNINITIALIZED: u8 = 0; const INITIALIZED: u8 = 1; + static INIT_STATE: AtomicU8 = AtomicU8::new(UNINITIALIZED); +static GLOBAL_CONTEXT: SyncWrap>> = SyncWrap(RefCell::new(None)); -// Storage: UnsafeCell guarded by atomic init state. +// Sync wrapper for the global context. struct SyncWrap(T); +/// SAFETY: The guest is single-threaded. unsafe impl Sync for SyncWrap {} -static GLOBAL_CONTEXT: SyncWrap>> = SyncWrap(UnsafeCell::new(None)); - /// Check if the global context has been initialized. pub fn is_initialized() -> bool { INIT_STATE.load(Ordering::Acquire) == INITIALIZED @@ -48,14 +49,14 @@ pub fn is_initialized() -> bool { /// /// # Panics /// -/// Panics if the context has not been initialized. +/// Panics if the context has not been initialized or re-entranted. pub fn with_context(f: impl FnOnce(&mut GuestContext) -> R) -> R { assert!( INIT_STATE.load(Ordering::Acquire) == INITIALIZED, "guest context not initialized" ); - let ctx = unsafe { &mut *GLOBAL_CONTEXT.0.get() }; - f(ctx.as_mut().unwrap()) + let mut borrow = GLOBAL_CONTEXT.0.borrow_mut(); + f(borrow.as_mut().unwrap()) } /// Install the global guest context. Called once during guest init. @@ -75,5 +76,5 @@ pub fn set_global_context(ctx: GuestContext) { { panic!("guest context already initialized"); } - unsafe { *GLOBAL_CONTEXT.0.get() = Some(ctx) }; + unsafe { *GLOBAL_CONTEXT.0.as_ptr() = Some(ctx) }; } diff --git a/src/hyperlight_guest_bin/src/host_comm.rs b/src/hyperlight_guest_bin/src/host_comm.rs index 1fe7f9994..c812da357 100644 --- a/src/hyperlight_guest_bin/src/host_comm.rs +++ b/src/hyperlight_guest_bin/src/host_comm.rs @@ -28,12 +28,14 @@ use hyperlight_common::flatbuffer_wrappers::util::get_flatbuffer_result; use hyperlight_common::func::{ParameterTuple, SupportedReturnType}; use hyperlight_guest::error::{HyperlightGuestError, Result}; use hyperlight_guest::virtq; +use tracing::instrument; const BUFFER_SIZE: usize = 1000; static mut MESSAGE_BUFFER: Vec = Vec::new(); use crate::GUEST_HANDLE; +#[instrument(skip_all, level = "Info")] pub fn call_host_function( function_name: &str, parameters: Option>, diff --git a/src/hyperlight_host/tests/integration_test.rs b/src/hyperlight_host/tests/integration_test.rs index 1450783b8..091ce699a 100644 --- a/src/hyperlight_host/tests/integration_test.rs +++ b/src/hyperlight_host/tests/integration_test.rs @@ -30,6 +30,7 @@ pub mod common; // pub to disable dead_code warning use crate::common::{ new_rust_sandbox, new_rust_uninit_sandbox, with_all_sandboxes, with_c_sandbox, with_c_uninit_sandbox, with_rust_sandbox, with_rust_sandbox_cfg, with_rust_uninit_sandbox, + with_rust_uninit_sandbox_cfg, }; // A host function cannot be interrupted, but we can at least make sure after requesting to interrupt a host call, @@ -800,6 +801,167 @@ fn log_test_messages(levelfilter: Option) { } } +// The following tests depend on a global SimpleLogger or TracingSubscriber and +// cannot run in parallel with other tests. They are marked #[ignore] and run +// sequentially via `just test-isolated`. + +#[test] +#[ignore] +fn virtq_log_delivery() { + use hyperlight_testing::simplelogger::{LOGGER, SimpleLogger}; + + SimpleLogger::initialize_test_logger(); + LOGGER.clear_log_calls(); + + with_rust_uninit_sandbox(|mut sbox| { + sbox.set_max_guest_log_level(tracing_core::LevelFilter::TRACE); + let mut sandbox = sbox.evolve().unwrap(); + + sandbox + .call::<()>("LogMessage", ("virtq log test message".to_string(), 3_i32)) + .unwrap(); + + let count = LOGGER.num_log_calls(); + let mut found = false; + for i in 0..count { + if let Some(call) = LOGGER.get_log_call(i) + && call.target == "hyperlight_guest" + && call.args.contains("virtq log test") + { + found = true; + break; + } + } + assert!(found, "expected 'virtq log test' message from guest"); + LOGGER.clear_log_calls(); + }); +} + +#[test] +#[ignore] +fn virtq_log_backpressure() { + use hyperlight_testing::simplelogger::{LOGGER, SimpleLogger}; + + SimpleLogger::initialize_test_logger(); + LOGGER.clear_log_calls(); + + let mut cfg = SandboxConfiguration::default(); + cfg.set_g2h_pool_pages(2); + + with_rust_uninit_sandbox_cfg(cfg, |mut sbox| { + sbox.set_max_guest_log_level(tracing_core::LevelFilter::INFO); + let mut sandbox = sbox.evolve().unwrap(); + + sandbox.call::<()>("LogMessageN", 50_i32).unwrap(); + + let res: i32 = sandbox + .call("ThisIsNotARealFunctionButTheNameIsImportant", ()) + .unwrap(); + assert_eq!(res, 99); + + let guest_count = (0..LOGGER.num_log_calls()) + .filter_map(|i| LOGGER.get_log_call(i)) + .filter(|c| c.target == "hyperlight_guest" && c.args.contains("log entry")) + .count(); + assert_eq!(guest_count, 50, "expected 50 guest logs, got {guest_count}"); + LOGGER.clear_log_calls(); + }); +} + +#[test] +#[ignore] +fn virtq_log_backpressure_repeated() { + let mut cfg = SandboxConfiguration::default(); + cfg.set_g2h_pool_pages(2); + + with_rust_sandbox_cfg(cfg, |mut sandbox| { + for _ in 0..5 { + sandbox.call::<()>("LogMessageN", 30_i32).unwrap(); + } + }); +} + +#[test] +#[ignore] +fn virtq_backpressure_small_ring() { + use hyperlight_testing::simplelogger::{LOGGER, SimpleLogger}; + + SimpleLogger::initialize_test_logger(); + LOGGER.clear_log_calls(); + + let mut cfg = SandboxConfiguration::default(); + cfg.set_g2h_queue_depth(4); + + with_rust_uninit_sandbox_cfg(cfg, |mut sbox| { + sbox.set_max_guest_log_level(tracing_core::LevelFilter::INFO); + let mut sandbox = sbox.evolve().unwrap(); + + sandbox.call::<()>("LogMessageN", 20_i32).unwrap(); + + let guest_count = (0..LOGGER.num_log_calls()) + .filter_map(|i| LOGGER.get_log_call(i)) + .filter(|c| c.target == "hyperlight_guest" && c.args.contains("log entry")) + .count(); + assert_eq!(guest_count, 20, "expected 20 guest logs, got {guest_count}"); + LOGGER.clear_log_calls(); + }); +} + +#[test] +#[ignore] +fn virtq_log_tracing_delivery() { + use hyperlight_testing::tracing_subscriber::TracingSubscriber; + + let subscriber = TracingSubscriber::new(tracing::Level::TRACE); + + tracing::subscriber::with_default(subscriber.clone(), || { + with_rust_uninit_sandbox(|mut sbox| { + sbox.set_max_guest_log_level(tracing_core::LevelFilter::INFO); + let mut sandbox = sbox.evolve().unwrap(); + + subscriber.clear(); + + sandbox + .call::<()>("LogMessage", ("tracing delivery test".to_string(), 3_i32)) + .unwrap(); + + let events = subscriber.get_events(); + assert!( + !events.is_empty(), + "expected tracing events after guest log call, got none" + ); + }); + }); +} + +#[test] +#[ignore] +fn virtq_log_tracing_levels() { + use hyperlight_testing::tracing_subscriber::TracingSubscriber; + + let subscriber = TracingSubscriber::new(tracing::Level::TRACE); + + tracing::subscriber::with_default(subscriber.clone(), || { + with_rust_uninit_sandbox(|mut sbox| { + sbox.set_max_guest_log_level(tracing_core::LevelFilter::TRACE); + let mut sandbox = sbox.evolve().unwrap(); + + for level in [1_i32, 2, 3, 4, 5] { + subscriber.clear(); + let msg = format!("level-test-{}", level); + sandbox.call::<()>("LogMessage", (msg, level)).unwrap(); + + let events = subscriber.get_events(); + assert!( + !events.is_empty(), + "expected tracing events for guest log level {}", + level + ); + } + }); + }); +} + /// Tests whether host is able to return Bool as return type /// or not #[test] diff --git a/src/hyperlight_host/tests/sandbox_host_tests.rs b/src/hyperlight_host/tests/sandbox_host_tests.rs index 8b1d435a3..9778ab482 100644 --- a/src/hyperlight_host/tests/sandbox_host_tests.rs +++ b/src/hyperlight_host/tests/sandbox_host_tests.rs @@ -375,40 +375,6 @@ fn host_function_error() { }); } -#[test] -fn virtq_log_delivery() { - use hyperlight_testing::simplelogger::{LOGGER, SimpleLogger}; - - SimpleLogger::initialize_test_logger(); - LOGGER.clear_log_calls(); - - with_rust_uninit_sandbox(|mut sbox| { - sbox.set_max_guest_log_level(tracing_core::LevelFilter::TRACE); - let mut sandbox = sbox.evolve().unwrap(); - - sandbox - .call::<()>("LogMessage", ("virtq log test message".to_string(), 3_i32)) - .unwrap(); - - // Verify the guest log arrived via virtqueue - let count = LOGGER.num_log_calls(); - assert!(count > 0, "expected at least one guest log, got 0"); - - let mut found = false; - for i in 0..count { - if let Some(call) = LOGGER.get_log_call(i) - && call.target == "hyperlight_guest" - && call.args.contains("virtq log test") - { - found = true; - break; - } - } - assert!(found, "expected 'virtq log test' message from guest"); - LOGGER.clear_log_calls(); - }); -} - #[test] fn virtq_log_with_callback() { // Verify that log messages interleaved with host callbacks work @@ -431,79 +397,6 @@ fn virtq_log_with_callback() { }); } -#[test] -fn virtq_log_backpressure() { - use hyperlight_testing::simplelogger::{LOGGER, SimpleLogger}; - - SimpleLogger::initialize_test_logger(); - LOGGER.clear_log_calls(); - - let mut cfg = SandboxConfiguration::default(); - cfg.set_g2h_pool_pages(2); - - with_rust_uninit_sandbox_cfg(cfg, |mut sbox| { - sbox.set_max_guest_log_level(tracing_core::LevelFilter::INFO); - let mut sandbox = sbox.evolve().unwrap(); - - // 50 logs with a 2-page pool should trigger backpressure - sandbox.call::<()>("LogMessageN", 50_i32).unwrap(); - - // Verify sandbox is still functional after backpressure - let res: i32 = sandbox - .call("ThisIsNotARealFunctionButTheNameIsImportant", ()) - .unwrap(); - assert_eq!(res, 99); - - // Verify all 50 log entries were delivered - let guest_count = (0..LOGGER.num_log_calls()) - .filter_map(|i| LOGGER.get_log_call(i)) - .filter(|c| c.target == "hyperlight_guest" && c.args.contains("log entry")) - .count(); - assert_eq!(guest_count, 50, "expected 50 guest logs, got {guest_count}"); - LOGGER.clear_log_calls(); - }); -} - -#[test] -fn virtq_log_backpressure_repeated() { - // Multiple calls that each trigger backpressure, verifying the - // pool recovers correctly each time. - let mut cfg = SandboxConfiguration::default(); - cfg.set_g2h_pool_pages(2); - - with_rust_sandbox_cfg(cfg, |mut sandbox| { - for _ in 0..5 { - sandbox.call::<()>("LogMessageN", 30_i32).unwrap(); - } - }); -} - -#[test] -fn virtq_backpressure_small_ring() { - // Small descriptor table forces ring-level backpressure. - use hyperlight_testing::simplelogger::{LOGGER, SimpleLogger}; - - SimpleLogger::initialize_test_logger(); - LOGGER.clear_log_calls(); - - let mut cfg = SandboxConfiguration::default(); - cfg.set_g2h_queue_depth(4); - - with_rust_uninit_sandbox_cfg(cfg, |mut sbox| { - sbox.set_max_guest_log_level(tracing_core::LevelFilter::INFO); - let mut sandbox = sbox.evolve().unwrap(); - - sandbox.call::<()>("LogMessageN", 20_i32).unwrap(); - - let guest_count = (0..LOGGER.num_log_calls()) - .filter_map(|i| LOGGER.get_log_call(i)) - .filter(|c| c.target == "hyperlight_guest" && c.args.contains("log entry")) - .count(); - assert_eq!(guest_count, 20, "expected 20 guest logs, got {guest_count}"); - LOGGER.clear_log_calls(); - }); -} - #[test] fn virtq_backpressure_log_then_callback() { // Logs fill the G2H ring, then a host callback needs ring space. @@ -553,65 +446,6 @@ fn virtq_backpressure_no_data_loss() { }); } -#[test] -fn virtq_log_tracing_delivery() { - // Verify guest logs are emitted as tracing events when a tracing - // subscriber is active, matching the behavior of the old outb_log. - use hyperlight_testing::tracing_subscriber::TracingSubscriber; - - let subscriber = TracingSubscriber::new(tracing::Level::TRACE); - - tracing::subscriber::with_default(subscriber.clone(), || { - with_rust_uninit_sandbox(|mut sbox| { - sbox.set_max_guest_log_level(tracing_core::LevelFilter::INFO); - let mut sandbox = sbox.evolve().unwrap(); - - subscriber.clear(); - - sandbox - .call::<()>("LogMessage", ("tracing delivery test".to_string(), 3_i32)) - .unwrap(); - - // Guest log goes through format_trace which creates tracing - // events with log.target = "hyperlight_guest" as a field. - let events = subscriber.get_events(); - assert!( - !events.is_empty(), - "expected tracing events after guest log call, got none" - ); - }); - }); -} - -#[test] -fn virtq_log_tracing_levels() { - // Verify each guest log level produces tracing events. - use hyperlight_testing::tracing_subscriber::TracingSubscriber; - - let subscriber = TracingSubscriber::new(tracing::Level::TRACE); - - tracing::subscriber::with_default(subscriber.clone(), || { - with_rust_uninit_sandbox(|mut sbox| { - sbox.set_max_guest_log_level(tracing_core::LevelFilter::TRACE); - let mut sandbox = sbox.evolve().unwrap(); - - // Test each level: 1=Trace, 2=Debug, 3=Info, 4=Warn, 5=Error - for level in [1_i32, 2, 3, 4, 5] { - subscriber.clear(); - let msg = format!("level-test-{}", level); - sandbox.call::<()>("LogMessage", (msg, level)).unwrap(); - - let events = subscriber.get_events(); - assert!( - !events.is_empty(), - "expected tracing events for guest log level {}", - level - ); - } - }); - }); -} - #[test] fn virtq_invalid_guest_function_returns_error() { // Calling a non-existent guest function should return a proper From 74f3898893ea2b89d3a328e5f7baf4e6b7ba6e05 Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Mon, 13 Apr 2026 15:19:12 +0200 Subject: [PATCH 26/31] chore(virtq): update lock file Signed-off-by: Tomasz Andrzejak --- src/hyperlight_guest/src/virtq/context.rs | 1 + src/hyperlight_host/tests/sandbox_host_tests.rs | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/hyperlight_guest/src/virtq/context.rs b/src/hyperlight_guest/src/virtq/context.rs index 39efaba17..a13288fd3 100644 --- a/src/hyperlight_guest/src/virtq/context.rs +++ b/src/hyperlight_guest/src/virtq/context.rs @@ -33,6 +33,7 @@ use hyperlight_common::virtq::msg::{MsgKind, VirtqMsgHeader}; use hyperlight_common::virtq::{ self, BufferPool, Layout, Notifier, QueueStats, RecyclePool, Token, VirtqProducer, }; + use super::GuestMemOps; use crate::bail; use crate::error::Result; diff --git a/src/hyperlight_host/tests/sandbox_host_tests.rs b/src/hyperlight_host/tests/sandbox_host_tests.rs index 9778ab482..cc480e9fe 100644 --- a/src/hyperlight_host/tests/sandbox_host_tests.rs +++ b/src/hyperlight_host/tests/sandbox_host_tests.rs @@ -26,8 +26,7 @@ use hyperlight_testing::simple_guest_as_string; pub mod common; // pub to disable dead_code warning use crate::common::{ with_all_sandboxes, with_all_sandboxes_cfg, with_all_sandboxes_with_writer, - with_all_uninit_sandboxes, with_rust_sandbox_cfg, with_rust_uninit_sandbox, - with_rust_uninit_sandbox_cfg, + with_all_uninit_sandboxes, with_rust_sandbox_cfg, with_rust_uninit_sandbox_cfg, }; #[test] From 068d73a2b941a2a3d0bbfe8177e18c2600f9987f Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Thu, 16 Apr 2026 11:39:17 +0200 Subject: [PATCH 27/31] feat(virtq): add with_hint function call API Signed-off-by: Tomasz Andrzejak --- src/hyperlight_guest/src/virtq/context.rs | 37 ++++++-- src/hyperlight_guest_bin/src/host_comm.rs | 18 ++++ src/hyperlight_host/src/sandbox/outb.rs | 26 +++--- .../tests/sandbox_host_tests.rs | 85 +++++++++++++++++++ src/tests/rust_guests/simpleguest/src/main.rs | 46 +++++++++- 5 files changed, 192 insertions(+), 20 deletions(-) diff --git a/src/hyperlight_guest/src/virtq/context.rs b/src/hyperlight_guest/src/virtq/context.rs index a13288fd3..e51598254 100644 --- a/src/hyperlight_guest/src/virtq/context.rs +++ b/src/hyperlight_guest/src/virtq/context.rs @@ -118,13 +118,35 @@ impl GuestContext { /// Call a host function via the G2H virtqueue. /// - /// The reply guard is checked before submitting the readwrite chain - /// to ensure G2H capacity is reserved for pending responses. + /// Uses the default completion capacity (4096 bytes) for the response + /// buffer. For host functions known to return large payloads, use + /// [`call_host_function_with_hint`](Self::call_host_function_with_hint). pub fn call_host_function>( &mut self, function_name: &str, parameters: Option>, return_type: ReturnType, + ) -> Result { + self.call_host_function_with_hint( + function_name, + parameters, + return_type, + self.g2h_response_cap, + ) + } + + /// Call a host function with an explicit response capacity hint. + /// + /// `resp_hint` is the total completion buffer size in bytes + /// (including wire overhead: VirtqMsgHeader + FlatBuffer framing). + /// The BufferPool allocates multiple adjacent slots when the hint + /// exceeds a single slot size, so this is zero-copy for the host. + pub fn call_host_function_with_hint>( + &mut self, + function_name: &str, + parameters: Option>, + return_type: ReturnType, + resp_hint: usize, ) -> Result { let params = parameters.as_deref().unwrap_or_default(); let estimated_capacity = estimate_flatbuffer_capacity(function_name, params); @@ -140,15 +162,15 @@ impl GuestContext { let payload = fc.encode(&mut builder); let reqid = REQUEST_ID.fetch_add(1, Relaxed); - let hdr = VirtqMsgHeader::new(MsgKind::Request, reqid, payload.len() as u32); - let hdr_bytes = bytemuck::bytes_of(&hdr); + let msg = VirtqMsgHeader::new(MsgKind::Request, reqid, payload.len() as u32); + let hdr = bytemuck::bytes_of(&msg); let entry_len = VirtqMsgHeader::SIZE + payload.len(); // Reply guard: readwrite chains use 2 descriptors, leave room for pending replies. self.ensure_reply_capacity(2)?; - let token = match self.try_send_readwrite(hdr_bytes, payload, entry_len) { + let token = match self.try_send_readwrite(hdr, payload, entry_len, resp_hint) { Ok(tok) => tok, Err(e) if e.is_transient() => { self.g2h_producer.notify_backpressure(); @@ -157,7 +179,7 @@ impl GuestContext { bail!("G2H reclaim: {err}"); } - let Ok(tok) = self.try_send_readwrite(hdr_bytes, payload, entry_len) else { + let Ok(tok) = self.try_send_readwrite(hdr, payload, entry_len, resp_hint) else { bail!("G2H call retry"); }; @@ -436,12 +458,13 @@ impl GuestContext { header: &[u8], payload: &[u8], entry_len: usize, + completion_cap: usize, ) -> result::Result { let mut entry = self .g2h_producer .chain() .entry(entry_len) - .completion(self.g2h_response_cap) + .completion(completion_cap) .build()?; entry.write_all(header)?; diff --git a/src/hyperlight_guest_bin/src/host_comm.rs b/src/hyperlight_guest_bin/src/host_comm.rs index c812da357..18cc458b9 100644 --- a/src/hyperlight_guest_bin/src/host_comm.rs +++ b/src/hyperlight_guest_bin/src/host_comm.rs @@ -47,6 +47,24 @@ where virtq::with_context(|ctx| ctx.call_host_function(function_name, parameters, return_type)) } +/// Call a host function with an explicit response capacity hint. +/// +/// `response_hint` is the total completion buffer size in bytes including wire overhead. +/// Use this when you know the host function returns a large payload (e.g., >4096 bytes). +pub fn call_host_function_with_hint( + function_name: &str, + parameters: Option>, + return_type: ReturnType, + response_hint: usize, +) -> Result +where + T: TryFrom, +{ + virtq::with_context(|ctx| { + ctx.call_host_function_with_hint(function_name, parameters, return_type, response_hint) + }) +} + pub fn call_host(function_name: impl AsRef, args: impl ParameterTuple) -> Result where T: SupportedReturnType + TryFrom, diff --git a/src/hyperlight_host/src/sandbox/outb.rs b/src/hyperlight_host/src/sandbox/outb.rs index 0b6bfb609..f6d7cc4a8 100644 --- a/src/hyperlight_host/src/sandbox/outb.rs +++ b/src/hyperlight_host/src/sandbox/outb.rs @@ -273,29 +273,31 @@ fn outb_virtq_call( .try_lock() .map_err(|e| HandleOutbError::LockFailed(file!(), line!(), e.to_string()))?; - let mut res = registry + let res = registry .call_host_function(&name, args) .map_err(|e| GuestError::new(ErrorCode::HostFunctionError, e.to_string())); - // Truncate oversized error messages so the serialized response - // fits in the completion buffer the guest pre-allocated. - if let Err(err) = &mut res - && err.message.len() > wc.capacity() - { - err.message.truncate(wc.capacity()); - } - - // Serialize response: VirtqMsgHeader + FunctionCallResult let func_result = FunctionCallResult::new(res); let mut builder = flatbuffers::FlatBufferBuilder::new(); - let result_payload = func_result.encode(&mut builder); + let mut result_payload = func_result.encode(&mut builder).to_vec(); + + let total = VirtqMsgHeader::SIZE + result_payload.len(); + if total > wc.capacity() { + let too_large = GuestError::new( + ErrorCode::HostFunctionError, + "response too large for completion buffer".into(), + ); + let fallback = FunctionCallResult::new(Err(too_large)); + let mut fb = flatbuffers::FlatBufferBuilder::new(); + result_payload = fallback.encode(&mut fb).to_vec(); + } let resp_header = VirtqMsgHeader::new(MsgKind::Response, 0, result_payload.len() as u32); let resp_header_bytes = bytemuck::bytes_of(&resp_header); wc.write_all(resp_header_bytes) .map_err(|e| HandleOutbError::WriteHostFunctionResponse(format!("{e}")))?; - wc.write_all(result_payload) + wc.write_all(&result_payload) .map_err(|e| HandleOutbError::WriteHostFunctionResponse(format!("{e}")))?; consumer .complete(wc.into()) diff --git a/src/hyperlight_host/tests/sandbox_host_tests.rs b/src/hyperlight_host/tests/sandbox_host_tests.rs index cc480e9fe..e48a0f38c 100644 --- a/src/hyperlight_host/tests/sandbox_host_tests.rs +++ b/src/hyperlight_host/tests/sandbox_host_tests.rs @@ -563,3 +563,88 @@ fn virtq_multi_descriptor_h2g_repeated_calls() { } }); } + +/// Helper to create a sandbox with a "GetLargeResponse" host function +/// that returns `size` bytes filled with 0xAB. +fn sandbox_with_large_response(cfg: SandboxConfiguration) -> MultiUseSandbox { + let mut sandbox = UninitializedSandbox::new( + GuestBinary::FilePath(simple_guest_as_string().unwrap()), + Some(cfg), + ) + .unwrap(); + sandbox + .register("GetLargeResponse", |size: i32| -> Result> { + Ok(vec![0xABu8; size as usize]) + }) + .unwrap(); + sandbox.evolve().unwrap() +} + +#[test] +fn virtq_large_g2h_response_with_hint() { + // Host function returns >4096 bytes. Guest uses a sized completion + // hint so the pool allocates multiple adjacent slots. + let mut cfg = SandboxConfiguration::default(); + cfg.set_g2h_pool_pages(16); + let mut sandbox = sandbox_with_large_response(cfg); + + // 8000 bytes of response payload. With FlatBuffer + header overhead, + // the wire size is ~8100 bytes. A hint of 3*4096 = 12288 is enough. + let hint = 3 * 4096i32; + let res: Vec = sandbox + .call("CallGetLargeResponseWithHint", (8000i32, hint)) + .unwrap(); + assert_eq!(res.len(), 8000); + assert!(res.iter().all(|&b| b == 0xAB)); +} + +#[test] +fn virtq_large_g2h_response_too_large_without_hint() { + // Without a hint, the default 4096-byte completion buffer is used. + // A response >4096 bytes should trigger the host's "response too + // large" fallback error instead of a transport crash. + let mut cfg = SandboxConfiguration::default(); + cfg.set_g2h_pool_pages(16); + let mut sandbox = sandbox_with_large_response(cfg); + + let res = sandbox.call::>("CallGetLargeResponseDefault", 8000i32); + assert!( + res.is_err(), + "expected error for oversized response without hint" + ); +} + +#[test] +fn virtq_large_g2h_response_boundary() { + // Response that fits exactly in one page (with overhead) should work + // without needing a hint. + let mut cfg = SandboxConfiguration::default(); + cfg.set_g2h_pool_pages(16); + let mut sandbox = sandbox_with_large_response(cfg); + + // Small response that fits in default 4096 buffer + let res: Vec = sandbox + .call("CallGetLargeResponseDefault", 1000i32) + .unwrap(); + assert_eq!(res.len(), 1000); + assert!(res.iter().all(|&b| b == 0xAB)); +} + +#[test] +fn virtq_large_g2h_response_after_log_backpressure() { + // Logs fill the G2H pool, then a large host response (with hint) + // needs multi-slot allocation. The backpressure path must drain + // completed log entries to free pool slots for the large completion. + let mut cfg = SandboxConfiguration::default(); + cfg.set_g2h_pool_pages(16); + let mut sandbox = sandbox_with_large_response(cfg); + + // Emit 20 log entries to consume pool slots, then request 8KB + // response with a 12KB hint (3 upper-slab slots). + let hint = 3 * 4096i32; + let res: Vec = sandbox + .call("LogThenLargeResponse", (20i32, 8000i32, hint)) + .unwrap(); + assert_eq!(res.len(), 8000); + assert!(res.iter().all(|&b| b == 0xAB)); +} diff --git a/src/tests/rust_guests/simpleguest/src/main.rs b/src/tests/rust_guests/simpleguest/src/main.rs index 0b61cb233..acb672f1c 100644 --- a/src/tests/rust_guests/simpleguest/src/main.rs +++ b/src/tests/rust_guests/simpleguest/src/main.rs @@ -49,7 +49,8 @@ use hyperlight_guest_bin::exception::arch::{Context, ExceptionInfo}; use hyperlight_guest_bin::guest_function::definition::{GuestFunc, GuestFunctionDefinition}; use hyperlight_guest_bin::guest_function::register::register_function; use hyperlight_guest_bin::host_comm::{ - call_host_function, print_output_with_host_print, read_n_bytes_from_user_memory, + call_host_function, call_host_function_with_hint, print_output_with_host_print, + read_n_bytes_from_user_memory, }; use hyperlight_guest_bin::memory::malloc; use hyperlight_guest_bin::{guest_function, guest_logger, host_function}; @@ -379,6 +380,49 @@ fn echo(value: String) -> String { value } +/// Calls a host function "GetLargeResponse" with an explicit response +/// capacity hint. The `hint` parameter is the total completion buffer +/// size in bytes. +#[guest_function("CallGetLargeResponseWithHint")] +fn call_get_large_response_with_hint(size: i32, hint: i32) -> Vec { + call_host_function_with_hint::>( + "GetLargeResponse", + Some(vec![ParameterValue::Int(size)]), + ReturnType::VecBytes, + hint as usize, + ) + .expect("GetLargeResponse call failed") +} + +/// Calls a host function "GetLargeResponse" WITHOUT a hint, using the +/// default 4096-byte completion buffer. +#[guest_function("CallGetLargeResponseDefault")] +fn call_get_large_response_default(size: i32) -> Vec { + call_host_function::>( + "GetLargeResponse", + Some(vec![ParameterValue::Int(size)]), + ReturnType::VecBytes, + ) + .expect("GetLargeResponse call failed") +} + +/// Emits `log_count` log entries to fill the G2H queue, then calls +/// "GetLargeResponse" with a sized hint. Tests that backpressure +/// draining of logs frees pool slots for the large completion. +#[guest_function("LogThenLargeResponse")] +fn log_then_large_response(log_count: i32, size: i32, hint: i32) -> Vec { + for i in 0..log_count { + log::info!("backpressure log {}", i); + } + call_host_function_with_hint::>( + "GetLargeResponse", + Some(vec![ParameterValue::Int(size)]), + ReturnType::VecBytes, + hint as usize, + ) + .expect("GetLargeResponse after logs failed") +} + #[guest_function("GetSizePrefixedBuffer")] fn get_size_prefixed_buffer(data: Vec) -> Vec { data From a3542ac5d1b7d56c6c944268ee5801f2f1293a52 Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Thu, 16 Apr 2026 14:30:29 +0200 Subject: [PATCH 28/31] feat(virtq): estimate completion capacity based on ret type Signed-off-by: Tomasz Andrzejak --- src/hyperlight_common/src/virtq/pool.rs | 9 ++++++++- src/hyperlight_guest/src/virtq/context.rs | 20 +++++++------------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/src/hyperlight_common/src/virtq/pool.rs b/src/hyperlight_common/src/virtq/pool.rs index 92db1a38e..e6dc2ae34 100644 --- a/src/hyperlight_common/src/virtq/pool.rs +++ b/src/hyperlight_common/src/virtq/pool.rs @@ -416,10 +416,17 @@ impl BufferPool { inner: SyncWrap(Rc::new(RefCell::new(inner))), }) } +} +impl BufferPool { /// Upper slab slot size in bytes. pub const fn upper_slot_size() -> usize { - U + 4096 + } + + /// Lower slab slot size in bytes. + pub const fn lower_slot_size() -> usize { + 256 } } diff --git a/src/hyperlight_guest/src/virtq/context.rs b/src/hyperlight_guest/src/virtq/context.rs index e51598254..c0dff8af1 100644 --- a/src/hyperlight_guest/src/virtq/context.rs +++ b/src/hyperlight_guest/src/virtq/context.rs @@ -72,9 +72,6 @@ pub struct GuestContext { g2h_producer: G2hProducer, /// host-to-guest driver h2g_producer: H2gProducer, - /// Max writable bytes the host can write into a G2H completion. - /// Derived from the G2H pool upper slab slot size. - g2h_response_cap: usize, /// H2G slot size in bytes (each prefilled writable descriptor). h2g_slot_size: usize, /// snapshot generation counter @@ -91,7 +88,6 @@ impl GuestContext { let size = g2h.pool_pages * PAGE_SIZE_USIZE; let g2h_pool = BufferPool::new(g2h.pool_gva, size).expect("failed to create G2H buffer pool"); - let g2h_response_cap = BufferPool::<256, 4096>::upper_slot_size(); let g2h_producer = VirtqProducer::new(g2h.layout, GuestMemOps, GuestNotifier, g2h_pool.clone()); @@ -105,7 +101,6 @@ impl GuestContext { let mut ctx = Self { g2h_producer, h2g_producer, - g2h_response_cap, h2g_slot_size, generation, pending_replies: 0, @@ -118,8 +113,7 @@ impl GuestContext { /// Call a host function via the G2H virtqueue. /// - /// Uses the default completion capacity (4096 bytes) for the response - /// buffer. For host functions known to return large payloads, use + /// For host functions known to return payloads larger than 4096 bytes, use /// [`call_host_function_with_hint`](Self::call_host_function_with_hint). pub fn call_host_function>( &mut self, @@ -127,12 +121,12 @@ impl GuestContext { parameters: Option>, return_type: ReturnType, ) -> Result { - self.call_host_function_with_hint( - function_name, - parameters, - return_type, - self.g2h_response_cap, - ) + let hint = if matches!(return_type, ReturnType::String | ReturnType::VecBytes) { + BufferPool::upper_slot_size() + } else { + BufferPool::lower_slot_size() + }; + self.call_host_function_with_hint(function_name, parameters, return_type, hint) } /// Call a host function with an explicit response capacity hint. From 2763ac028b153861dbac27bd16805dfcbc0770c7 Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Tue, 12 May 2026 13:31:42 +0200 Subject: [PATCH 29/31] fix(virtq): cargo fmt Signed-off-by: Tomasz Andrzejak --- src/hyperlight_common/src/virtq/desc.rs.rej | 23 ------------------- src/hyperlight_common/src/virtq/mod.rs | 4 +++- .../src/guest_function/call.rs | 3 +-- 3 files changed, 4 insertions(+), 26 deletions(-) delete mode 100644 src/hyperlight_common/src/virtq/desc.rs.rej diff --git a/src/hyperlight_common/src/virtq/desc.rs.rej b/src/hyperlight_common/src/virtq/desc.rs.rej deleted file mode 100644 index 2172452ba..000000000 --- a/src/hyperlight_common/src/virtq/desc.rs.rej +++ /dev/null @@ -1,23 +0,0 @@ -diff a/src/hyperlight_common/src/virtq/desc.rs b/src/hyperlight_common/src/virtq/desc.rs (rejected hunks) -@@ -58,12 +58,15 @@ pub struct Descriptor { - pub flags: u16, - } - --const _: () = assert!(core::mem::size_of::() == 16); --const _: () = assert!(Descriptor::ALIGN == 16); --const _: () = assert!(Descriptor::ADDR_OFFSET == 0); --const _: () = assert!(Descriptor::LEN_OFFSET == 8); --const _: () = assert!(Descriptor::ID_OFFSET == 12); --const _: () = assert!(Descriptor::FLAGS_OFFSET == 14); -+#[allow(clippy::disallowed_macros)] -+const _: () = { -+ assert!(core::mem::size_of::() == 16); -+ assert!(Descriptor::ALIGN == 16); -+ assert!(Descriptor::ADDR_OFFSET == 0); -+ assert!(Descriptor::LEN_OFFSET == 8); -+ assert!(Descriptor::ID_OFFSET == 12); -+ assert!(Descriptor::FLAGS_OFFSET == 14); -+}; - - impl Descriptor { - // VIRTIO spec requires 16-byte alignment for descriptors diff --git a/src/hyperlight_common/src/virtq/mod.rs b/src/hyperlight_common/src/virtq/mod.rs index ac1110bdd..688fd8a9b 100644 --- a/src/hyperlight_common/src/virtq/mod.rs +++ b/src/hyperlight_common/src/virtq/mod.rs @@ -515,7 +515,9 @@ pub(crate) mod test_utils { /// Create test infrastructure: a producer, consumer, and notifier backed /// by the supplied [`OwnedRing`]. - pub(crate) fn make_test_producer(ring: &OwnedRing) -> (TestProducer, TestConsumer, TestNotifier) { + pub(crate) fn make_test_producer( + ring: &OwnedRing, + ) -> (TestProducer, TestConsumer, TestNotifier) { let layout = ring.layout(); let mem = ring.mem(); diff --git a/src/hyperlight_guest_bin/src/guest_function/call.rs b/src/hyperlight_guest_bin/src/guest_function/call.rs index ecee52ace..105053932 100644 --- a/src/hyperlight_guest_bin/src/guest_function/call.rs +++ b/src/hyperlight_guest_bin/src/guest_function/call.rs @@ -21,9 +21,8 @@ use flatbuffers::FlatBufferBuilder; use hyperlight_common::flatbuffer_wrappers::function_call::{FunctionCall, FunctionCallType}; use hyperlight_common::flatbuffer_wrappers::function_types::{FunctionCallResult, ParameterType}; use hyperlight_common::flatbuffer_wrappers::guest_error::{ErrorCode, GuestError}; -use hyperlight_guest::bail; use hyperlight_guest::error::{HyperlightGuestError, Result}; -use hyperlight_guest::virtq; +use hyperlight_guest::{bail, virtq}; use tracing::instrument; use crate::REGISTERED_GUEST_FUNCTIONS; From e4fd644b2346ff39793249c8f05ae18bb5db42d9 Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Tue, 12 May 2026 17:36:44 +0200 Subject: [PATCH 30/31] feat(virtq): add producer side batch API Signed-off-by: Tomasz Andrzejak --- src/hyperlight_common/src/virtq/buffer.rs | 30 +- src/hyperlight_common/src/virtq/mod.rs | 10 +- src/hyperlight_common/src/virtq/producer.rs | 288 ++++++++++++++++++-- src/hyperlight_guest/src/virtq/context.rs | 21 +- 4 files changed, 304 insertions(+), 45 deletions(-) diff --git a/src/hyperlight_common/src/virtq/buffer.rs b/src/hyperlight_common/src/virtq/buffer.rs index 7b637e38b..522ffcf84 100644 --- a/src/hyperlight_common/src/virtq/buffer.rs +++ b/src/hyperlight_common/src/virtq/buffer.rs @@ -117,15 +117,37 @@ impl Drop for BufferOwner { } } +impl BufferOwner { + pub(crate) fn try_new( + pool: P, + mem: M, + alloc: Allocation, + written: usize, + ) -> Result { + // Pre check direct access before handing the owner to Bytes::from_owner + let len = written.min(alloc.len); + let _ = unsafe { mem.as_slice(alloc.addr, len) }?; + + Ok(Self { + pool, + mem, + alloc, + written, + }) + } +} + impl AsRef<[u8]> for BufferOwner { fn as_ref(&self) -> &[u8] { let len = self.written.min(self.alloc.len); - // Safety: BufferOwner keeps both the pool allocation and the M - // alive, so the memory region is valid. Protocol-level descriptor - // ownership transfer guarantees no concurrent writes. + // Safety: BufferOwner keeps both the pool allocation and the M alive, + // so the memory region is valid. match unsafe { self.mem.as_slice(self.alloc.addr, len) } { Ok(slice) => slice, - Err(_) => &[], + Err(_) => { + debug_assert!(false, "BufferOwner direct slice failed"); + &[] + } } } } diff --git a/src/hyperlight_common/src/virtq/mod.rs b/src/hyperlight_common/src/virtq/mod.rs index 688fd8a9b..ed6d7f2cb 100644 --- a/src/hyperlight_common/src/virtq/mod.rs +++ b/src/hyperlight_common/src/virtq/mod.rs @@ -79,17 +79,21 @@ limitations under the License. //! //! ## Multiple Entries //! -//! Each submit checks event suppression and notifies independently: +//! Each submit checks event suppression and notifies independently. Use +//! [`VirtqProducer::batch`] when a higher-level protocol wants to publish +//! multiple entries and kick the queue once. //! //! ```ignore +//! let mut batch = producer.batch(); //! for data in entries { -//! let mut se = producer.chain() +//! let mut se = batch.chain() //! .entry(data.len()) //! .completion(64) //! .build()?; //! se.write_all(data)?; -//! producer.submit(se)?; +//! batch.submit(se)?; //! } +//! batch.finish()?; //! ``` //! //! ## Completion Batching with Event Suppression diff --git a/src/hyperlight_common/src/virtq/producer.rs b/src/hyperlight_common/src/virtq/producer.rs index ff8536ff6..8e52fe357 100644 --- a/src/hyperlight_common/src/virtq/producer.rs +++ b/src/hyperlight_common/src/virtq/producer.rs @@ -253,20 +253,28 @@ where self.pool.dealloc(entry)?; } + let completion_guard = inf.completion().map(|buf| { + let pool = self.pool.clone(); + AllocGuard::new(buf, move |a| { + let _ = pool.dealloc(a); + }) + }); + // Read completion data - let has_completion = inf.completion().is_some(); - let data = match inf.completion() { + let has_completion = completion_guard.is_some(); + let data = match completion_guard { Some(buf) => { if written > buf.len { - let _ = self.pool.dealloc(buf); return Err(VirtqError::InvalidState); } - let owner = BufferOwner { - pool: self.pool.clone(), - mem: self.inner.mem().clone(), - alloc: buf, + let owner = BufferOwner::try_new( + self.pool.clone(), + self.inner.mem().clone(), + *buf, written, - }; + ) + .map_err(|_| VirtqError::MemoryReadError)?; + let _ = buf.release(); Bytes::from_owner(owner) } None => Bytes::new(), @@ -315,23 +323,42 @@ where ChainBuilder::new(self.inner.mem().clone(), self.pool.clone()) } + /// Begin a batch of submissions. + /// + /// Entries submitted through the returned [`SubmitBatch`] are published to + /// the ring immediately, but the consumer is notified at most once when + /// [`SubmitBatch::finish`] is called. This mirrors the virtio pattern of + /// adding multiple buffers and then kicking the queue once. + pub fn batch(&mut self) -> SubmitBatch<'_, M, N, P> { + SubmitBatch::new(self) + } + /// Submit a [`SendEntry`] to the ring. /// /// Publishes the descriptor chain, stores the in-flight tracking state, - /// and notifies the consumer if event suppression allows. + /// and notifies the consumer if event suppression allows. Notifications + /// are layout-neutral; use [`batch`](Self::batch) when a higher-level + /// protocol wants to publish multiple entries and kick once. /// /// # Errors /// /// - [`VirtqError::EntryTooLarge`] - written exceeds entry buffer capacity /// - [`VirtqError::RingError`] - ring is full /// - [`VirtqError::InvalidState`] - descriptor ID collision - pub fn submit(&mut self, mut entry: SendEntry) -> Result { + pub fn submit(&mut self, entry: SendEntry) -> Result { + let cursor_before = self.inner.avail_cursor(); + let token = self.publish(entry)?; + self.notify_since(cursor_before)?; + Ok(token) + } + + fn publish(&mut self, mut entry: SendEntry) -> Result { let written = entry.written; - let inflight = entry.inflight.take().ok_or(VirtqError::InvalidState)?; + let inflight = *entry.inflight.as_ref().ok_or(VirtqError::InvalidState)?; - let cursor_before = self.inner.avail_cursor(); let chain = inflight.try_into_chain(written)?; let id = self.inner.submit_available(&chain)?; + let inflight = entry.inflight.take().ok_or(VirtqError::InvalidState)?; let token = Token(self.next_token, id); self.next_token = self.next_token.wrapping_add(1); @@ -347,33 +374,31 @@ where *slot = Some((token, inflight)); - let should_notify = self.inner.should_notify_since(cursor_before)?; - - // TODO(virtq): for now simulate current outb behavior of only - // notifying on bidirectional (request/response) entries. - // Eventually this should be decoupled from the buffer layout - // and driven entirely by event suppression rules. - let should_notify = should_notify && matches!(inflight, Inflight::ReadWrite { .. }); + Ok(token) + } + fn notify_since(&mut self, cursor: RingCursor) -> Result { + let should_notify = self.inner.should_notify_since(cursor)?; if should_notify { - self.notifier.notify(QueueStats { - num_free: self.inner.num_free(), - num_inflight: self.inner.num_inflight(), - }); + self.notify_now(); } + Ok(should_notify) + } - Ok(token) + fn notify_now(&self) { + self.notifier.notify(QueueStats { + num_free: self.inner.num_free(), + num_inflight: self.inner.num_inflight(), + }); } /// Signal backpressure to the consumer. /// - /// Bypasses event suppression. Call this when submit fails with a backpressure error and the consumer needs to drain. + /// Bypasses event suppression. Call this when submit fails with a + /// backpressure error and the consumer needs to drain. #[inline] pub fn notify_backpressure(&self) { - self.notifier.notify(QueueStats { - num_free: self.inner.num_free(), - num_inflight: self.inner.num_inflight(), - }); + self.notify_now(); } /// Get the current used cursor position. @@ -454,6 +479,58 @@ where } } +/// A scoped batch of producer submissions. +/// +/// Submissions are published immediately, while notification is delayed until +/// [`finish`](Self::finish). `finish` is explicit because the event-suppression +/// check can fail; dropping a batch does not notify. +#[must_use = "call finish to notify the consumer about batched submissions"] +pub struct SubmitBatch<'a, M, N, P> { + producer: &'a mut VirtqProducer, + notify_from: Option, +} + +impl<'a, M, N, P> SubmitBatch<'a, M, N, P> +where + M: MemOps + Clone, + N: Notifier, + P: BufferProvider + Clone, +{ + fn new(producer: &'a mut VirtqProducer) -> Self { + Self { + producer, + notify_from: None, + } + } + + /// Begin building a descriptor chain for this batch. + pub fn chain(&self) -> ChainBuilder { + self.producer.chain() + } + + /// Publish an entry as part of this batch without notifying yet. + pub fn submit(&mut self, entry: SendEntry) -> Result { + let cursor_before = self.producer.inner.avail_cursor(); + let token = self.producer.publish(entry)?; + if self.notify_from.is_none() { + self.notify_from = Some(cursor_before); + } + Ok(token) + } + + /// Finish the batch and notify the consumer once if event suppression + /// requires it for the whole published range. + /// + /// Returns `true` if a notification was sent. + pub fn finish(mut self) -> Result { + let Some(notify_from) = self.notify_from.take() else { + return Ok(false); + }; + + self.producer.notify_since(notify_from) + } +} + /// Snapshot restore support for producers backed by [`RecyclePool`]. impl VirtqProducer where @@ -800,6 +877,39 @@ mod tests { RecyclePool::new(pool_base, slot_count * SLOT_SIZE, SLOT_SIZE).unwrap() } + #[derive(Clone)] + struct NoDirectSliceMem(TestMem); + + // SAFETY: Delegates all non-slice memory operations to TestMem. Direct + // slices are intentionally unsupported to exercise producer error handling. + unsafe impl MemOps for NoDirectSliceMem { + type Error = (); + + fn read(&self, addr: u64, dst: &mut [u8]) -> Result<(), Self::Error> { + self.0.read(addr, dst).map_err(|e| match e {}) + } + + fn write(&self, addr: u64, src: &[u8]) -> Result<(), Self::Error> { + self.0.write(addr, src).map_err(|e| match e {}) + } + + fn load_acquire(&self, addr: u64) -> Result { + self.0.load_acquire(addr).map_err(|e| match e {}) + } + + fn store_release(&self, addr: u64, val: u16) -> Result<(), Self::Error> { + self.0.store_release(addr, val).map_err(|e| match e {}) + } + + unsafe fn as_slice(&self, _addr: u64, _len: usize) -> Result<&[u8], Self::Error> { + Err(()) + } + + unsafe fn as_mut_slice(&self, _addr: u64, _len: usize) -> Result<&mut [u8], Self::Error> { + Err(()) + } + } + #[test] fn test_chain_readwrite_build() { let ring = make_ring(16); @@ -940,6 +1050,99 @@ mod tests { assert!(notifier.notification_count() > initial_count); } + #[test] + fn test_submit_read_only_notifies_by_default() { + let ring = make_ring(16); + let (mut producer, _consumer, notifier) = make_test_producer(&ring); + + let initial_count = notifier.notification_count(); + + let mut se = producer.chain().entry(64).build().unwrap(); + se.write_all(b"fire-and-forget").unwrap(); + producer.submit(se).unwrap(); + + assert!(notifier.notification_count() > initial_count); + } + + #[test] + fn test_submit_write_only_notifies_by_default() { + let ring = make_ring(16); + let (mut producer, _consumer, notifier) = make_test_producer(&ring); + + let initial_count = notifier.notification_count(); + + let se = producer.chain().completion(128).build().unwrap(); + producer.submit(se).unwrap(); + + assert!(notifier.notification_count() > initial_count); + } + + #[test] + fn test_batch_notifies_once_on_finish() { + let ring = make_ring(16); + let (mut producer, mut consumer, notifier) = make_test_producer(&ring); + + let initial_count = notifier.notification_count(); + + let mut batch = producer.batch(); + + let mut first = batch.chain().entry(64).build().unwrap(); + first.write_all(b"first").unwrap(); + batch.submit(first).unwrap(); + + let mut second = batch.chain().entry(64).build().unwrap(); + second.write_all(b"second").unwrap(); + batch.submit(second).unwrap(); + + assert_eq!(notifier.notification_count(), initial_count); + assert!(batch.finish().unwrap()); + assert_eq!(notifier.notification_count(), initial_count + 1); + + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.data().as_ref(), b"first"); + consumer.complete(completion).unwrap(); + + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.data().as_ref(), b"second"); + consumer.complete(completion).unwrap(); + } + + #[test] + fn test_batch_finish_notifies_from_batch_start_cursor() { + let ring = make_ring(16); + let (mut producer, mut consumer, notifier) = make_test_producer(&ring); + + let cursor = consumer.avail_cursor(); + consumer + .set_avail_suppression(SuppressionKind::Descriptor(cursor)) + .unwrap(); + + let mut batch = producer.batch(); + + let mut first = batch.chain().entry(64).build().unwrap(); + first.write_all(b"first").unwrap(); + batch.submit(first).unwrap(); + assert_eq!(notifier.notification_count(), 0); + + let mut second = batch.chain().entry(64).completion(64).build().unwrap(); + second.write_all(b"second").unwrap(); + batch.submit(second).unwrap(); + + assert!(batch.finish().unwrap()); + + assert_eq!(notifier.notification_count(), 1); + } + + #[test] + fn test_empty_batch_finish_does_not_notify() { + let ring = make_ring(16); + let (mut producer, _consumer, notifier) = make_test_producer(&ring); + + let batch = producer.batch(); + assert!(!batch.finish().unwrap()); + assert_eq!(notifier.notification_count(), 0); + } + #[test] fn test_set_written_too_large() { let ring = make_ring(16); @@ -1019,6 +1222,33 @@ mod tests { assert_eq!(&cqe.data[..], b"response data"); } + #[test] + fn test_poll_completion_requires_direct_slice() { + let ring = make_ring(16); + let layout = ring.layout(); + let test_mem = ring.mem(); + let pool_base = test_mem.base_addr() + Layout::query_size(ring.len()) as u64 + 0x100; + let pool = TestPool::new(pool_base, 0x8000); + let notifier = TestNotifier::new(); + let mem = NoDirectSliceMem(test_mem); + let mut producer = VirtqProducer::new(layout, mem.clone(), notifier.clone(), pool); + let mut consumer = VirtqConsumer::new(layout, mem, notifier); + + let mut se = producer.chain().entry(64).completion(128).build().unwrap(); + se.write_all(b"request data").unwrap(); + producer.submit(se).unwrap(); + + let (_entry, completion) = consumer.poll(1024).unwrap().unwrap(); + if let SendCompletion::Writable(mut wc) = completion { + wc.write_all(b"response data").unwrap(); + consumer.complete(wc.into()).unwrap(); + } else { + panic!("expected Writable"); + } + + assert!(matches!(producer.poll(), Err(VirtqError::MemoryReadError))); + } + #[test] fn test_virtq_producer_reset() { let ring = make_ring(16); diff --git a/src/hyperlight_guest/src/virtq/context.rs b/src/hyperlight_guest/src/virtq/context.rs index c0dff8af1..08b668865 100644 --- a/src/hyperlight_guest/src/virtq/context.rs +++ b/src/hyperlight_guest/src/virtq/context.rs @@ -357,21 +357,24 @@ impl GuestContext { /// Pre-fill the H2G queue with completion-only descriptors so the host /// can write incoming call payloads into them. fn prefill_h2g(&mut self) -> Result<()> { + let mut batch = self.h2g_producer.batch(); + loop { - let entry = match self - .h2g_producer - .chain() - .completion(self.h2g_slot_size) - .build() - { + let entry = match batch.chain().completion(self.h2g_slot_size).build() { Ok(e) => e, - Err(e) if e.is_transient() => return Ok(()), + Err(e) if e.is_transient() => { + batch.finish()?; + return Ok(()); + } Err(e) => bail!("H2G prefill build: {e}"), }; - match self.h2g_producer.submit(entry) { + match batch.submit(entry) { Ok(_) => {} - Err(e) if e.is_transient() => return Ok(()), + Err(e) if e.is_transient() => { + batch.finish()?; + return Ok(()); + } Err(e) => bail!("H2G prefill submit: {e}"), } } From 93cde039e417a5c11f5cce08957de071e64808b3 Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Tue, 12 May 2026 20:39:07 +0200 Subject: [PATCH 31/31] feat(virtq): introduce pool alloc for managing allocation lifetime --- src/hyperlight_common/src/virtq/buffer.rs | 117 +++++++++++--------- src/hyperlight_common/src/virtq/producer.rs | 49 +++----- 2 files changed, 78 insertions(+), 88 deletions(-) diff --git a/src/hyperlight_common/src/virtq/buffer.rs b/src/hyperlight_common/src/virtq/buffer.rs index 522ffcf84..bcdbea0b6 100644 --- a/src/hyperlight_common/src/virtq/buffer.rs +++ b/src/hyperlight_common/src/virtq/buffer.rs @@ -97,7 +97,7 @@ impl BufferProvider for Arc { /// The owner of a mapped buffer, ensuring its lifetime. /// -/// Holds a pool allocation and provides direct access to the underlying +/// Holds a [`PoolAlloc`] and provides direct access to the underlying /// shared memory via [`MemOps::as_slice`]. Implements `AsRef<[u8]>` so it /// can be used with [`Bytes::from_owner`](bytes::Bytes::from_owner) for /// zero-copy `Bytes` backed by shared memory. @@ -105,44 +105,18 @@ impl BufferProvider for Arc { /// When dropped, the allocation is returned to the pool. #[derive(Debug)] pub struct BufferOwner { - pub(crate) pool: P, + pub(crate) alloc: PoolAlloc

, pub(crate) mem: M, - pub(crate) alloc: Allocation, pub(crate) written: usize, } -impl Drop for BufferOwner { - fn drop(&mut self) { - let _ = self.pool.dealloc(self.alloc); - } -} - -impl BufferOwner { - pub(crate) fn try_new( - pool: P, - mem: M, - alloc: Allocation, - written: usize, - ) -> Result { - // Pre check direct access before handing the owner to Bytes::from_owner - let len = written.min(alloc.len); - let _ = unsafe { mem.as_slice(alloc.addr, len) }?; - - Ok(Self { - pool, - mem, - alloc, - written, - }) - } -} - impl AsRef<[u8]> for BufferOwner { fn as_ref(&self) -> &[u8] { - let len = self.written.min(self.alloc.len); + let alloc = self.alloc.allocation(); + let len = self.written.min(alloc.len); // Safety: BufferOwner keeps both the pool allocation and the M alive, // so the memory region is valid. - match unsafe { self.mem.as_slice(self.alloc.addr, len) } { + match unsafe { self.mem.as_slice(alloc.addr, len) } { Ok(slice) => slice, Err(_) => { debug_assert!(false, "BufferOwner direct slice failed"); @@ -152,41 +126,74 @@ impl AsRef<[u8]> for BufferOwner { } } -/// A guard that runs a cleanup function when dropped, unless dismissed. -pub struct AllocGuard(Option<(Allocation, F)>); +/// Pool-owned allocation that is returned to the pool on drop. +/// +/// Use [`into_raw`](Self::into_raw) to transfer ownership to a descriptor +/// state that will deallocate the raw [`Allocation`] through another path. +#[derive(Debug)] +pub struct PoolAlloc { + inner: Option>, +} + +#[derive(Debug)] +struct PoolAllocInner { + pool: P, + alloc: Allocation, +} -impl AllocGuard { - pub fn new(alloc: Allocation, cleanup: F) -> Self { - Self(Some((alloc, cleanup))) +impl PoolAlloc

{ + /// Wrap an existing allocation with its owning pool. + pub fn new(pool: P, alloc: Allocation) -> Self { + Self { + inner: Some(PoolAllocInner { pool, alloc }), + } + } + + /// Allocate from `pool` and return an owning guard. + pub fn allocate(pool: P, len: usize) -> Result { + let alloc = pool.alloc(len)?; + Ok(Self::new(pool, alloc)) + } + + /// The raw allocation currently owned by this guard. + pub fn allocation(&self) -> Allocation { + self.inner + .as_ref() + .map(|inner| inner.alloc) + .unwrap_or_else(|| { + unreachable!("PoolAlloc::allocation called after ownership transfer") + }) } - pub fn release(mut self) -> Allocation { - // Safety: AllocGuard is always constructed with Some, and release is only called once - self.0 + /// Release ownership and return the raw allocation. + pub fn into_raw(mut self) -> Allocation { + self.inner .take() - .map(|(alloc, _)| alloc) - .unwrap_or_else(|| unreachable!("AllocGuard::release called on dismissed guard")) + .map(|inner| inner.alloc) + .unwrap_or_else(|| unreachable!("PoolAlloc::into_raw called after ownership transfer")) } -} -impl core::ops::Deref for AllocGuard { - type Target = Allocation; + pub(crate) fn into_buffer_owner( + self, + mem: M, + written: usize, + ) -> Result, M::Error> { + let alloc = self.allocation(); + let len = written.min(alloc.len); + let _ = unsafe { mem.as_slice(alloc.addr, len) }?; - fn deref(&self) -> &Allocation { - // Safety: AllocGuard is always constructed with Some, and the inner value is only - // taken by release() or Drop. - &self - .0 - .as_ref() - .unwrap_or_else(|| unreachable!("AllocGuard::deref called on dismissed guard")) - .0 + Ok(BufferOwner { + alloc: self, + mem, + written, + }) } } -impl Drop for AllocGuard { +impl Drop for PoolAlloc

{ fn drop(&mut self) { - if let Some((alloc, cleanup)) = self.0.take() { - cleanup(alloc) + if let Some(PoolAllocInner { pool, alloc }) = self.inner.take() { + let _ = pool.dealloc(alloc); } } } diff --git a/src/hyperlight_common/src/virtq/producer.rs b/src/hyperlight_common/src/virtq/producer.rs index 8e52fe357..7d3e35ab2 100644 --- a/src/hyperlight_common/src/virtq/producer.rs +++ b/src/hyperlight_common/src/virtq/producer.rs @@ -253,30 +253,21 @@ where self.pool.dealloc(entry)?; } - let completion_guard = inf.completion().map(|buf| { - let pool = self.pool.clone(); - AllocGuard::new(buf, move |a| { - let _ = pool.dealloc(a); - }) - }); + let completion_guard = inf + .completion() + .map(|buf| PoolAlloc::new(self.pool.clone(), buf)); // Read completion data let has_completion = completion_guard.is_some(); let data = match completion_guard { - Some(buf) => { - if written > buf.len { - return Err(VirtqError::InvalidState); - } - let owner = BufferOwner::try_new( - self.pool.clone(), - self.inner.mem().clone(), - *buf, - written, - ) - .map_err(|_| VirtqError::MemoryReadError)?; - let _ = buf.release(); - Bytes::from_owner(owner) + Some(buf) if written > buf.allocation().len => { + // This is a protocol violation + return Err(VirtqError::InvalidState); } + Some(buf) => Bytes::from_owner( + buf.into_buffer_owner(self.inner.mem().clone(), written) + .map_err(|_| VirtqError::MemoryReadError)?, + ), None => Bytes::new(), }; @@ -642,16 +633,8 @@ impl ChainBuilder { } } - fn alloc( - &self, - size: usize, - ) -> Result>, VirtqError> { - let alloc = self.pool.alloc(size)?; - let pool = self.pool.clone(); - - Ok(AllocGuard::new(alloc, move |a| { - let _ = pool.dealloc(a); - })) + fn alloc(&self, size: usize) -> Result, VirtqError> { + Ok(PoolAlloc::allocate(self.pool.clone(), size)?) } /// Request an entry buffer of `cap` bytes. @@ -688,14 +671,14 @@ impl ChainBuilder { let inflight = match (entry_alloc, completion_alloc) { (Some(entry), Some(cqe)) => Inflight::ReadWrite { - entry: entry.release(), - completion: cqe.release(), + entry: entry.into_raw(), + completion: cqe.into_raw(), }, (Some(entry), None) => Inflight::ReadOnly { - entry: entry.release(), + entry: entry.into_raw(), }, (None, Some(cqe)) => Inflight::WriteOnly { - completion: cqe.release(), + completion: cqe.into_raw(), }, (None, None) => unreachable!(), };