diff --git a/.gitmodules b/.gitmodules index b7b8a5c57..8c039127b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -2,3 +2,6 @@ path = src/hyperlight_libc/third_party/picolibc url = https://github.com/hyperlight-dev/picolibc-bsd.git shallow = true +[submodule "src/hyperlight_libc/third_party/mimalloc"] + path = src/hyperlight_libc/third_party/mimalloc + url = https://github.com/microsoft/mimalloc diff --git a/Cargo.lock b/Cargo.lock index 40790f9f1..58f446f61 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -925,6 +925,12 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" +[[package]] +name = "fixedbitset" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" + [[package]] name = "flatbuffers" version = "25.12.19" @@ -1485,8 +1491,13 @@ dependencies = [ "arbitrary", "bitflags 2.11.1", "bytemuck", + "bytes", + "criterion", + "fixedbitset", "flatbuffers", + "hyperlight-testing", "log", + "loom", "quickcheck", "rand 0.9.2", "smallvec", @@ -1538,6 +1549,7 @@ name = "hyperlight-guest" version = "0.15.0" dependencies = [ "anyhow", + "bytemuck", "flatbuffers", "hyperlight-common", "hyperlight-guest-tracing", @@ -1590,6 +1602,7 @@ dependencies = [ "bitflags 2.11.1", "blake3", "built", + "bytemuck", "cfg-if", "cfg_aliases", "chrono", diff --git a/Cargo.toml b/Cargo.toml index e693368d0..84b38bb62 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,6 +50,7 @@ hyperlight-component-macro = { path = "src/hyperlight_component_macro", version [workspace.lints.rust] unsafe_op_in_unsafe_fn = "deny" +unexpected_cfgs = { level = "warn", check-cfg = [ 'cfg(loom)' ] } # this will generate symbols for release builds # so is handy for debugging issues in release builds diff --git a/Justfile b/Justfile index bd569fa64..2064cccb8 100644 --- a/Justfile +++ b/Justfile @@ -225,7 +225,7 @@ test-unit target=default-target features="": test-isolated target=default-target features="" : {{ cargo-cmd }} test {{ if features =="" {''} else if features=="no-default-features" {"--no-default-features" } else {"--no-default-features -F " + features } }} --profile={{ if target == "debug" { "dev" } else { target } }} {{ target-triple-flag }} -p hyperlight-host --lib -- sandbox::uninitialized::tests::test_log_trace --exact --ignored {{ cargo-cmd }} test {{ if features =="" {''} else if features=="no-default-features" {"--no-default-features" } else {"--no-default-features -F " + features } }} --profile={{ if target == "debug" { "dev" } else { target } }} {{ target-triple-flag }} -p hyperlight-host --lib -- sandbox::outb::tests::test_log_outb_log --exact --ignored - {{ cargo-cmd }} test {{ if features =="" {''} else if features=="no-default-features" {"--no-default-features" } else {"--no-default-features -F " + features } }} --profile={{ if target == "debug" { "dev" } else { target } }} {{ target-triple-flag }} -p hyperlight-host --test integration_test -- log_message --exact --ignored + {{ cargo-cmd }} test {{ if features =="" {''} else if features=="no-default-features" {"--no-default-features" } else {"--no-default-features -F " + features } }} --profile={{ if target == "debug" { "dev" } else { target } }} {{ target-triple-flag }} -p hyperlight-host --test integration_test -- --test-threads=1 --ignored @# metrics tests {{ cargo-cmd }} test {{ if features =="" {''} else if features=="no-default-features" {"--no-default-features" } else {"--no-default-features -F function_call_metrics," + features } }} --profile={{ if target == "debug" { "dev" } else { target } }} {{ target-triple-flag }} -p hyperlight-host --lib -- metrics::tests::test_metrics_are_emitted --exact diff --git a/fuzz/fuzz_targets/guest_trace.rs b/fuzz/fuzz_targets/guest_trace.rs index 3dfb61c95..43373300e 100644 --- a/fuzz/fuzz_targets/guest_trace.rs +++ b/fuzz/fuzz_targets/guest_trace.rs @@ -69,8 +69,8 @@ impl<'a> Arbitrary<'a> for FuzzInput { fuzz_target!( init: { let mut cfg = SandboxConfiguration::default(); - // In local tests, 256 KiB seemed sufficient for deep recursion - cfg.set_scratch_size(256 * 1024); + // In local tests, 512 KiB seemed sufficient for deep recursion + cfg.set_scratch_size(512 * 1024); let path = simple_guest_for_fuzzing_as_string().expect("Guest Binary Missing"); let u_sbox = UninitializedSandbox::new( GuestBinary::FilePath(path), diff --git a/fuzz/fuzz_targets/host_call.rs b/fuzz/fuzz_targets/host_call.rs index b0d37cf1a..c65b10047 100644 --- a/fuzz/fuzz_targets/host_call.rs +++ b/fuzz/fuzz_targets/host_call.rs @@ -33,9 +33,9 @@ static SANDBOX: OnceLock> = OnceLock::new(); fuzz_target!( init: { let mut cfg = SandboxConfiguration::default(); - cfg.set_output_data_size(64 * 1024); // 64 KB output buffer - cfg.set_input_data_size(64 * 1024); // 64 KB input buffer - cfg.set_scratch_size(512 * 1024); // large scratch region to contain those buffers, any data copies, etc. + cfg.set_g2h_pool_pages(16); // 64 KB / 4096 = 16 pages + cfg.set_h2g_pool_pages(16); // 64 KB / 4096 = 16 pages + cfg.set_scratch_size(512 * 1024); // large scratch region let u_sbox = UninitializedSandbox::new( GuestBinary::FilePath(simple_guest_for_fuzzing_as_string().expect("Guest Binary Missing")), Some(cfg) @@ -61,6 +61,8 @@ fuzz_target!( HyperlightError::GuestError(ErrorCode::HostFunctionError, msg) if msg.contains("The number of arguments to the function is wrong") => {} HyperlightError::ParameterValueConversionFailure(_, _) => {}, HyperlightError::GuestError(ErrorCode::HostFunctionError, msg) if msg.contains("Failed To Convert Parameter Value") => {} + HyperlightError::GuestError(ErrorCode::HostFunctionError, msg) if msg.contains("The parameter value type is unexpected") => {} + HyperlightError::GuestError(ErrorCode::HostFunctionError, msg) if msg.contains("The return value type is unexpected") => {} // any other error should be reported _ => panic!("Guest Aborted with Unexpected Error: {:?}", e), diff --git a/src/hyperlight_common/Cargo.toml b/src/hyperlight_common/Cargo.toml index ad9558bc5..1b688cd54 100644 --- a/src/hyperlight_common/Cargo.toml +++ b/src/hyperlight_common/Cargo.toml @@ -19,6 +19,8 @@ arbitrary = {version = "1.4.2", optional = true, features = ["derive"]} anyhow = { version = "1.0.102", default-features = false } bitflags = "2.10.0" bytemuck = { version = "1.24", features = ["derive"] } +bytes = { version = "1", default-features = false } +fixedbitset = { version = "0.5.7", default-features = false } flatbuffers = { version = "25.12.19", default-features = false } log = "0.4.29" smallvec = "1.15.1" @@ -39,9 +41,18 @@ nanvix-unstable = ["i686-guest"] guest-counter = [] [dev-dependencies] +criterion = "0.8.1" +hyperlight-testing = { workspace = true } quickcheck = "1.0.3" rand = "0.9.2" +[target.'cfg(loom)'.dev-dependencies] +loom = "0.7" + [lib] bench = false # see https://bheisler.github.io/criterion.rs/book/faq.html#cargo-bench-gives-unrecognized-option-errors-for-valid-command-line-options doctest = false # reduce noise in test output + +[[bench]] +name = "buffer_pool" +harness = false diff --git a/src/hyperlight_common/benches/buffer_pool.rs b/src/hyperlight_common/benches/buffer_pool.rs new file mode 100644 index 000000000..80d4f9daa --- /dev/null +++ b/src/hyperlight_common/benches/buffer_pool.rs @@ -0,0 +1,192 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +use std::hint::black_box; + +use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main}; +use hyperlight_common::virtq::{BufferPool, BufferProvider}; + +// Helper to create a pool for benchmarking +fn make_pool(size: usize) -> BufferPool { + let base = 0x10000; + BufferPool::::new(base, size).unwrap() +} + +// Single allocation performance +fn bench_alloc_single(c: &mut Criterion) { + let mut group = c.benchmark_group("alloc_single"); + + for size in [64, 128, 256, 512, 1024, 1500, 4096].iter() { + group.throughput(Throughput::Elements(1)); + group.bench_with_input(BenchmarkId::from_parameter(size), size, |b, &size| { + let pool = make_pool::<256, 4096>(4 * 1024 * 1024); + b.iter(|| { + let alloc = pool.alloc(black_box(size)).unwrap(); + pool.dealloc(alloc).unwrap(); + }); + }); + } + group.finish(); +} + +// LIFO recycling +fn bench_alloc_lifo(c: &mut Criterion) { + let mut group = c.benchmark_group("alloc_lifo"); + + for size in [256, 1500, 4096].iter() { + group.throughput(Throughput::Elements(100)); + group.bench_with_input(BenchmarkId::from_parameter(size), size, |b, &size| { + let pool = make_pool::<256, 4096>(4 * 1024 * 1024); + b.iter(|| { + for _ in 0..100 { + let alloc = pool.alloc(black_box(size)).unwrap(); + pool.dealloc(alloc).unwrap(); + } + }); + }); + } + group.finish(); +} + +// Fragmented allocation worst case +fn bench_alloc_fragmented(c: &mut Criterion) { + let mut group = c.benchmark_group("alloc_fragmented"); + + group.bench_function("fragmented_256", |b| { + let pool = make_pool::<256, 4096>(4 * 1024 * 1024); + + // Create fragmentation pattern: allocate many, free every other + let mut allocations = Vec::new(); + for _ in 0..100 { + allocations.push(pool.alloc(128).unwrap()); + } + for i in (0..100).step_by(2) { + pool.dealloc(allocations[i]).unwrap(); + } + + b.iter(|| { + let alloc = pool.alloc(black_box(256)).unwrap(); + pool.dealloc(alloc).unwrap(); + }); + }); + + group.finish(); +} + +// Realloc operations +fn bench_realloc(c: &mut Criterion) { + let mut group = c.benchmark_group("realloc"); + + // In-place grow (same tier) + group.bench_function("grow_inplace", |b| { + let pool = make_pool::<256, 4096>(4 * 1024 * 1024); + b.iter(|| { + let alloc = pool.alloc(256).unwrap(); + let grown = pool.resize(alloc, black_box(512)).unwrap(); + pool.dealloc(grown).unwrap(); + }); + }); + + // Relocate grow (cross tier) + group.bench_function("grow_relocate", |b| { + let pool = make_pool::<256, 4096>(4 * 1024 * 1024); + b.iter(|| { + let alloc = pool.alloc(128).unwrap(); + // Block in-place growth + let blocker = pool.alloc(256).unwrap(); + let grown = pool.resize(alloc, black_box(1500)).unwrap(); + pool.dealloc(grown).unwrap(); + pool.dealloc(blocker).unwrap(); + }); + }); + + // Shrink + group.bench_function("shrink", |b| { + let pool = make_pool::<256, 4096>(4 * 1024 * 1024); + b.iter(|| { + let alloc = pool.alloc(1500).unwrap(); + let shrunk = pool.resize(alloc, black_box(256)).unwrap(); + pool.dealloc(shrunk).unwrap(); + }); + }); + + group.finish(); +} + +// Free performance +fn bench_free(c: &mut Criterion) { + let mut group = c.benchmark_group("free"); + + for size in [256, 1500, 4096].iter() { + group.bench_with_input(BenchmarkId::from_parameter(size), size, |b, &size| { + let pool = make_pool::<256, 4096>(4 * 1024 * 1024); + b.iter(|| { + let alloc = pool.alloc(size).unwrap(); + pool.dealloc(black_box(alloc)).unwrap(); + }); + }); + } + + group.finish(); +} + +// Cursor optimization +fn bench_last_free_run(c: &mut Criterion) { + let mut group = c.benchmark_group("last_free_run"); + + // With cursor optimization (LIFO) + group.bench_function("lifo_pattern", |b| { + let pool = make_pool::<256, 4096>(4 * 1024 * 1024); + b.iter(|| { + let alloc = pool.alloc(256).unwrap(); + pool.dealloc(alloc).unwrap(); + let alloc2 = pool.alloc(black_box(256)).unwrap(); + pool.dealloc(alloc2).unwrap(); + }); + }); + + // Without cursor benefit (FIFO-like) + group.bench_function("fifo_pattern", |b| { + let pool = make_pool::<256, 4096>(4 * 1024 * 1024); + let mut queue = Vec::new(); + + // Pre-fill queue + for _ in 0..10 { + queue.push(pool.alloc(256).unwrap()); + } + + b.iter(|| { + // FIFO: free oldest, allocate new + let old = queue.remove(0); + pool.dealloc(old).unwrap(); + queue.push(pool.alloc(black_box(256)).unwrap()); + }); + }); + + group.finish(); +} + +criterion_group!( + benches, + bench_alloc_single, + bench_alloc_lifo, + bench_alloc_fragmented, + bench_realloc, + bench_free, + bench_last_free_run, +); + +criterion_main!(benches); diff --git a/src/hyperlight_common/src/arch/aarch64/layout.rs b/src/hyperlight_common/src/arch/aarch64/layout.rs index 20f17026c..9f9c504a6 100644 --- a/src/hyperlight_common/src/arch/aarch64/layout.rs +++ b/src/hyperlight_common/src/arch/aarch64/layout.rs @@ -20,6 +20,6 @@ pub const SNAPSHOT_PT_GVA_MIN: usize = 0xffff_8000_0000_0000; pub const SNAPSHOT_PT_GVA_MAX: usize = 0xffff_80ff_ffff_ffff; pub const MAX_GPA: usize = 0x0000_000f_ffff_ffff; -pub fn min_scratch_size(_input_data_size: usize, _output_data_size: usize) -> usize { +pub fn min_scratch_size(_g2h_num_descs: usize, _h2g_num_descs: usize) -> usize { unimplemented!("min_scratch_size") } diff --git a/src/hyperlight_common/src/arch/amd64/layout.rs b/src/hyperlight_common/src/arch/amd64/layout.rs index 14a9cd62a..12644de6c 100644 --- a/src/hyperlight_common/src/arch/amd64/layout.rs +++ b/src/hyperlight_common/src/arch/amd64/layout.rs @@ -37,8 +37,11 @@ pub const MAX_GPA: usize = 0x0000_000f_ffff_ffff; /// - A page for the smallest possible non-exception stack /// - (up to) 3 pages for mapping that /// - Two pages for the exception stack and metadata -/// - A page-aligned amount of memory for I/O buffers (for now) -pub fn min_scratch_size(input_data_size: usize, output_data_size: usize) -> usize { - (input_data_size + output_data_size).next_multiple_of(crate::vmem::PAGE_SIZE) +/// - A page-aligned amount of memory for virtqueue rings +pub fn min_scratch_size(g2h_num_descs: usize, h2g_num_descs: usize) -> usize { + let g2h_ring_size = crate::virtq::Layout::query_size(g2h_num_descs); + let h2g_ring_size = crate::virtq::Layout::query_size(h2g_num_descs); + + (g2h_ring_size + h2g_ring_size).next_multiple_of(crate::vmem::PAGE_SIZE) + 12 * crate::vmem::PAGE_SIZE } diff --git a/src/hyperlight_common/src/arch/i686/layout.rs b/src/hyperlight_common/src/arch/i686/layout.rs index cdc3af7d1..54a651cc2 100644 --- a/src/hyperlight_common/src/arch/i686/layout.rs +++ b/src/hyperlight_common/src/arch/i686/layout.rs @@ -21,10 +21,6 @@ pub const MAX_GVA: usize = 0xffff_ffff; /// regions are large enough to reach that address. pub const MAX_GPA: usize = 0xFEDF_FFFF; -/// Minimum scratch region size: IO buffers (page-aligned) plus 12 pages -/// for bookkeeping and the exception stack. Page table space is validated -/// separately by `set_pt_size()`. -pub fn min_scratch_size(input_data_size: usize, output_data_size: usize) -> usize { - (input_data_size + output_data_size).next_multiple_of(crate::vmem::PAGE_SIZE) - + 12 * crate::vmem::PAGE_SIZE +pub fn min_scratch_size(_g2h_num_descs: usize, _h2g_num_descs: usize) -> usize { + crate::vmem::PAGE_SIZE } diff --git a/src/hyperlight_common/src/layout.rs b/src/hyperlight_common/src/layout.rs index 1a7ca0880..7448b7d0d 100644 --- a/src/hyperlight_common/src/layout.rs +++ b/src/hyperlight_common/src/layout.rs @@ -33,12 +33,33 @@ pub use arch::{MAX_GPA, MAX_GVA}; ))] pub use arch::{SNAPSHOT_PT_GVA_MAX, SNAPSHOT_PT_GVA_MIN}; -// offsets down from the top of scratch memory for various things pub const SCRATCH_TOP_SIZE_OFFSET: u64 = 0x08; pub const SCRATCH_TOP_ALLOCATOR_OFFSET: u64 = 0x10; pub const SCRATCH_TOP_SNAPSHOT_PT_GPA_BASE_OFFSET: u64 = 0x18; pub const SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET: u64 = 0x20; -pub const SCRATCH_TOP_EXN_STACK_OFFSET: u64 = 0x30; +pub const SCRATCH_TOP_G2H_RING_GVA_OFFSET: u64 = 0x28; +pub const SCRATCH_TOP_H2G_RING_GVA_OFFSET: u64 = 0x30; +pub const SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET: u64 = 0x38; +pub const SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET: u64 = 0x3A; +pub const SCRATCH_TOP_G2H_POOL_PAGES_OFFSET: u64 = 0x3C; +pub const SCRATCH_TOP_H2G_POOL_PAGES_OFFSET: u64 = 0x3E; +pub const SCRATCH_TOP_H2G_POOL_GVA_OFFSET: u64 = 0x48; +pub const SCRATCH_TOP_EXN_STACK_OFFSET: u64 = 0x50; + +const _: () = { + assert!(SCRATCH_TOP_ALLOCATOR_OFFSET >= SCRATCH_TOP_SIZE_OFFSET + 8); + assert!(SCRATCH_TOP_SNAPSHOT_PT_GPA_BASE_OFFSET >= SCRATCH_TOP_ALLOCATOR_OFFSET + 8); + assert!(SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET >= SCRATCH_TOP_SNAPSHOT_PT_GPA_BASE_OFFSET + 8); + assert!(SCRATCH_TOP_G2H_RING_GVA_OFFSET >= SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET + 8); + assert!(SCRATCH_TOP_H2G_RING_GVA_OFFSET >= SCRATCH_TOP_G2H_RING_GVA_OFFSET + 8); + assert!(SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET >= SCRATCH_TOP_H2G_RING_GVA_OFFSET + 8); + assert!(SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET >= SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET + 2); + assert!(SCRATCH_TOP_G2H_POOL_PAGES_OFFSET >= SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET + 2); + assert!(SCRATCH_TOP_H2G_POOL_PAGES_OFFSET >= SCRATCH_TOP_G2H_POOL_PAGES_OFFSET + 2); + assert!(SCRATCH_TOP_H2G_POOL_GVA_OFFSET >= SCRATCH_TOP_H2G_POOL_PAGES_OFFSET + 8); + assert!(SCRATCH_TOP_EXN_STACK_OFFSET >= SCRATCH_TOP_H2G_POOL_GVA_OFFSET + 8); + assert!(SCRATCH_TOP_EXN_STACK_OFFSET % 0x10 == 0); +}; /// Offset from the top of scratch memory for a shared host-guest u64 counter. /// @@ -56,5 +77,27 @@ pub fn scratch_base_gva(size: usize) -> u64 { (MAX_GVA - size + 1) as u64 } +pub const fn scratch_top_ptr(offset: u64) -> *mut T { + (MAX_GVA as u64 - offset + 1) as *mut T +} + +/// Compute the byte offset from the scratch base to the G2H ring. +/// +/// The G2H ring starts at offset 0, aligned to descriptor alignment. +pub const fn g2h_ring_scratch_offset() -> usize { + 0 +} + +/// Compute the byte offset from the scratch base to the H2G ring. +/// +/// The H2G ring follows immediately after the G2H ring, aligned to +/// descriptor alignment. +pub const fn h2g_ring_scratch_offset(g2h_num_descs: usize) -> usize { + let g2h_size = crate::virtq::Layout::query_size(g2h_num_descs); + let align = crate::virtq::Descriptor::ALIGN; + + (g2h_size + align - 1) & !(align - 1) +} + /// Compute the minimum scratch region size needed for a sandbox. pub use arch::min_scratch_size; diff --git a/src/hyperlight_common/src/mem.rs b/src/hyperlight_common/src/mem.rs index fb850acc8..1cdd65cef 100644 --- a/src/hyperlight_common/src/mem.rs +++ b/src/hyperlight_common/src/mem.rs @@ -68,8 +68,6 @@ impl Default for FileMappingInfo { #[derive(Debug, Clone, Copy)] #[repr(C)] pub struct HyperlightPEB { - pub input_stack: GuestMemoryRegion, - pub output_stack: GuestMemoryRegion, pub init_data: GuestMemoryRegion, pub guest_heap: GuestMemoryRegion, /// File mappings array descriptor. diff --git a/src/hyperlight_common/src/outb.rs b/src/hyperlight_common/src/outb.rs index 3bfb99848..0f9c25e00 100644 --- a/src/hyperlight_common/src/outb.rs +++ b/src/hyperlight_common/src/outb.rs @@ -105,6 +105,8 @@ pub enum OutBAction { TraceMemoryAlloc = 105, #[cfg(feature = "mem_profile")] TraceMemoryFree = 106, + /// Notification that entries are available on a virtqueue. + VirtqNotify = 109, } /// IO-port actions intercepted at the hypervisor level (in `run_vcpu`) @@ -137,6 +139,7 @@ impl TryFrom for OutBAction { 105 => Ok(OutBAction::TraceMemoryAlloc), #[cfg(feature = "mem_profile")] 106 => Ok(OutBAction::TraceMemoryFree), + 109 => Ok(OutBAction::VirtqNotify), _ => Err(anyhow::anyhow!("Invalid OutBAction value: {}", val)), } } diff --git a/src/hyperlight_common/src/virtq/buffer.rs b/src/hyperlight_common/src/virtq/buffer.rs new file mode 100644 index 000000000..bcdbea0b6 --- /dev/null +++ b/src/hyperlight_common/src/virtq/buffer.rs @@ -0,0 +1,199 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Buffer allocation traits and shared types for virtqueue buffer management. + +use alloc::rc::Rc; +use alloc::sync::Arc; + +use thiserror::Error; + +use super::access::MemOps; + +#[derive(Debug, Error, Copy, Clone)] +pub enum AllocError { + #[error("Invalid region addr {0}")] + InvalidAlign(u64), + #[error("Invalid free addr {0} and size {1}")] + InvalidFree(u64, usize), + #[error("Invalid argument")] + InvalidArg, + #[error("Empty region")] + EmptyRegion, + #[error("No space available")] + NoSpace, + #[error("Requested size exceeds pool capacity")] + OutOfMemory, + #[error("Overflow")] + Overflow, +} + +/// Allocation result +#[derive(Debug, Clone, Copy)] +pub struct Allocation { + /// Starting address of the allocation + pub addr: u64, + /// Length of the allocation in bytes rounded up to slab size + pub len: usize, +} + +/// Trait for buffer providers. +pub trait BufferProvider { + /// Allocate at least `len` bytes. + fn alloc(&self, len: usize) -> Result; + + /// Free a previously allocated block. + fn dealloc(&self, alloc: Allocation) -> Result<(), AllocError>; + + /// Resize by trying in-place grow; otherwise reserve a new block and free old. + fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result; + + /// Reset the pool to initial state. + fn reset(&self) {} +} + +impl BufferProvider for Rc { + fn alloc(&self, len: usize) -> Result { + (**self).alloc(len) + } + fn dealloc(&self, alloc: Allocation) -> Result<(), AllocError> { + (**self).dealloc(alloc) + } + fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result { + (**self).resize(old_alloc, new_len) + } + fn reset(&self) { + (**self).reset() + } +} + +impl BufferProvider for Arc { + fn alloc(&self, len: usize) -> Result { + (**self).alloc(len) + } + fn dealloc(&self, alloc: Allocation) -> Result<(), AllocError> { + (**self).dealloc(alloc) + } + fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result { + (**self).resize(old_alloc, new_len) + } + fn reset(&self) { + (**self).reset() + } +} + +/// The owner of a mapped buffer, ensuring its lifetime. +/// +/// Holds a [`PoolAlloc`] and provides direct access to the underlying +/// shared memory via [`MemOps::as_slice`]. Implements `AsRef<[u8]>` so it +/// can be used with [`Bytes::from_owner`](bytes::Bytes::from_owner) for +/// zero-copy `Bytes` backed by shared memory. +/// +/// When dropped, the allocation is returned to the pool. +#[derive(Debug)] +pub struct BufferOwner { + pub(crate) alloc: PoolAlloc

, + pub(crate) mem: M, + pub(crate) written: usize, +} + +impl AsRef<[u8]> for BufferOwner { + fn as_ref(&self) -> &[u8] { + let alloc = self.alloc.allocation(); + let len = self.written.min(alloc.len); + // Safety: BufferOwner keeps both the pool allocation and the M alive, + // so the memory region is valid. + match unsafe { self.mem.as_slice(alloc.addr, len) } { + Ok(slice) => slice, + Err(_) => { + debug_assert!(false, "BufferOwner direct slice failed"); + &[] + } + } + } +} + +/// Pool-owned allocation that is returned to the pool on drop. +/// +/// Use [`into_raw`](Self::into_raw) to transfer ownership to a descriptor +/// state that will deallocate the raw [`Allocation`] through another path. +#[derive(Debug)] +pub struct PoolAlloc { + inner: Option>, +} + +#[derive(Debug)] +struct PoolAllocInner { + pool: P, + alloc: Allocation, +} + +impl PoolAlloc

{ + /// Wrap an existing allocation with its owning pool. + pub fn new(pool: P, alloc: Allocation) -> Self { + Self { + inner: Some(PoolAllocInner { pool, alloc }), + } + } + + /// Allocate from `pool` and return an owning guard. + pub fn allocate(pool: P, len: usize) -> Result { + let alloc = pool.alloc(len)?; + Ok(Self::new(pool, alloc)) + } + + /// The raw allocation currently owned by this guard. + pub fn allocation(&self) -> Allocation { + self.inner + .as_ref() + .map(|inner| inner.alloc) + .unwrap_or_else(|| { + unreachable!("PoolAlloc::allocation called after ownership transfer") + }) + } + + /// Release ownership and return the raw allocation. + pub fn into_raw(mut self) -> Allocation { + self.inner + .take() + .map(|inner| inner.alloc) + .unwrap_or_else(|| unreachable!("PoolAlloc::into_raw called after ownership transfer")) + } + + pub(crate) fn into_buffer_owner( + self, + mem: M, + written: usize, + ) -> Result, M::Error> { + let alloc = self.allocation(); + let len = written.min(alloc.len); + let _ = unsafe { mem.as_slice(alloc.addr, len) }?; + + Ok(BufferOwner { + alloc: self, + mem, + written, + }) + } +} + +impl Drop for PoolAlloc

{ + fn drop(&mut self) { + if let Some(PoolAllocInner { pool, alloc }) = self.inner.take() { + let _ = pool.dealloc(alloc); + } + } +} diff --git a/src/hyperlight_common/src/virtq/consumer.rs b/src/hyperlight_common/src/virtq/consumer.rs new file mode 100644 index 000000000..d3da1020c --- /dev/null +++ b/src/hyperlight_common/src/virtq/consumer.rs @@ -0,0 +1,680 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +use alloc::vec; + +use bytes::Bytes; +use fixedbitset::FixedBitSet; + +use super::*; + +/// Data received from the producer, safely copied out of shared memory. +/// +/// Created by [`VirtqConsumer::poll`]. The entry data is eagerly copied +/// from shared memory during poll using [`MemOps::read`] (volatile on +/// the host side), so accessing it requires no unsafe code and no +/// references into shared memory. +#[derive(Debug, Clone)] +pub struct RecvEntry { + token: Token, + data: Bytes, +} + +impl RecvEntry { + /// The token identifying this entry. + pub fn token(&self) -> Token { + self.token + } + + /// The entry payload, copied from shared memory. + /// + /// Returns empty [`Bytes`] when the chain has no readable buffers. + pub fn data(&self) -> &Bytes { + &self.data + } + + /// Consume the entry, taking ownership of the data. + pub fn into_data(self) -> Bytes { + self.data + } +} + +/// A pending completion, either writable or ack-only. +/// +/// Created by [`VirtqConsumer::poll`]. Must be submitted back via +/// [`VirtqConsumer::complete`] to release the descriptor. +#[must_use = "dropping without completing leaks the descriptor"] +pub enum SendCompletion { + /// Completion with a writable buffer (for chains with a completion buffer). + /// Use the `write*` methods on [`WritableCompletion`] to fill the + /// response buffer. + Writable(WritableCompletion), + /// Ack-only completion (for chains with only entry buffers). No response buffer. + /// Just pass back to [`VirtqConsumer::complete`] to acknowledge. + Ack(AckCompletion), +} + +impl SendCompletion { + /// The token identifying this completion. + pub fn token(&self) -> Token { + match self { + SendCompletion::Writable(wc) => wc.token(), + SendCompletion::Ack(ack) => ack.token(), + } + } + + /// Number of bytes written (0 for Ack). + pub fn written(&self) -> usize { + match self { + SendCompletion::Writable(wc) => wc.written(), + SendCompletion::Ack(_) => 0, + } + } + + fn id(&self) -> u16 { + match self { + SendCompletion::Writable(wc) => wc.id, + SendCompletion::Ack(ack) => ack.id, + } + } +} + +/// A completion with a writable buffer for response data. +/// +/// # Example +/// +/// ```ignore +/// if let SendCompletion::Writable(mut wc) = completion { +/// wc.write_all(b"response data")?; +/// consumer.complete(wc.into())?; +/// } +/// ``` +#[must_use = "dropping without completing leaks the descriptor"] +pub struct WritableCompletion { + mem: M, + id: u16, + token: Token, + elem: BufferElement, + written: usize, +} + +impl WritableCompletion { + fn new(mem: M, id: u16, token: Token, elem: BufferElement) -> Self { + Self { + mem, + id, + token, + elem, + written: 0, + } + } + + /// The token identifying this completion. + pub fn token(&self) -> Token { + self.token + } + + /// Total capacity of the completion buffer in bytes. + pub fn capacity(&self) -> usize { + self.elem.len as usize + } + + /// Number of bytes written so far. + pub fn written(&self) -> usize { + self.written + } + + /// Remaining writable capacity. + pub fn remaining(&self) -> usize { + self.capacity() - self.written + } + + /// Write bytes into the completion buffer, returning how many were written. + /// + /// Appends at the current write position. If `buf` is larger than the + /// remaining capacity, writes as many bytes as will fit (partial write). + /// + /// Returns the number of bytes actually written. + /// + /// # Errors + /// + /// - [`VirtqError::MemoryWriteError`] - underlying MemOps write failed + pub fn write(&mut self, buf: &[u8]) -> Result { + let to_write = buf.len().min(self.remaining()); + if to_write == 0 { + return Ok(0); + } + + let addr = self.elem.addr + self.written as u64; + self.mem + .write(addr, &buf[..to_write]) + .map_err(|_| VirtqError::MemoryWriteError)?; + + self.written += to_write; + Ok(to_write) + } + + /// Write the entire buffer or return an error. + /// + /// # Errors + /// + /// - [`VirtqError::CqeTooLarge`] - buf exceeds remaining capacity + /// - [`VirtqError::MemoryWriteError`] - underlying MemOps write failed + pub fn write_all(&mut self, buf: &[u8]) -> Result<(), VirtqError> { + if buf.len() > self.remaining() { + return Err(VirtqError::CqeTooLarge); + } + + let addr = self.elem.addr + self.written as u64; + self.mem + .write(addr, buf) + .map_err(|_| VirtqError::MemoryWriteError)?; + + self.written += buf.len(); + Ok(()) + } + + /// Reset the write cursor to the beginning. + /// + /// Previously written bytes in shared memory are not zeroed; the + /// `written` count is simply reset to 0. + pub fn reset(&mut self) { + self.written = 0; + } +} + +/// An ack-only completion for chains with no writable buffers. +/// +/// No response buffer - just pass back to [`VirtqConsumer::complete`] +/// to acknowledge processing and release the descriptor. +#[must_use = "dropping without completing leaks the descriptor"] +pub struct AckCompletion { + id: u16, + token: Token, +} + +impl AckCompletion { + fn new(id: u16, token: Token) -> Self { + Self { id, token } + } + + pub fn token(&self) -> Token { + self.token + } +} + +/// A high-level virtqueue consumer (device side). +/// +/// The consumer receives entries from the producer (driver), processes them, +/// and sends back completions. This is typically used on the device/host side. +/// +/// # Example +/// +/// ```ignore +/// let mut consumer = VirtqConsumer::new(layout, mem, notifier); +/// +/// // Poll and process +/// while let Some((entry, completion)) = consumer.poll(MAX_ENTRY_SIZE)? { +/// let data = entry.data(); +/// match completion { +/// SendCompletion::Writable(mut wc) => { +/// let response = handle_request(data); +/// wc.write_all(&response)?; +/// consumer.complete(wc.into())?; +/// } +/// SendCompletion::Ack(ack) => { +/// consumer.complete(ack.into())?; +/// } +/// } +/// } +/// +/// // Or defer completions +/// let mut pending = Vec::new(); +/// while let Some((entry, completion)) = consumer.poll(MAX_ENTRY_SIZE)? { +/// pending.push((process(entry), completion)); +/// } +/// for (result, completion) in pending { +/// // ... complete later ... +/// consumer.complete(completion)?; +/// } +/// ``` +pub struct VirtqConsumer { + inner: RingConsumer, + notifier: N, + inflight: FixedBitSet, + next_token: u32, +} + +impl VirtqConsumer { + /// Create a new virtqueue consumer. + /// + /// # Arguments + /// + /// * `layout` - Ring memory layout (descriptor table and event suppression addresses) + /// * `mem` - Memory operations implementation for reading/writing to shared memory + /// * `notifier` - Callback for notifying the driver (producer) about completions + pub fn new(layout: Layout, mem: M, notifier: N) -> Self { + let inner = RingConsumer::new(layout, mem); + let inflight = FixedBitSet::with_capacity(inner.len()); + + Self { + inner, + notifier, + inflight, + next_token: 0, + } + } + + /// Poll for a single incoming entry from the driver. + /// + /// Returns a [`RecvEntry`] (data copied from shared memory) and a + /// [`SendCompletion`] (writable handle or ack token). Both are + /// independent owned values with no borrow on the consumer. + /// + /// # Arguments + /// + /// * `max_entry` - Maximum entry size to accept. Entries larger than + /// this will return [`VirtqError::EntryTooLarge`]. + /// + /// # Errors + /// + /// - [`VirtqError::EntryTooLarge`] - Entry data exceeds `max_entry` bytes + /// - [`VirtqError::BadChain`] - Descriptor chain format not recognized + /// - [`VirtqError::InvalidState`] - Descriptor ID collision (driver bug) + /// - [`VirtqError::MemoryReadError`] - Failed to read entry from shared memory + pub fn poll( + &mut self, + max_entry: usize, + ) -> Result)>, VirtqError> { + let (id, chain) = match self.inner.poll_available() { + Ok(x) => x, + Err(RingError::WouldBlock) => return Ok(None), + Err(e) => return Err(e.into()), + }; + + let (entry_elem, cqe_elem) = parse_chain(&chain)?; + + // Validate entry size + if let Some(ref elem) = entry_elem + && elem.len as usize > max_entry + { + return Err(VirtqError::EntryTooLarge); + } + + // Reserve the inflight slot + let id_idx = id as usize; + if id_idx >= self.inflight.len() { + return Err(VirtqError::InvalidState); + } + + if self.inflight.contains(id_idx) { + return Err(VirtqError::InvalidState); + } + + self.inflight.insert(id_idx); + let token = Token(self.next_token, id); + self.next_token = self.next_token.wrapping_add(1); + + // Copy entry data from shared memory + let data = match entry_elem.map(|elem| self.read_element(&elem)).transpose() { + Ok(d) => d.unwrap_or_default(), + Err(e) => { + // Read failed - clear inflight before propagating + self.inflight.set(id_idx, false); + return Err(e); + } + }; + + let entry = RecvEntry { token, data }; + + // Build the appropriate completion handle + let completion = if let Some(elem) = cqe_elem { + let mem = self.inner.mem().clone(); + let cqe = WritableCompletion::new(mem, id, token, elem); + SendCompletion::Writable(cqe) + } else { + let ack = AckCompletion::new(id, token); + SendCompletion::Ack(ack) + }; + + Ok(Some((entry, completion))) + } + + /// Submit a completed entry back to the ring. + /// + /// Accepts both [`WritableCompletion`] (with written byte count) and + /// [`AckCompletion`] (zero-length) via the [`SendCompletion`] enum. + /// Clears the inflight slot and notifies the producer if event + /// suppression allows. + pub fn complete(&mut self, completion: SendCompletion) -> Result<(), VirtqError> { + let id = completion.id(); + let written = completion.written() as u32; + + let id_idx = id as usize; + let slot_set = id_idx < self.inflight.len() && self.inflight.contains(id_idx); + if !slot_set { + return Err(VirtqError::InvalidState); + } + + self.inflight.set(id_idx, false); + + if self.inner.submit_used_with_notify(id, written)? { + self.notifier.notify(QueueStats { + num_free: self.inner.num_free(), + num_inflight: self.inner.num_inflight(), + }); + } + + Ok(()) + } + + /// Get the current available cursor position. + /// + /// Returns the position where the next available descriptor will be + /// consumed. Useful for setting up descriptor-based event suppression. + #[inline] + pub fn avail_cursor(&self) -> RingCursor { + self.inner.avail_cursor() + } + + /// Get the current used cursor position. + /// + /// Returns the position where the next used descriptor will be written. + /// Useful for setting up descriptor-based event suppression. + #[inline] + pub fn used_cursor(&self) -> RingCursor { + self.inner.used_cursor() + } + + /// Configure event suppression for available buffer notifications. + /// + /// This controls when the driver (producer) signals us about new buffers: + /// + /// - [`SuppressionKind::Enable`] - Always signal (default) - good for latency + /// - [`SuppressionKind::Disable`] - Never signal - caller must poll + /// - [`SuppressionKind::Descriptor`] - Signal only at specific cursor position + /// + /// # Example: Polling Mode + /// ```ignore + /// consumer.set_avail_suppression(SuppressionKind::Disable)?; + /// loop { + /// while let Some((entry, completion)) = consumer.poll(1024)? { + /// process(entry, completion); + /// } + /// // ... do other work ... + /// } + /// ``` + pub fn set_avail_suppression(&mut self, kind: SuppressionKind) -> Result<(), VirtqError> { + match kind { + SuppressionKind::Enable => self.inner.enable_avail_notifications()?, + SuppressionKind::Disable => self.inner.disable_avail_notifications()?, + SuppressionKind::Descriptor(cursor) => self + .inner + .enable_avail_notifications_desc(cursor.head(), cursor.wrap())?, + } + Ok(()) + } + + /// Read a buffer element from shared memory into `Bytes`. + fn read_element(&self, elem: &BufferElement) -> Result { + let mut buf = vec![0u8; elem.len as usize]; + self.inner + .mem() + .read(elem.addr, &mut buf) + .map_err(|_| VirtqError::MemoryReadError)?; + + Ok(Bytes::from(buf)) + } + + /// Reset ring and inflight state to initial values. + pub fn reset(&mut self) { + self.inner.reset(); + self.inflight.clear(); + } +} + +/// Parse a descriptor chain into entry/completion buffer elements. +/// +/// Returns `(entry_element, completion_element)`. +fn parse_chain( + chain: &BufferChain, +) -> Result<(Option, Option), VirtqError> { + let r = chain.readables(); + let w = chain.writables(); + + match (r.len(), w.len()) { + (1, 1) => Ok((Some(r[0]), Some(w[0]))), + (0, 1) => Ok((None, Some(w[0]))), + (1, 0) => Ok((Some(r[0]), None)), + _ => Err(VirtqError::BadChain), + } +} + +impl From> for SendCompletion { + fn from(wc: WritableCompletion) -> Self { + SendCompletion::Writable(wc) + } +} + +impl From for SendCompletion { + fn from(ack: AckCompletion) -> Self { + SendCompletion::Ack(ack) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::virtq::ring::tests::make_ring; + use crate::virtq::test_utils::*; + + #[test] + fn test_write_only_entry_is_empty() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let se = producer.chain().completion(16).build().unwrap(); + producer.submit(se).unwrap(); + + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert!(entry.data().is_empty()); + assert!(matches!(completion, SendCompletion::Writable(_))); + + if let SendCompletion::Writable(mut wc) = completion { + wc.write_all(b"response").unwrap(); + consumer.complete(wc.into()).unwrap(); + } + } + + #[test] + fn test_read_only_ack_completion() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let mut se = producer.chain().entry(16).build().unwrap(); + se.write_all(b"hello").unwrap(); + producer.submit(se).unwrap(); + + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.data().as_ref(), b"hello"); + assert!(matches!(completion, SendCompletion::Ack(_))); + + consumer.complete(completion).unwrap(); + } + + #[test] + fn test_readwrite_round_trip() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let mut se = producer.chain().entry(32).completion(64).build().unwrap(); + se.write_all(b"hello world").unwrap(); + producer.submit(se).unwrap(); + + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.data().as_ref(), b"hello world"); + + if let SendCompletion::Writable(mut wc) = completion { + assert_eq!(wc.capacity(), 64); + assert_eq!(wc.written(), 0); + assert_eq!(wc.remaining(), 64); + wc.write_all(b"response").unwrap(); + assert_eq!(wc.written(), 8); + assert_eq!(wc.remaining(), 56); + consumer.complete(wc.into()).unwrap(); + } else { + panic!("expected Writable completion for entry+completion chain"); + } + } + + #[test] + fn test_writable_partial_write() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let se = producer.chain().completion(8).build().unwrap(); + producer.submit(se).unwrap(); + + let (_entry, completion) = consumer.poll(1024).unwrap().unwrap(); + + if let SendCompletion::Writable(mut wc) = completion { + let n = wc.write(b"hello world!").unwrap(); + assert_eq!(n, 8); + assert_eq!(wc.remaining(), 0); + consumer.complete(wc.into()).unwrap(); + } else { + panic!("expected Writable"); + } + } + + #[test] + fn test_writable_write_all_too_large() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let se = producer.chain().completion(4).build().unwrap(); + producer.submit(se).unwrap(); + let (_entry, completion) = consumer.poll(1024).unwrap().unwrap(); + + if let SendCompletion::Writable(mut wc) = completion { + let err = wc.write_all(b"too long").unwrap_err(); + assert!(matches!(err, VirtqError::CqeTooLarge)); + } else { + panic!("expected Writable"); + } + } + + #[test] + fn test_writable_reset() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let se = producer.chain().completion(16).build().unwrap(); + producer.submit(se).unwrap(); + + let (_entry, completion) = consumer.poll(1024).unwrap().unwrap(); + + if let SendCompletion::Writable(mut wc) = completion { + wc.write_all(b"first").unwrap(); + assert_eq!(wc.written(), 5); + wc.reset(); + assert_eq!(wc.written(), 0); + assert_eq!(wc.remaining(), 16); + wc.write_all(b"second").unwrap(); + assert_eq!(wc.written(), 6); + consumer.complete(wc.into()).unwrap(); + } else { + panic!("expected Writable"); + } + } + + #[test] + fn test_multiple_pending_completions() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let se1 = producer.chain().completion(16).build().unwrap(); + producer.submit(se1).unwrap(); + let se2 = producer.chain().completion(16).build().unwrap(); + producer.submit(se2).unwrap(); + + let (_e1, c1) = consumer.poll(1024).unwrap().unwrap(); + let (_e2, c2) = consumer.poll(1024).unwrap().unwrap(); + + // Complete in reverse order + consumer.complete(c2).unwrap(); + consumer.complete(c1).unwrap(); + } + + #[test] + fn test_entry_into_data() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let mut se = producer.chain().entry(16).build().unwrap(); + se.write_all(b"abc").unwrap(); + producer.submit(se).unwrap(); + + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + let data = entry.into_data(); + assert_eq!(data.as_ref(), b"abc"); + consumer.complete(completion).unwrap(); + } + + #[test] + fn test_virtq_consumer_reset() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + // Submit and poll (but do not complete) + let se = producer.chain().completion(16).build().unwrap(); + producer.submit(se).unwrap(); + + let (_entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert!(consumer.inflight.count_ones(..) > 0); + + // Complete first so we do not leak + consumer.complete(completion).unwrap(); + + consumer.reset(); + + assert_eq!(consumer.inflight.count_ones(..), 0); + assert_eq!(consumer.inner.num_inflight(), 0); + } + + #[test] + fn test_virtq_consumer_reset_clears_inflight() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + // Submit two entries and poll both + let se1 = producer.chain().completion(16).build().unwrap(); + producer.submit(se1).unwrap(); + let se2 = producer.chain().completion(16).build().unwrap(); + producer.submit(se2).unwrap(); + + let (_e1, c1) = consumer.poll(1024).unwrap().unwrap(); + let (_e2, c2) = consumer.poll(1024).unwrap().unwrap(); + // Complete both before reset + consumer.complete(c1).unwrap(); + consumer.complete(c2).unwrap(); + + consumer.reset(); + + assert_eq!(consumer.inflight.count_ones(..), 0); + assert_eq!(consumer.inner.num_inflight(), 0); + } +} diff --git a/src/hyperlight_common/src/virtq/mod.rs b/src/hyperlight_common/src/virtq/mod.rs index 326aac933..ed6d7f2cb 100644 --- a/src/hyperlight_common/src/virtq/mod.rs +++ b/src/hyperlight_common/src/virtq/mod.rs @@ -14,14 +14,20 @@ See the License for the specific language governing permissions and limitations under the License. */ -//! Packed Virtqueue - Ring Primitives +//! Packed Virtqueue Implementation //! -//! This module provides low-level ring primitives for virtio packed virtqueues, -//! implementing the VIRTIO 1.1+ packed ring format with proper memory ordering -//! and event suppression support. +//! This module provides a high-level API for virtio packed virtqueues, built on top of +//! the lower-level ring primitives. It implements the VIRTIO 1.1+ packed ring format +//! with proper memory ordering and event suppression support. //! //! # Architecture //! +//! The implementation is split into layers: +//! +//! - **High-level API** ([`VirtqProducer`], [`VirtqConsumer`]): Manages buffer allocation, +//! entry/completion lifecycle, and notification decisions. This is the recommended API +//! for most use cases. +//! //! - **Ring primitives** ([`RingProducer`], [`RingConsumer`]): Low-level descriptor ring //! operations with explicit buffer chain management. Use this when you need full control //! over buffer layouts or custom allocation strategies. @@ -29,11 +35,112 @@ limitations under the License. //! - **Descriptor and event types** ([`Descriptor`], [`EventSuppression`]): Raw virtio //! data structures for direct memory manipulation. //! -//! - **Memory access** ([`MemOps`]): Trait abstracting memory read/write operations, -//! allowing the ring to work with different memory backends (host vs guest). +//! # Quick Start +//! +//! ## Single Entry/Completion +//! +//! ```ignore +//! // Producer (driver) side - build entry, submit, get completion +//! let mut entry = producer.chain() +//! .entry(64) +//! .completion(128) +//! .build()?; +//! entry.write_all(b"entry data")?; +//! let token = producer.submit(entry)?; +//! // ... wait for notification ... +//! if let Some(completion) = producer.poll()? { +//! process(completion.data); +//! } +//! +//! // Consumer (device) side - receive entry, send completion +//! if let Some((entry, completion)) = consumer.poll(max_request_size)? { +//! let request = entry.data(); +//! match completion { +//! SendCompletion::Writable(mut wc) => { +//! let response = handle(request); +//! wc.write_all(&response)?; +//! consumer.complete(wc.into())?; +//! } +//! SendCompletion::Ack(ack) => { +//! consumer.complete(ack.into())?; +//! } +//! } +//! } +//! +//! // Multiple pending completions (no borrow on consumer) +//! let mut pending = Vec::new(); +//! while let Some((entry, completion)) = consumer.poll(max_request_size)? { +//! pending.push((process(entry), completion)); +//! } +//! for (result, completion) in pending { +//! consumer.complete(completion)?; +//! } +//! ``` +//! +//! ## Multiple Entries +//! +//! Each submit checks event suppression and notifies independently. Use +//! [`VirtqProducer::batch`] when a higher-level protocol wants to publish +//! multiple entries and kick the queue once. +//! +//! ```ignore +//! let mut batch = producer.batch(); +//! for data in entries { +//! let mut se = batch.chain() +//! .entry(data.len()) +//! .completion(64) +//! .build()?; +//! se.write_all(data)?; +//! batch.submit(se)?; +//! } +//! batch.finish()?; +//! ``` +//! +//! ## Completion Batching with Event Suppression +//! +//! To receive a single notification when multiple requests complete: +//! +//! ```ignore +//! // Submit entries +//! for data in entries { +//! let mut se = producer.chain() +//! .entry(data.len()) +//! .completion(64) +//! .build()?; +//! se.write_all(data)?; +//! producer.submit(se)?; +//! } +//! +//! // Tell device: "notify me only after completing past this cursor" +//! let cursor = producer.used_cursor(); +//! producer.set_used_suppression(SuppressionKind::Descriptor(cursor))?; +//! +//! // Wait for single notification, then drain all responses +//! producer.drain(|token, data| { +//! handle_response(token, data); +//! })?; +//! ``` +//! +//! # Event Suppression +//! +//! Both sides can control when they want to be notified using [`SuppressionKind`]: +//! +//! - [`SuppressionKind::Enable`]: Always notify (default, lowest latency) +//! - [`SuppressionKind::Disable`]: Never notify (polling mode, lowest overhead) +//! - [`SuppressionKind::Descriptor`]: Notify at specific ring position (batching) +//! +//! See [`VirtqProducer::set_used_suppression`] and [`VirtqConsumer::set_avail_suppression`]. //! //! # Low-Level API //! +//! For advanced use cases, the ring module exposes lower-level primitives: +//! +//! - [`RingProducer`] / [`RingConsumer`]: Direct ring access with [`BufferChain`] submission +//! - [`BufferChainBuilder`]: Construct scatter-gather buffer lists +//! - [`RingCursor`]: Track ring positions for event suppression +//! +//! Example using low-level API: +//! //! ```ignore //! let chain = BufferChainBuilder::new() //! .readable(header_addr, header_len) @@ -48,16 +155,87 @@ limitations under the License. //! ``` mod access; +mod buffer; +mod consumer; mod desc; mod event; +pub mod msg; +mod pool; +mod producer; mod ring; use core::num::NonZeroU16; pub use access::*; +pub use buffer::*; +pub use consumer::*; pub use desc::*; pub use event::*; +pub use pool::*; +pub use producer::*; pub use ring::*; +use thiserror::Error; + +/// A trait for notifying the consumer about virtqueue events. +pub trait Notifier { + fn notify(&self, stats: QueueStats); +} + +/// Errors that can occur in the virtqueue operations. +#[derive(Error, Debug)] +pub enum VirtqError { + #[error("Ring error: {0}")] + RingError(RingError), + #[error("Allocation error: {0}")] + Alloc(AllocError), + #[error("Ring or pool temporarily full")] + Backpressure, + #[error("Allocation exceeds pool capacity")] + OutOfMemory, + #[error("Invalid token")] + BadToken, + #[error("Invalid chain received")] + BadChain, + #[error("Entry data too large for allocated buffer")] + EntryTooLarge, + #[error("Completion data too large for allocated buffer")] + CqeTooLarge, + #[error("Internal state error")] + InvalidState, + #[error("Memory write error")] + MemoryWriteError, + #[error("Memory read error")] + MemoryReadError, + #[error("No readable buffer in this entry")] + NoReadableBuffer, +} + +impl VirtqError { + /// Check if this error is transient or unrecoverable. + #[inline(always)] + pub fn is_transient(&self) -> bool { + matches!(self, Self::Backpressure) + } +} + +impl From for VirtqError { + fn from(e: RingError) -> Self { + match e { + RingError::WouldBlock => Self::Backpressure, + other => Self::RingError(other), + } + } +} + +impl From for VirtqError { + fn from(e: AllocError) -> Self { + match e { + AllocError::NoSpace => Self::Backpressure, + AllocError::OutOfMemory => Self::OutOfMemory, + other => Self::Alloc(other), + } + } +} /// Layout of a packed virtqueue ring in shared memory. /// @@ -166,6 +344,49 @@ impl Layout { } } +/// Statistics about the current virtqueue state. +/// +/// Provided to the [`Notifier`] when sending notifications, allowing +/// the notifier to make decisions based on queue pressure. +#[derive(Debug, Clone, Copy)] +pub struct QueueStats { + /// Number of free descriptor slots available. + pub num_free: usize, + /// Number of descriptors currently in-flight (submitted but not completed). + pub num_inflight: usize, +} + +/// Event suppression mode for controlling when notifications are sent. +/// +/// This configures when the other side should signal (interrupt/kick) us +/// about new data. Used to optimize batching and reduce interrupt overhead. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SuppressionKind { + /// Always signal after each operation (default behavior). + Enable, + /// Never signal. + Disable, + /// Signal only when reaching a specific descriptor position. + Descriptor(RingCursor), +} + +/// A token representing a sent entry in the virtqueue. +/// +/// Tokens uniquely identify in-flight requests and are used to correlate requests with their responses. +/// The first element is a monotonically increasing generation counter. The second element is the +/// underlying descriptor ID +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub struct Token(pub u32, pub u16); + +impl From for Allocation { + fn from(value: BufferElement) -> Self { + Allocation { + addr: value.addr, + len: value.len as usize, + } + } +} + const _: () = { #[allow(clippy::unwrap_used)] const fn verify_layout(num_descs: usize) { @@ -219,3 +440,979 @@ const _: () = { verify_layout(512); verify_layout(1024); }; + +/// Shared test utilities for virtqueue tests. +#[cfg(test)] +pub(crate) mod test_utils { + use alloc::sync::Arc; + use core::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; + + use super::*; + use crate::virtq::ring::tests::{OwnedRing, TestMem}; + + /// Simple notifier that tracks notification count. + #[derive(Debug, Clone)] + pub(crate) struct TestNotifier { + pub(crate) count: Arc, + } + + impl TestNotifier { + pub(crate) fn new() -> Self { + Self { + count: Arc::new(AtomicUsize::new(0)), + } + } + + pub(crate) fn notification_count(&self) -> usize { + self.count.load(Ordering::Relaxed) + } + } + + impl Notifier for TestNotifier { + fn notify(&self, _stats: QueueStats) { + self.count.fetch_add(1, Ordering::Relaxed); + } + } + + /// Simple test buffer pool that allocates from a range. + #[derive(Clone)] + pub(crate) struct TestPool { + base: u64, + next: Arc, + size: usize, + } + + impl TestPool { + pub(crate) fn new(base: u64, size: usize) -> Self { + Self { + base, + next: Arc::new(AtomicU64::new(base)), + size, + } + } + } + + impl BufferProvider for TestPool { + fn alloc(&self, len: usize) -> Result { + let addr = self.next.fetch_add(len as u64, Ordering::Relaxed); + let end = addr + len as u64; + if end > self.base + self.size as u64 { + return Err(AllocError::NoSpace); + } + Ok(Allocation { addr, len }) + } + + fn dealloc(&self, _alloc: Allocation) -> Result<(), AllocError> { + // Simple pool doesn't track individual allocations + Ok(()) + } + + fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result { + // Simple implementation: always allocate new + self.dealloc(old_alloc)?; + self.alloc(new_len) + } + } + + type TestProducer = VirtqProducer; + type TestConsumer = VirtqConsumer; + + /// Create test infrastructure: a producer, consumer, and notifier backed + /// by the supplied [`OwnedRing`]. + pub(crate) fn make_test_producer( + ring: &OwnedRing, + ) -> (TestProducer, TestConsumer, TestNotifier) { + let layout = ring.layout(); + let mem = ring.mem(); + + // Pool needs to be in memory accessible via mem - use memory after ring layout + let pool_base = mem.base_addr() + Layout::query_size(ring.len()) as u64 + 0x100; + let pool = TestPool::new(pool_base, 0x8000); + let notifier = TestNotifier::new(); + + let producer = VirtqProducer::new(layout, mem.clone(), notifier.clone(), pool); + let consumer = VirtqConsumer::new(layout, mem, notifier.clone()); + + (producer, consumer, notifier) + } +} + +#[cfg(test)] +mod tests { + use alloc::sync::Arc; + use core::sync::atomic::{AtomicUsize, Ordering}; + + use super::*; + use crate::virtq::ring::tests::{TestMem, make_ring}; + use crate::virtq::test_utils::*; + + /// Helper: build and submit an entry+completion chain using the chain() builder. + fn send_readwrite( + producer: &mut VirtqProducer, + entry_data: &[u8], + cqe_cap: usize, + ) -> Token { + let mut se = producer + .chain() + .entry(entry_data.len()) + .completion(cqe_cap) + .build() + .unwrap(); + se.write_all(entry_data).unwrap(); + producer.submit(se).unwrap() + } + + #[test] + fn test_submit_notifies() { + let ring = make_ring(16); + let (mut producer, mut consumer, notifier) = make_test_producer(&ring); + + let initial_count = notifier.notification_count(); + + let token = send_readwrite(&mut producer, b"hello", 64); + assert!(notifier.notification_count() > initial_count); + + let (entry, _completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.token(), token); + } + + #[test] + fn test_multiple_submits() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let tok1 = send_readwrite(&mut producer, b"request1", 64); + let tok2 = send_readwrite(&mut producer, b"request2", 64); + let tok3 = send_readwrite(&mut producer, b"request3", 64); + + // Consumer sees all requests + for _ in 0..3 { + let (_entry, completion) = consumer.poll(1024).unwrap().unwrap(); + consumer.complete(completion).unwrap(); + } + + // All completions available + let cqe1 = producer.poll().unwrap().unwrap(); + let cqe2 = producer.poll().unwrap().unwrap(); + let cqe3 = producer.poll().unwrap().unwrap(); + assert!( + [cqe1.token, cqe2.token, cqe3.token].contains(&tok1) + && [cqe1.token, cqe2.token, cqe3.token].contains(&tok2) + && [cqe1.token, cqe2.token, cqe3.token].contains(&tok3) + ); + } + + #[test] + fn test_completion_batching_with_suppression() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + // Submit entries + let tok1 = send_readwrite(&mut producer, b"req1", 64); + let tok2 = send_readwrite(&mut producer, b"req2", 64); + let tok3 = send_readwrite(&mut producer, b"req3", 64); + + // Set up completion batching via used suppression + let cursor = producer.used_cursor(); + producer + .set_used_suppression(SuppressionKind::Descriptor(cursor)) + .unwrap(); + + // Consumer processes requests + for _ in 0..3 { + let (_entry, completion) = consumer.poll(1024).unwrap().unwrap(); + let SendCompletion::Writable(mut wc) = completion else { + panic!("expected writable completion"); + }; + wc.write_all(b"cqe-data").unwrap(); + consumer.complete(wc.into()).unwrap(); + } + + // Producer can drain all responses + let mut responses = Vec::new(); + producer + .drain(|tok, _data| { + responses.push(tok); + }) + .unwrap(); + + assert_eq!(responses.len(), 3); + assert!(responses.contains(&tok1)); + assert!(responses.contains(&tok2)); + assert!(responses.contains(&tok3)); + } + + #[test] + fn test_notifier_receives_context() { + #[derive(Debug, Clone)] + struct CtxNotifier { + last_num_free: Arc, + last_num_inflight: Arc, + count: Arc, + } + + impl Notifier for CtxNotifier { + fn notify(&self, stats: QueueStats) { + self.last_num_free.store(stats.num_free, Ordering::Relaxed); + self.last_num_inflight + .store(stats.num_inflight, Ordering::Relaxed); + self.count.fetch_add(1, Ordering::Relaxed); + } + } + + let ring = make_ring(16); + let layout = ring.layout(); + let mem = ring.mem(); + let pool_base = mem.base_addr() + Layout::query_size(ring.len()) as u64 + 0x100; + let pool = TestPool::new(pool_base, 0x8000); + let notifier = CtxNotifier { + last_num_free: Arc::new(AtomicUsize::new(0)), + last_num_inflight: Arc::new(AtomicUsize::new(0)), + count: Arc::new(AtomicUsize::new(0)), + }; + + let mut producer = VirtqProducer::new(layout, mem, notifier.clone(), pool); + + let mut se = producer.chain().entry(4).completion(32).build().unwrap(); + se.write_all(b"test").unwrap(); + producer.submit(se).unwrap(); + assert_eq!(notifier.count.load(Ordering::Relaxed), 1); + assert!(notifier.last_num_inflight.load(Ordering::Relaxed) > 0); + } + + #[test] + fn test_chain_zero_copy_batch() { + let ring = make_ring(16); + let (mut producer, mut consumer, notifier) = make_test_producer(&ring); + + let initial_count = notifier.notification_count(); + + // Zero-copy entry via buf_mut + let mut se1 = producer.chain().entry(64).completion(128).build().unwrap(); + let buf = se1.buf_mut().unwrap(); + buf[..6].copy_from_slice(b"zc-ent"); + se1.set_written(6).unwrap(); + let _tok1 = producer.submit(se1).unwrap(); + + // Write-based entry + let mut se2 = producer.chain().entry(64).completion(64).build().unwrap(); + se2.write_all(b"copy-ent").unwrap(); + let _tok2 = producer.submit(se2).unwrap(); + + // Completion-only chain + let se3 = producer.chain().completion(32).build().unwrap(); + let tok3 = producer.submit(se3).unwrap(); + + // Each submit may notify independently + assert!(notifier.notification_count() > initial_count); + + // Consumer sees all three entries + let (entry1, completion1) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry1.data().as_ref(), b"zc-ent"); + consumer.complete(completion1).unwrap(); + + let (entry2, completion2) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry2.data().as_ref(), b"copy-ent"); + consumer.complete(completion2).unwrap(); + + let (_entry3, completion3) = consumer.poll(1024).unwrap().unwrap(); + let SendCompletion::Writable(mut wc) = completion3 else { + panic!("expected writable completion"); + }; + wc.write_all(b"resp").unwrap(); + consumer.complete(wc.into()).unwrap(); + + // Drain completions + let _ = producer.poll().unwrap().unwrap(); + let _ = producer.poll().unwrap().unwrap(); + + let cqe = producer.poll().unwrap().unwrap(); + assert_eq!(cqe.token, tok3); + assert_eq!(&cqe.data[..], b"resp"); + } + + #[test] + fn test_chain_zero_copy_send() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + // Zero-copy send: allocate, write directly, submit + let mut se = producer.chain().entry(64).completion(128).build().unwrap(); + let buf = se.buf_mut().unwrap(); + assert_eq!(buf.len(), 64); + buf[..5].copy_from_slice(b"hello"); + se.set_written(5).unwrap(); + let token = producer.submit(se).unwrap(); + + // Consumer sees the data + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.token(), token); + assert_eq!(entry.data().as_ref(), b"hello"); + + // Write response + let SendCompletion::Writable(mut wc) = completion else { + panic!("expected writable completion"); + }; + wc.write_all(b"world").unwrap(); + consumer.complete(wc.into()).unwrap(); + let cqe = producer.poll().unwrap().unwrap(); + assert_eq!(&cqe.data[..], b"world"); + } + + #[test] + fn test_full_round_trip() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + // Send an entry + let token = send_readwrite(&mut producer, b"round-trip-entry", 128); + + // Consumer receives and responds + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.token(), token); + assert_eq!(entry.data().as_ref(), b"round-trip-entry"); + + let SendCompletion::Writable(mut wc) = completion else { + panic!("expected writable completion"); + }; + assert!(wc.capacity() >= 128); + wc.write_all(b"round-trip-rsp").unwrap(); + consumer.complete(wc.into()).unwrap(); + + // Producer gets the completion + let cqe = producer.poll().unwrap().unwrap(); + assert_eq!(cqe.token, token); + assert_eq!(&cqe.data[..], b"round-trip-rsp"); + } + + #[test] + fn test_cancel_submits_zero_length() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let token = send_readwrite(&mut producer, b"entry-data", 64); + + let (_entry, completion) = consumer.poll(1024).unwrap().unwrap(); + consumer.complete(completion).unwrap(); + + let cqe = producer.poll().unwrap().unwrap(); + assert_eq!(cqe.token, token); + assert_eq!(cqe.data.len(), 0); + assert!(cqe.data.is_empty()); + } + + #[test] + fn test_hold_completion_and_complete() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let token = send_readwrite(&mut producer, b"deferred", 64); + + // Poll and hold the completion + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.token(), token); + assert_eq!(entry.data().as_ref(), b"deferred"); + + let SendCompletion::Writable(mut wc) = completion else { + panic!("expected writable completion"); + }; + wc.write_all(b"deferred-cqe").unwrap(); + consumer.complete(wc.into()).unwrap(); + + let cqe = producer.poll().unwrap().unwrap(); + assert_eq!(cqe.token, token); + assert_eq!(&cqe.data[..], b"deferred-cqe"); + } + + #[test] + fn test_concurrent_pending_completions() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let tok1 = send_readwrite(&mut producer, b"first", 64); + let tok2 = send_readwrite(&mut producer, b"second", 64); + + // Poll both + let (entry1, completion1) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry1.token(), tok1); + assert_eq!(entry1.data().as_ref(), b"first"); + + let (entry2, completion2) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry2.token(), tok2); + assert_eq!(entry2.data().as_ref(), b"second"); + + // Complete second first (out of order) + let SendCompletion::Writable(mut wc2) = completion2 else { + panic!("expected writable"); + }; + wc2.write_all(b"resp2").unwrap(); + consumer.complete(wc2.into()).unwrap(); + + let SendCompletion::Writable(mut wc1) = completion1 else { + panic!("expected writable"); + }; + wc1.write_all(b"resp1").unwrap(); + consumer.complete(wc1.into()).unwrap(); + + let cqe1 = producer.poll().unwrap().unwrap(); + let cqe2 = producer.poll().unwrap().unwrap(); + let mut responses: Vec<_> = vec![ + (cqe1.token, cqe1.data.to_vec()), + (cqe2.token, cqe2.data.to_vec()), + ]; + responses.sort_by_key(|(t, _)| t.0); + + let expected_first = responses.iter().find(|(t, _)| *t == tok1).unwrap(); + let expected_second = responses.iter().find(|(t, _)| *t == tok2).unwrap(); + assert_eq!(&expected_first.1[..], b"resp1"); + assert_eq!(&expected_second.1[..], b"resp2"); + } + + /// Helper: submit a ReadOnly entry (entry data, no completion). + fn send_readonly( + producer: &mut VirtqProducer, + entry_data: &[u8], + ) -> Token { + let mut se = producer.chain().entry(entry_data.len()).build().unwrap(); + se.write_all(entry_data).unwrap(); + producer.submit(se).unwrap() + } + + #[test] + fn test_reclaim_frees_ring_slots() { + let ring = make_ring(4); + let (mut producer, mut consumer, _) = make_test_producer(&ring); + + // Fill the ring with ReadOnly entries + send_readonly(&mut producer, b"a"); + send_readonly(&mut producer, b"b"); + send_readonly(&mut producer, b"c"); + send_readonly(&mut producer, b"d"); + + // Ring is now full - next submit should fail with Backpressure + let mut se = producer.chain().entry(1).build().unwrap(); + se.write_all(b"e").unwrap(); + let res = producer.submit(se); + assert!( + matches!(res, Err(VirtqError::Backpressure)), + "expected Backpressure from full ring" + ); + + // Consumer acks all entries + while let Some((_, completion)) = consumer.poll(1024).unwrap() { + consumer.complete(completion).unwrap(); + } + + // Reclaim should free ring slots without losing data + let count = producer.reclaim().unwrap(); + assert_eq!(count, 4, "expected 4 reclaimed entries"); + + // Ring should have space now + send_readonly(&mut producer, b"e"); + } + + #[test] + fn test_reclaim_buffers_rw_completions() { + let ring = make_ring(4); + let (mut producer, mut consumer, _) = make_test_producer(&ring); + + // Submit a ReadWrite entry + let tok = send_readwrite(&mut producer, b"request", 64); + + // Consumer processes and writes response + let (_, completion) = consumer.poll(1024).unwrap().unwrap(); + let SendCompletion::Writable(mut wc) = completion else { + panic!("expected writable"); + }; + wc.write_all(b"response-data").unwrap(); + consumer.complete(wc.into()).unwrap(); + + // Reclaim buffers the completion (doesn't discard it) + let count = producer.reclaim().unwrap(); + assert_eq!(count, 1); + + // poll() should return the buffered completion + let cqe = producer.poll().unwrap().unwrap(); + assert_eq!(cqe.token, tok); + assert_eq!(&cqe.data[..], b"response-data"); + } + + #[test] + fn test_reclaim_discards_readonly_completions() { + let ring = make_ring(8); + let (mut producer, mut consumer, _) = make_test_producer(&ring); + + // Submit 3 entries: RO, RW, RO + let _tok_ro1 = send_readonly(&mut producer, b"log1"); + let tok_rw = send_readwrite(&mut producer, b"call", 64); + let _tok_ro2 = send_readonly(&mut producer, b"log2"); + + // Consumer processes all 3 + let (_, c1) = consumer.poll(1024).unwrap().unwrap(); + consumer.complete(c1).unwrap(); // ack RO + + let (_, c2) = consumer.poll(1024).unwrap().unwrap(); + let SendCompletion::Writable(mut wc) = c2 else { + panic!("expected writable"); + }; + wc.write_all(b"result").unwrap(); + consumer.complete(wc.into()).unwrap(); // complete RW + + let (_, c3) = consumer.poll(1024).unwrap().unwrap(); + consumer.complete(c3).unwrap(); // ack RO + + // Reclaim all 3 - RO completions are discarded, only RW is buffered + let count = producer.reclaim().unwrap(); + assert_eq!(count, 3); + + // poll() returns only the RW completion + let cqe = producer.poll().unwrap().unwrap(); + assert_eq!(cqe.token, tok_rw); + assert_eq!(&cqe.data[..], b"result"); + + // No more - RO completions were discarded + assert!(producer.poll().unwrap().is_none()); + } + + #[test] + fn test_reclaim_mixed_with_poll() { + let ring = make_ring(8); + let (mut producer, mut consumer, _) = make_test_producer(&ring); + + // Submit and complete 2 entries + send_readonly(&mut producer, b"x"); + let tok_rw = send_readwrite(&mut producer, b"y", 64); + + let (_, c1) = consumer.poll(1024).unwrap().unwrap(); + consumer.complete(c1).unwrap(); + + let (_, c2) = consumer.poll(1024).unwrap().unwrap(); + let SendCompletion::Writable(mut wc) = c2 else { + panic!("expected writable"); + }; + wc.write_all(b"reply").unwrap(); + consumer.complete(wc.into()).unwrap(); + + // poll() consumes first entry directly from ring + let cqe1 = producer.poll().unwrap().unwrap(); + assert!(cqe1.data.is_empty()); + + // reclaim() buffers second entry + let count = producer.reclaim().unwrap(); + assert_eq!(count, 1); + + // poll() returns the buffered one + let cqe2 = producer.poll().unwrap().unwrap(); + assert_eq!(cqe2.token, tok_rw); + assert_eq!(&cqe2.data[..], b"reply"); + } + + /// reclaim + submit must not cause token collisions. + #[test] + fn test_reclaim_submit_no_token_collision() { + let ring = make_ring(8); + let (mut producer, mut consumer, _) = make_test_producer(&ring); + + // Submit and complete a ReadOnly entry + let tok_old = send_readonly(&mut producer, b"log"); + + let (_, c) = consumer.poll(1024).unwrap().unwrap(); + consumer.complete(c).unwrap(); + + let count = producer.reclaim().unwrap(); + assert_eq!(count, 1); + + // Submit a new ReadWrite entry - may reuse the same descriptor ID + let tok_new = send_readwrite(&mut producer, b"call", 64); + + // Tokens must differ even if the descriptor ID was recycled + assert_ne!( + tok_old, tok_new, + "tokens must be unique across reclaim/submit cycles" + ); + + // Complete the ReadWrite entry + let (_, c) = consumer.poll(1024).unwrap().unwrap(); + let SendCompletion::Writable(mut wc) = c else { + panic!("expected writable"); + }; + wc.write_all(b"result").unwrap(); + consumer.complete(wc.into()).unwrap(); + + // Poll returns only the RW completion (RO was discarded by reclaim) + let cqe = producer.poll().unwrap().unwrap(); + assert_eq!(cqe.token, tok_new); + assert_eq!(&cqe.data[..], b"result"); + + // No stale RO completion in the queue + assert!(producer.poll().unwrap().is_none()); + } + + /// Verify that repeated oneshot submit/reclaim cycles do not accumulate pending completions. + #[test] + fn test_reclaim_readonly_does_not_leak_pending() { + let ring = make_ring(4); + let (mut producer, mut consumer, _) = make_test_producer(&ring); + + for _ in 0..10 { + // Fill the ring + for _ in 0..4 { + send_readonly(&mut producer, b"msg"); + } + + // Consumer acks all + while let Some((_, completion)) = consumer.poll(1024).unwrap() { + consumer.complete(completion).unwrap(); + } + + // Reclaim frees ring slots; empty completions are discarded + let count = producer.reclaim().unwrap(); + assert_eq!(count, 4); + + // No completions should be buffered in pending + assert!( + producer.poll().unwrap().is_none(), + "pending should be empty after reclaiming RO entries" + ); + } + } +} +#[cfg(all(test, loom))] +mod fuzz { + //! Loom-based concurrency testing for the virtqueue implementation. + //! + //! Loom will explores all possible thread interleavings to find data races + //! and other concurrency bugs. However, it has specific requirements that + //! make our memory model more involved: + //! + //! ## Flag-Based Synchronization + //! + //! The virtqueue protocol uses flag-based synchronization: + //! 1. Producer writes descriptor fields (addr, len, id), then writes flags with release semantics + //! 2. Consumer reads flags with acquire semantics, then reads descriptor fields + //! + //! Loom would see this as concurrent access to the same memory and report a race, even though + //! acquire/release on flags provides proper synchronization. + //! + //! ## Shadow Atomics for Flags + //! + //! We maintain shadow atomics that loom tracks for synchronization: + //! + //! - `desc_flags`: One `AtomicU16` per descriptor for flags field + //! - `drv_flags`: `AtomicU16` for driver event suppression flags + //! - `dev_flags`: `AtomicU16` for device event suppression flags + //! + //! The `load_acquire`/`store_release` operations use these loom atomics, + //! while `read`/`write` access the underlying data directly. + //! + //! ## Memory Regions + //! + //! We use a `BTreeMap` to map addresses to memory regions: + //! - `Desc(idx)`: Individual descriptors in the ring + //! - `DrvEvt`: Driver event suppression structure + //! - `DevEvt`: Device event suppression structure + //! - `Pool`: Buffer pool for entry/completion data + + use alloc::collections::BTreeMap; + use alloc::sync::Arc; + use alloc::vec; + use core::num::NonZeroU16; + + use bytemuck::Zeroable; + use loom::sync::atomic::{AtomicU16, AtomicUsize, Ordering}; + use loom::thread; + + use super::*; + use crate::virtq::desc::Descriptor; + use crate::virtq::pool::BufferPoolSync; + + #[derive(Debug)] + pub struct MemErr; + + #[derive(Debug, Clone, Copy)] + enum RegionKind { + Desc(usize), + DrvEvt, + DevEvt, + Pool, + } + + #[derive(Debug, Clone, Copy)] + struct RegionInfo { + kind: RegionKind, + size: usize, + } + + #[derive(Debug)] + pub struct LoomMem { + descs: Vec, + drv: core::cell::UnsafeCell, + dev: core::cell::UnsafeCell, + pool: loom::cell::UnsafeCell>, + + desc_flags: Vec, + drv_flags: AtomicU16, + dev_flags: AtomicU16, + + regions: BTreeMap, + layout: Layout, + } + + unsafe impl Sync for LoomMem {} + unsafe impl Send for LoomMem {} + + impl LoomMem { + pub fn new(ring_base: u64, num_descs: usize, pool_base: u64, pool_size: usize) -> Self { + let descs_nz = NonZeroU16::new(num_descs as u16).unwrap(); + let layout = unsafe { Layout::from_base(ring_base, descs_nz).unwrap() }; + + let descs: Vec<_> = (0..num_descs).map(|_| Descriptor::zeroed()).collect(); + let desc_flags: Vec<_> = (0..num_descs).map(|_| AtomicU16::new(0)).collect(); + + let mut regions = BTreeMap::new(); + + // Register each descriptor as a separate region + for i in 0..num_descs { + let addr = layout.desc_table_addr() + (i * Descriptor::SIZE) as u64; + regions.insert( + addr, + RegionInfo { + kind: RegionKind::Desc(i), + size: Descriptor::SIZE, + }, + ); + } + + regions.insert( + layout.drv_evt_addr(), + RegionInfo { + kind: RegionKind::DrvEvt, + size: EventSuppression::SIZE, + }, + ); + + regions.insert( + layout.dev_evt_addr(), + RegionInfo { + kind: RegionKind::DevEvt, + size: EventSuppression::SIZE, + }, + ); + + regions.insert( + pool_base, + RegionInfo { + kind: RegionKind::Pool, + size: pool_size, + }, + ); + + Self { + descs, + drv: core::cell::UnsafeCell::new(EventSuppression::zeroed()), + dev: core::cell::UnsafeCell::new(EventSuppression::zeroed()), + pool: loom::cell::UnsafeCell::new(vec![0u8; pool_size]), + desc_flags, + drv_flags: AtomicU16::new(0), + dev_flags: AtomicU16::new(0), + regions, + layout, + } + } + + pub fn layout(&self) -> Layout { + self.layout + } + + fn region(&self, addr: u64) -> Option<(RegionInfo, usize)> { + let (&base, &info) = self.regions.range(..=addr).next_back()?; + let offset = (addr - base) as usize; + + if offset < info.size { + Some((info, offset)) + } else { + None + } + } + + fn desc_ptr(&self, idx: usize) -> *mut Descriptor { + self.descs.as_ptr().cast_mut().wrapping_add(idx) + } + } + + unsafe impl MemOps for LoomMem { + type Error = MemErr; + + fn read(&self, addr: u64, dst: &mut [u8]) -> Result<(), Self::Error> { + let (info, offset) = self.region(addr).ok_or(MemErr)?; + + match info.kind { + RegionKind::Desc(idx) => { + let desc = unsafe { &*self.desc_ptr(idx) }; + let bytes = bytemuck::bytes_of(desc); + dst.copy_from_slice(&bytes[offset..offset + dst.len()]); + } + RegionKind::DrvEvt => { + let evt = unsafe { &*self.drv.get() }; + let bytes = bytemuck::bytes_of(evt); + dst.copy_from_slice(&bytes[offset..offset + dst.len()]); + } + RegionKind::DevEvt => { + let evt = unsafe { &*self.dev.get() }; + let bytes = bytemuck::bytes_of(evt); + dst.copy_from_slice(&bytes[offset..offset + dst.len()]); + } + RegionKind::Pool => { + self.pool.with(|buf| { + dst.copy_from_slice(&(unsafe { &*buf })[offset..offset + dst.len()]); + }); + } + } + Ok(()) + } + + fn write(&self, addr: u64, src: &[u8]) -> Result<(), Self::Error> { + let (info, offset) = self.region(addr).ok_or(MemErr)?; + + match info.kind { + RegionKind::Desc(idx) => { + let desc = unsafe { &mut *self.desc_ptr(idx) }; + let bytes = bytemuck::bytes_of_mut(desc); + bytes[offset..offset + src.len()].copy_from_slice(src); + } + RegionKind::DrvEvt => { + let evt = unsafe { &mut *self.drv.get() }; + let bytes = bytemuck::bytes_of_mut(evt); + bytes[offset..offset + src.len()].copy_from_slice(src); + } + RegionKind::DevEvt => { + let evt = unsafe { &mut *self.dev.get() }; + let bytes = bytemuck::bytes_of_mut(evt); + bytes[offset..offset + src.len()].copy_from_slice(src); + } + RegionKind::Pool => { + self.pool.with_mut(|buf| { + (unsafe { &mut *buf })[offset..offset + src.len()].copy_from_slice(src); + }); + } + } + Ok(()) + } + + fn load_acquire(&self, addr: u64) -> Result { + let (info, _offset) = self.region(addr).ok_or(MemErr)?; + + Ok(match info.kind { + RegionKind::Desc(idx) => self.desc_flags[idx].load(Ordering::Acquire), + RegionKind::DrvEvt => self.drv_flags.load(Ordering::Acquire), + RegionKind::DevEvt => self.dev_flags.load(Ordering::Acquire), + RegionKind::Pool => return Err(MemErr), + }) + } + + fn store_release(&self, addr: u64, val: u16) -> Result<(), Self::Error> { + let (info, _offset) = self.region(addr).ok_or(MemErr)?; + + match info.kind { + RegionKind::Desc(idx) => self.desc_flags[idx].store(val, Ordering::Release), + RegionKind::DrvEvt => self.drv_flags.store(val, Ordering::Release), + RegionKind::DevEvt => self.dev_flags.store(val, Ordering::Release), + RegionKind::Pool => return Err(MemErr), + } + Ok(()) + } + + unsafe fn as_slice(&self, addr: u64, len: usize) -> Result<&[u8], Self::Error> { + let (info, offset) = self.region(addr).ok_or(MemErr)?; + + match info.kind { + RegionKind::Pool => { + // Safety: pool memory is a contiguous Vec; caller ensures + // no concurrent writes for the lifetime of the returned slice. + let buf = unsafe { &*self.pool.get() }; + Ok(&buf[offset..offset + len]) + } + _ => Err(MemErr), + } + } + + unsafe fn as_mut_slice(&self, addr: u64, len: usize) -> Result<&mut [u8], Self::Error> { + let (info, offset) = self.region(addr).ok_or(MemErr)?; + + match info.kind { + RegionKind::Pool => { + let buf = unsafe { &mut *self.pool.get() }; + Ok(&mut buf[offset..offset + len]) + } + _ => Err(MemErr), + } + } + } + + #[derive(Debug)] + pub struct Notify { + kicks: AtomicUsize, + } + + impl Notify { + pub fn new() -> Self { + Self { + kicks: AtomicUsize::new(0), + } + } + } + + impl Notifier for Arc { + fn notify(&self, _stats: QueueStats) { + self.kicks.fetch_add(1, Ordering::Relaxed); + } + } + + #[test] + fn virtq_ping_pong() { + loom::model(|| { + let ring_base = 0x10000; + let pool_base = 0x40000; + let pool_size = 0x10000; + + let mem = Arc::new(LoomMem::new(ring_base, 8, pool_base, pool_size)); + let pool = BufferPoolSync::<256, 4096>::new(pool_base, pool_size).unwrap(); + let notify = Arc::new(Notify::new()); + + let mut prod = VirtqProducer::new(mem.layout(), mem.clone(), notify.clone(), pool); + let mut cons = VirtqConsumer::new(mem.layout(), mem.clone(), notify.clone()); + + let t_prod = thread::spawn(move || { + let mut se = prod.chain().entry(4).completion(32).build().unwrap(); + se.write_all(b"ping").unwrap(); + let tok = prod.submit(se).unwrap(); + loop { + if let Some(r) = prod.poll().unwrap() { + assert_eq!(r.token, tok); + assert_eq!(&r.data[..], b"pong"); + break; + } + thread::yield_now(); + } + }); + + let t_cons = thread::spawn(move || { + let (entry, completion) = loop { + if let Some(r) = cons.poll(1024).unwrap() { + break r; + } + thread::yield_now(); + }; + assert_eq!(entry.data().as_ref(), b"ping"); + let SendCompletion::Writable(mut wc) = completion else { + panic!("expected writable completion"); + }; + wc.write_all(b"pong").unwrap(); + cons.complete(wc.into()).unwrap(); + }); + + t_prod.join().unwrap(); + t_cons.join().unwrap(); + }); + } +} diff --git a/src/hyperlight_common/src/virtq/msg.rs b/src/hyperlight_common/src/virtq/msg.rs new file mode 100644 index 000000000..090c2eb5b --- /dev/null +++ b/src/hyperlight_common/src/virtq/msg.rs @@ -0,0 +1,120 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Wire format header for all virtqueue messages. +//! +//! Every payload on both the G2H and H2G queues starts with this +//! fixed 8-byte header, enabling message type discrimination and +//! request/response correlation. + +use bitflags::bitflags; + +/// Message types for the virtqueue wire protocol. +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MsgKind { + /// A function call request (FunctionCall payload follows). + Request = 0x01, + /// A function call response (FunctionCallResult payload follows). + Response = 0x02, + /// A stream data chunk. + StreamChunk = 0x03, + /// End-of-stream marker. + StreamEnd = 0x04, + /// Cancel a pending request. + Cancel = 0x05, + /// A guest log message (GuestLogData payload follows). + Log = 0x06, +} + +impl TryFrom for MsgKind { + type Error = u8; + + fn try_from(value: u8) -> Result { + match value { + 0x01 => Ok(Self::Request), + 0x02 => Ok(Self::Response), + 0x03 => Ok(Self::StreamChunk), + 0x04 => Ok(Self::StreamEnd), + 0x05 => Ok(Self::Cancel), + 0x06 => Ok(Self::Log), + other => Err(other), + } + } +} + +bitflags! { + #[repr(transparent)] + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] + pub struct MsgFlags: u8 { + /// More descriptors follow for this message. + const MORE = 1 << 0; + } +} + +/// Wire header for all virtqueue messages +#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)] +#[repr(C)] +pub struct VirtqMsgHeader { + /// Discriminates the message type. + pub kind: u8, + /// Per-message flags (see [`MsgFlags`]). + pub flags: u8, + /// Caller-assigned correlation ID. Responses echo the request's ID. + pub req_id: u16, + /// Byte length of the payload following this header in this descriptor. + pub payload_len: u32, +} + +impl VirtqMsgHeader { + pub const SIZE: usize = core::mem::size_of::(); + + /// Create a new message header with no flags set. + pub const fn new(kind: MsgKind, req_id: u16, payload_len: u32) -> Self { + Self { + kind: kind as u8, + flags: 0, + req_id, + payload_len, + } + } + + /// Create a new header with flags. + pub const fn with_flags(kind: MsgKind, flags: MsgFlags, req_id: u16, payload_len: u32) -> Self { + Self { + kind: kind as u8, + flags: flags.bits(), + req_id, + payload_len, + } + } + + /// Parse the kind field into a [`MsgKind`] enum. + pub fn msg_kind(&self) -> Result { + MsgKind::try_from(self.kind) + } + + /// Interpret the raw flags field as [`MsgFlags`]. + pub fn msg_flags(&self) -> MsgFlags { + MsgFlags::from_bits_truncate(self.flags) + } + + /// Returns true if [`MsgFlags::MORE`] is set, indicating more + /// descriptors follow for this message. + pub const fn has_more(&self) -> bool { + self.flags & MsgFlags::MORE.bits() != 0 + } +} diff --git a/src/hyperlight_common/src/virtq/pool.rs b/src/hyperlight_common/src/virtq/pool.rs new file mode 100644 index 000000000..e6dc2ae34 --- /dev/null +++ b/src/hyperlight_common/src/virtq/pool.rs @@ -0,0 +1,1724 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +//! Buffer pool implementations for virtqueue buffer management. +//! +//! This module provides concrete buffer allocators: +//! +//! - [`BufferPool`] - a two-tier bitmap pool with small and large slabs, +//! intended for G2H descriptors where allocation sizes vary. +//! - [`RecyclePool`] - a fixed-size free-list recycler for H2G prefill +//! entries where all buffers are the same size. +//! +//! Both implement [`BufferProvider`] from the [`super::buffer`] module. +//! +//! # BufferPool design +//! +//! The core allocation strategy is a bitmap allocator that performs a linear +//! search over the bitmap, but implemented via `fixedbitset`'s SIMD iteration +//! over zero bits. +//! +//! # Two-tier layout +//! +//! [`BufferPool`] divides the underlying region into two slabs with different +//! slot sizes: +//! +//! - The lower tier (`Slab`, default `L = 256`) is intended for +//! *smaller allocations* - control messages, descriptor metadata, and other +//! small structures. Small allocations first try this tier. +//! - The upper tier (`Slab`, default `U = 4096`) uses page sized slots +//! and is intended for larger data buffers. + +use alloc::rc::Rc; +use core::cell::RefCell; +use core::cmp::Ordering; +use core::ops::Deref; + +use fixedbitset::FixedBitSet; +use smallvec::SmallVec; + +use super::buffer::{AllocError, Allocation, BufferProvider}; + +/// Wrapper asserting `Send + Sync` for single-threaded contexts. +/// +/// # Safety +/// +/// The wrapped value must only be accessed from a single thread. +#[derive(Debug)] +pub(super) struct SyncWrap(pub(super) T); + +// SAFETY: The wrapped value must only be accessed from a single thread. +unsafe impl Send for SyncWrap {} +// SAFETY: The wrapped value must only be accessed from a single thread. +unsafe impl Sync for SyncWrap {} + +impl Clone for SyncWrap { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +impl Deref for SyncWrap { + type Target = T; + fn deref(&self) -> &T { + &self.0 + } +} + +#[derive(Debug, Clone)] +pub struct Slab { + /// Base address of the slab + base_addr: u64, + /// Flat bitmap to track allocated/free slots + used_slots: FixedBitSet, + /// Last free allocation cache + last_free_run: Option, +} + +impl Slab { + /// Create a new slab allocator over a fixed region. + /// Region is rounded down to a multiple of N. + pub fn new(base_addr: u64, region_len: usize) -> Result { + let usable = region_len - (region_len % N); + let num_slots = usable / N; + let used_slots = FixedBitSet::with_capacity(num_slots); + + if base_addr % (N as u64) != 0 { + return Err(AllocError::InvalidAlign(base_addr)); + } + + if num_slots == 0 { + return Err(AllocError::EmptyRegion); + } + + Ok(Self { + base_addr, + used_slots, + last_free_run: None, + }) + } + + /// Get the address of a slot by its index + #[inline] + fn addr_of(&self, slot_idx: usize) -> Option { + self.base_addr + .checked_add((slot_idx as u64).checked_mul(N as u64)?) + } + + /// Get the slot index for a given address + #[inline] + fn slot_of(&self, addr: u64) -> usize { + let off = (addr - self.base_addr) as usize; + off / N + } + + /// Invalidate last_free_run cache if it overlaps with the given allocation. + fn maybe_invalidate_last_run(&mut self, alloc: Allocation) { + if let Some(run) = &self.last_free_run { + let new_end = alloc.addr + alloc.len as u64; + let run_end = run.addr + run.len as u64; + + if alloc.addr < run_end && run.addr < new_end { + self.last_free_run = None; + } + } + } + + /// Find a run of slots to satisfy at least `len` bytes starting at `start`. + pub fn find_slots(&mut self, slots_num: usize) -> Option { + debug_assert!(slots_num > 0); + + // Check last free run optimization + if let Some(alloc) = self.last_free_run + && alloc.len >= slots_num * N + { + let pos = self.slot_of(alloc.addr); + let _ = self.last_free_run.take(); + return Some(pos); + } + + // Fallback to full search + let total = self.used_slots.len(); + self.used_slots.zeroes().find(|&next_free| { + let end = next_free + slots_num; + end <= total && self.used_slots.count_zeroes(next_free..end) == slots_num + }) + } + + /// Allocate at least `len` bytes by merging consecutive slots. + pub fn alloc(&mut self, len: usize) -> Result { + if len == 0 { + return Err(AllocError::InvalidArg); + } + + let total = self.used_slots.len(); + let need_slots = len.div_ceil(N); + if need_slots > total { + return Err(AllocError::OutOfMemory); + } + + let idx = self.find_slots(need_slots).ok_or(AllocError::NoSpace)?; + self.used_slots.insert_range(idx..idx + need_slots); + let addr = self.addr_of(idx).ok_or(AllocError::Overflow)?; + + let alloc = Allocation { + addr, + len: need_slots * N, + }; + + self.maybe_invalidate_last_run(alloc); + Ok(alloc) + } + + /// Free a previously allocated slot or multiple slots. + /// + /// `len` must be a multiple of N and `addr` must be N-aligned to base. + pub fn dealloc(&mut self, alloc: Allocation) -> Result<(), AllocError> { + let Allocation { addr, len } = alloc; + if len == 0 || len % N != 0 || addr < self.base_addr { + return Err(AllocError::InvalidFree(addr, len)); + } + let alloc_slots = len / N; + let off = (addr - self.base_addr) as usize; + if off % N != 0 { + return Err(AllocError::InvalidFree(addr, len)); + } + let start = off / N; + let num_slots = self.used_slots.len(); + if start + alloc_slots > num_slots { + return Err(AllocError::InvalidFree(addr, len)); + } + + // Ensure all bits are set (avoid double-free) + if !self + .used_slots + .contains_all_in_range(start..start + alloc_slots) + { + return Err(AllocError::InvalidFree(addr, len)); + } + + // Mark as free + self.used_slots.remove_range(start..start + alloc_slots); + self.last_free_run = Some(alloc); + + Ok(()) + } + + /// Try to grow a block in place by reserving adjacent free slots to the right. + /// + /// Returns Ok(None) if in-place growth is not possible. Returns Err on invalid input. + pub fn try_grow_inplace( + &mut self, + old_alloc: Allocation, + new_len: usize, + ) -> Result, AllocError> { + let Allocation { + addr: old_addr, + len: old_len, + } = old_alloc; + + if new_len <= old_len || old_len == 0 || old_len % N != 0 { + return Err(AllocError::InvalidFree(old_addr, old_len)); + } + + let old_slots = old_len / N; + let need_slots = new_len.div_ceil(N); + let off = (old_addr - self.base_addr) as usize; + if off % N != 0 { + return Err(AllocError::InvalidFree(old_addr, old_len)); + } + + let start = off / N; + if start + need_slots > self.used_slots.len() { + return Ok(None); + } + // Existing range must be allocated + if !self + .used_slots + .contains_all_in_range(start..start + old_slots) + { + return Err(AllocError::InvalidFree(old_addr, old_len)); + } + + // Extension must be free + if self + .used_slots + .count_ones(start + old_slots..start + need_slots) + > 0 + { + return Ok(None); + } + + // Mark extension as allocated + self.used_slots + .insert_range(start + old_slots..start + need_slots); + + let alloc = Allocation { + addr: old_addr, + len: need_slots * N, + }; + + self.maybe_invalidate_last_run(alloc); + Ok(Some(alloc)) + } + + /// Shrink a block in place by freeing excess slots to the right. + pub fn shrink_inplace( + &mut self, + old_alloc: Allocation, + new_len: usize, + ) -> Result { + let Allocation { + addr: old_addr, + len: old_len, + } = old_alloc; + + if new_len >= old_len || old_len == 0 || old_len % N != 0 { + return Err(AllocError::InvalidFree(old_addr, old_len)); + } + + let old_slots = old_len / N; + let need_slots = new_len.div_ceil(N); + let off = (old_addr - self.base_addr) as usize; + if off % N != 0 { + return Err(AllocError::InvalidFree(old_addr, old_len)); + } + + let start = off / N; + if start + old_slots > self.used_slots.len() { + return Err(AllocError::InvalidFree(old_addr, old_len)); + } + // Existing range must be allocated + if !self + .used_slots + .contains_all_in_range(start..start + old_slots) + { + return Err(AllocError::InvalidFree(old_addr, old_len)); + } + + // Free the excess slots + self.used_slots + .remove_range(start + need_slots..start + old_slots); + + Ok(Allocation { + addr: old_addr, + len: need_slots * N, + }) + } + + /// Reallocate by trying in-place grow; otherwise reserve a new run of slots and free old. + /// Caller should copy the payload; this function only manages reservations. + pub fn resize( + &mut self, + old_alloc: Allocation, + new_len: usize, + ) -> Result { + if new_len == 0 { + return Err(AllocError::InvalidArg); + } + + match new_len.cmp(&old_alloc.len) { + Ordering::Greater => { + match self.try_grow_inplace(old_alloc, new_len) { + // in-place growth succeeded + Ok(Some(new_alloc)) => Ok(new_alloc), + // in-place growth failed; allocate new and free old + Ok(None) => { + let new_alloc = self.alloc(new_len)?; + self.dealloc(old_alloc)?; + Ok(new_alloc) + } + // other errors are propagated + Err(err) => Err(err), + } + } + Ordering::Less => self.shrink_inplace(old_alloc, new_len), + Ordering::Equal => Ok(old_alloc), + } + } + + /// Usable size rounded up to slot multiple. + pub fn usable_size(&self, _addr: usize, len: usize) -> usize { + if len == 0 { 0 } else { len.div_ceil(N) * N } + } + + /// Number of free bytes in the slab. + pub fn free_bytes(&self) -> usize { + (self.used_slots.len() - self.used_slots.count_ones(..)) * N + } + + /// Total capacity of the slab in bytes. + pub fn capacity(&self) -> usize { + self.used_slots.len() * N + } + + /// Get the address range covered by this slab. + pub fn range(&self) -> core::ops::Range { + let end = self.base_addr + self.capacity() as u64; + self.base_addr..end + } + + /// Check if an address is within this slab's range. + pub fn contains(&self, addr: u64) -> bool { + self.range().contains(&addr) + } + + /// Get the slot size N. + pub const fn slot_size() -> usize { + N + } + + /// Reset the slab to initial state which is all slots free. + pub fn reset(&mut self) { + self.used_slots.clear(); + self.last_free_run = None; + } +} + +#[inline] +fn align_up(val: usize, align: usize) -> usize { + assert!(align > 0); + if val == 0 { + return 0; + } + val.div_ceil(align) * align +} + +#[derive(Debug)] +struct Inner { + lower: Slab, + upper: Slab, +} + +/// Two tier buffer pool with small and large slabs. +#[derive(Debug, Clone)] +pub struct BufferPool { + inner: SyncWrap>>>, +} + +impl BufferPool { + /// Create a new buffer pool over a fixed region. + pub fn new(base_addr: u64, region_len: usize) -> Result { + let inner = Inner::::new(base_addr, region_len)?; + Ok(Self { + inner: SyncWrap(Rc::new(RefCell::new(inner))), + }) + } +} + +impl BufferPool { + /// Upper slab slot size in bytes. + pub const fn upper_slot_size() -> usize { + 4096 + } + + /// Lower slab slot size in bytes. + pub const fn lower_slot_size() -> usize { + 256 + } +} + +#[cfg(all(test, loom))] +#[derive(Debug, Clone)] +pub struct BufferPoolSync { + inner: std::sync::Arc>>, +} + +#[cfg(all(test, loom))] +impl BufferPoolSync { + /// Create a new buffer pool over a fixed region. + pub fn new(base_addr: u64, region_len: usize) -> Result { + let inner = Inner::::new(base_addr, region_len)?; + Ok(Self { + inner: Arc::new(std::sync::Mutex::new(inner)), + }) + } +} + +impl Inner { + /// Create a new buffer pool over a fixed region. + pub fn new(base_addr: u64, region_len: usize) -> Result { + const LOWER_FRACTION: usize = 8; + + let lower_region = region_len / LOWER_FRACTION; + let upper_region = region_len - lower_region; + + let mut aligned = base_addr; + aligned = align_up(aligned as usize, L) as u64; + let lower = Slab::::new(aligned, lower_region)?; + + // advance and align upper base to N + aligned = aligned + .checked_add(lower.capacity() as u64) + .ok_or(AllocError::Overflow)?; + + aligned = align_up(aligned as usize, U) as u64; + let upper = Slab::::new(aligned, upper_region)?; + + Ok(Self { lower, upper }) + } + + /// Allocate at least `len` bytes. + pub fn alloc(&mut self, len: usize) -> Result { + if len <= L { + match self.lower.alloc(len) { + Ok(alloc) => return Ok(alloc), + Err(AllocError::NoSpace) => {} + Err(e) => return Err(e), + } + } + + // fallback to upper slab + self.upper.alloc(len) + } + + /// Free a previously allocated block. + pub fn dealloc(&mut self, alloc: Allocation) -> Result<(), AllocError> { + if self.lower.contains(alloc.addr) { + self.lower.dealloc(alloc) + } else { + self.upper.dealloc(alloc) + } + } + + /// Reallocate by trying in-place grow; otherwise reserve a new block and free old. + pub fn resize( + &mut self, + old_alloc: Allocation, + new_len: usize, + ) -> Result { + if self.lower.contains(old_alloc.addr) { + maybe_move(&mut self.lower, &mut self.upper, old_alloc, new_len) + } else { + maybe_move(&mut self.upper, &mut self.lower, old_alloc, new_len) + } + } +} + +/// Try to realloc using slab that owns the old allocation; if that fails, +/// try to allocate in the other slab. The function prefers to move allocations +/// between slabs only when necessary based on size thresholds. +#[inline] +fn maybe_move( + slab: &mut Slab, + other: &mut Slab, + old_alloc: Allocation, + new_len: usize, +) -> Result { + let needs_move = if A < B { new_len > A } else { new_len <= B }; + if !needs_move { + return slab.resize(old_alloc, new_len); + } + + let new_alloc = other.alloc(new_len)?; + + slab.dealloc(old_alloc)?; + Ok(new_alloc) +} + +impl BufferProvider for BufferPool { + fn alloc(&self, len: usize) -> Result { + self.inner.borrow_mut().alloc(len) + } + + fn dealloc(&self, alloc: Allocation) -> Result<(), AllocError> { + self.inner.borrow_mut().dealloc(alloc) + } + + fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result { + self.inner.borrow_mut().resize(old_alloc, new_len) + } + + fn reset(&self) { + let mut inner = self.inner.borrow_mut(); + inner.lower.reset(); + inner.upper.reset(); + } +} + +#[cfg(all(test, loom))] +impl BufferProvider for BufferPoolSync { + fn alloc(&self, len: usize) -> Result { + self.inner.lock().expect("poisoned mutex").alloc(len) + } + + fn dealloc(&self, alloc: Allocation) -> Result<(), AllocError> { + self.inner.lock().expect("poisoned mutex").dealloc(alloc) + } + + fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result { + self.inner + .lock() + .expect("poisoned mutex") + .resize(old_alloc, new_len) + } +} + +struct RecyclePoolInner { + base_addr: u64, + slot_size: usize, + count: usize, + free: SmallVec<[u64; 64]>, +} + +/// A recycling buffer provider with fixed-size slots. +/// +/// Unlike [`BufferPool`] which uses a bitmap allocator, this holds a +/// fixed set of same-sized buffer addresses in a free list. Alloc and +/// dealloc are O(1). Intended for H2G writable buffers that are +/// pre-allocated once and recycled after each use. +#[derive(Clone)] +pub struct RecyclePool { + inner: SyncWrap>>, +} + +impl RecyclePool { + /// Create a new recycling pool by carving `base..base+region_len` into slots of `slot_size` bytes. + pub fn new(base_addr: u64, region_len: usize, slot_size: usize) -> Result { + if slot_size == 0 { + return Err(AllocError::InvalidArg); + } + + let count = region_len / slot_size; + if count == 0 { + return Err(AllocError::EmptyRegion); + } + + let mut free = SmallVec::with_capacity(count); + for i in 0..count { + free.push(base_addr + (i * slot_size) as u64); + } + + let inner = RefCell::new(RecyclePoolInner { + base_addr, + slot_size, + count, + free, + }); + + Ok(Self { + inner: SyncWrap(Rc::new(inner)), + }) + } + + /// Rebuild pool state so that every address in `allocated` is removed + /// from the free list, matching externally known inflight state. + pub fn restore_allocated(&self, allocated: &[u64]) -> Result<(), AllocError> { + self.reset(); + + if allocated.is_empty() { + return Ok(()); + } + + let mut inner = self.inner.borrow_mut(); + + for &addr in allocated { + let pos = inner + .free + .iter() + .position(|&a| a == addr) + .ok_or(AllocError::InvalidFree(addr, inner.slot_size))?; + + inner.free.swap_remove(pos); + } + + Ok(()) + } + + /// Compute the address of slot `index`. + /// + /// Returns `None` if `index >= count`. + pub fn slot_addr(&self, index: usize) -> Option { + let inner = self.inner.borrow(); + if index < inner.count { + Some(inner.base_addr + (index * inner.slot_size) as u64) + } else { + None + } + } + + /// Number of free slots. + pub fn num_free(&self) -> usize { + self.inner.borrow().free.len() + } + + /// Base address of the pool region. + pub fn base_addr(&self) -> u64 { + self.inner.borrow().base_addr + } + + /// Slot size in bytes. + pub fn slot_size(&self) -> usize { + self.inner.borrow().slot_size + } + + /// Number of slots in the pool. + pub fn count(&self) -> usize { + self.inner.borrow().count + } +} + +impl BufferProvider for RecyclePool { + fn alloc(&self, len: usize) -> Result { + let mut inner = self.inner.borrow_mut(); + if len > inner.slot_size { + return Err(AllocError::OutOfMemory); + } + + let addr = inner.free.pop().ok_or(AllocError::NoSpace)?; + + Ok(Allocation { + addr, + len: inner.slot_size, + }) + } + + fn dealloc(&self, alloc: Allocation) -> Result<(), AllocError> { + let mut inner = self.inner.borrow_mut(); + let end = inner.base_addr + (inner.count * inner.slot_size) as u64; + + if alloc.addr < inner.base_addr || alloc.addr >= end { + return Err(AllocError::InvalidFree(alloc.addr, alloc.len)); + } + + if (alloc.addr - inner.base_addr) % inner.slot_size as u64 != 0 { + return Err(AllocError::InvalidFree(alloc.addr, alloc.len)); + } + + if inner.free.contains(&alloc.addr) { + return Err(AllocError::InvalidFree(alloc.addr, alloc.len)); + } + + inner.free.push(alloc.addr); + Ok(()) + } + + fn resize(&self, old: Allocation, new_len: usize) -> Result { + let inner = self.inner.borrow(); + if new_len > inner.slot_size { + return Err(AllocError::OutOfMemory); + } + Ok(old) + } + + fn reset(&self) { + let mut inner = self.inner.borrow_mut(); + let base = inner.base_addr; + let slot = inner.slot_size; + let count = inner.count; + + inner.free.clear(); + + for i in 0..count { + inner.free.push(base + (i * slot) as u64); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_slab(size: usize) -> Slab { + let base = align_up(0x10000, N) as u64; + Slab::::new(base, size).unwrap() + } + + fn make_pool(size: usize) -> BufferPool { + let base = align_up(0x10000, L.max(U)) as u64; + BufferPool::::new(base, size).unwrap() + } + + fn make_recycle_pool(slot_count: usize, slot_size: usize) -> RecyclePool { + let base = 0x80000u64; + RecyclePool::new(base, slot_count * slot_size, slot_size).unwrap() + } + + #[test] + fn test_slab_new_success() { + let slab = Slab::<256>::new(0x10000, 1024).unwrap(); + assert_eq!(slab.capacity(), 1024); + assert_eq!(slab.free_bytes(), 1024); + } + + #[test] + fn test_slab_new_misaligned() { + let result = Slab::<256>::new(0x10001, 1024); + assert!(matches!(result, Err(AllocError::InvalidAlign(0x10001)))); + } + + #[test] + fn test_slab_new_empty_region() { + let result = Slab::<256>::new(0x10000, 100); + assert!(matches!(result, Err(AllocError::EmptyRegion))); + } + + #[test] + fn test_slab_alloc_single_slot() { + let mut slab = make_slab::<256>(1024); + let alloc = slab.alloc(128).unwrap(); + assert_eq!(alloc.len, 256); + assert_eq!(slab.free_bytes(), 1024 - 256); + } + + #[test] + fn test_slab_alloc_multiple_slots() { + let mut slab = make_slab::<256>(1024); + let alloc = slab.alloc(600).unwrap(); + assert_eq!(alloc.len, 768); // 3 slots × 256 bytes + assert_eq!(slab.free_bytes(), 1024 - 768); + } + + #[test] + fn test_slab_alloc_zero_length() { + let mut slab = make_slab::<256>(1024); + let result = slab.alloc(0); + assert!(matches!(result, Err(AllocError::InvalidArg))); + } + + #[test] + fn test_slab_alloc_too_large() { + let mut slab = make_slab::<256>(1024); + let result = slab.alloc(2048); + assert!(matches!(result, Err(AllocError::OutOfMemory))); + } + + #[test] + fn test_slab_alloc_until_full() { + let mut slab = make_slab::<256>(1024); + + // Allocate all 4 slots + let _a1 = slab.alloc(256).unwrap(); + let a2 = slab.alloc(256).unwrap(); + let _a3 = slab.alloc(256).unwrap(); + let _a4 = slab.alloc(256).unwrap(); + + assert_eq!(slab.free_bytes(), 0); + + // Next allocation should fail + let result = slab.alloc(256); + assert!(matches!(result, Err(AllocError::NoSpace))); + + // Free one and retry + slab.dealloc(a2).unwrap(); + let a5 = slab.alloc(256).unwrap(); + assert_eq!(a5.addr, a2.addr); // Should reuse same slot + } + + #[test] + fn test_slab_free_success() { + let mut slab = make_slab::<256>(1024); + let alloc = slab.alloc(256).unwrap(); + assert_eq!(slab.free_bytes(), 768); + + slab.dealloc(alloc).unwrap(); + assert_eq!(slab.free_bytes(), 1024); + } + + #[test] + fn test_slab_free_invalid_length() { + let mut slab = make_slab::<256>(1024); + let mut alloc = slab.alloc(256).unwrap(); + alloc.len = 100; // Invalid: not multiple of N + + let result = slab.dealloc(alloc); + assert!(matches!(result, Err(AllocError::InvalidFree(_, 100)))); + } + + #[test] + fn test_slab_free_double_free() { + let mut slab = make_slab::<256>(1024); + let alloc = slab.alloc(256).unwrap(); + + slab.dealloc(alloc).unwrap(); + let result = slab.dealloc(alloc); + assert!(matches!(result, Err(AllocError::InvalidFree(_, _)))); + } + + #[test] + fn test_slab_multi_slot_alloc_near_end() { + let mut slab = make_slab::<256>(1792); // 7 slots + let a0 = slab.alloc(256).unwrap(); + let a1 = slab.alloc(256).unwrap(); + let _a2 = slab.alloc(256).unwrap(); + let _a3 = slab.alloc(256).unwrap(); + let _a4 = slab.alloc(256).unwrap(); + let _a5 = slab.alloc(256).unwrap(); + let _a6 = slab.alloc(256).unwrap(); + + slab.dealloc(a0).unwrap(); + slab.dealloc(a1).unwrap(); + + // 2-slot run fits at indices 0..2 but the search visits index 6 + // (a free zero) first if slots 0-1 are not found before it. + // Actually slots 0-1 are free, so it should find them. + let run = slab.alloc(300).unwrap(); // needs 2 slots + assert_eq!(run.len, 512); + } + + #[test] + fn test_slab_multi_slot_alloc_no_room_at_end() { + // Only the last slot is free but a 2-slot run is requested. + // find_slots must not panic when checking beyond the bitset. + let mut slab = make_slab::<256>(1792); // 7 slots + let allocs: Vec<_> = (0..7).map(|_| slab.alloc(256).unwrap()).collect(); + // Free only the last slot (index 6) + slab.dealloc(allocs[6]).unwrap(); + + let result = slab.alloc(300); // needs 2 slots, only 1 free + assert!(matches!(result, Err(AllocError::NoSpace))); + } + + #[test] + fn test_slab_free_invalid_address() { + let mut slab = make_slab::<256>(1024); + let alloc = Allocation { + addr: 0x99999, + len: 256, + }; + + let result = slab.dealloc(alloc); + assert!(matches!(result, Err(AllocError::InvalidFree(0x99999, _)))); + } + + #[test] + fn test_slab_cursor_optimization_lifo() { + let mut slab = make_slab::<256>(1024); + + let a1 = slab.alloc(256).unwrap(); + let addr1 = a1.addr; + + slab.dealloc(a1).unwrap(); + + // Next allocation should reuse same slot (cursor moved back) + let a2 = slab.alloc(256).unwrap(); + assert_eq!(a2.addr, addr1); + } + + #[test] + fn test_slab_cursor_rewind_for_single_slot() { + let mut slab = make_slab::<256>(1024); + + let _a1 = slab.alloc(256).unwrap(); + let a2 = slab.alloc(256).unwrap(); + let _a3 = slab.alloc(256).unwrap(); + + // Free single-slot at position 1, before cursor at 3 + slab.dealloc(a2).unwrap(); + + // Cursor should rewind to 1 + let a4 = slab.alloc(256).unwrap(); + // Should reuse slot 1 + assert_eq!(a4.addr, a2.addr); + } + + #[test] + fn test_slab_grow_inplace_success() { + let mut slab = make_slab::<256>(1024); + let alloc = slab.alloc(256).unwrap(); + + // Grow from 256 to 512 (adjacent slot is free) + let grown = slab.try_grow_inplace(alloc, 512).unwrap(); + assert!(grown.is_some()); + assert_eq!(grown.unwrap().len, 512); + assert_eq!(grown.unwrap().addr, alloc.addr); + } + + #[test] + fn test_slab_grow_inplace_blocked() { + let mut slab = make_slab::<256>(1024); + let a1 = slab.alloc(256).unwrap(); + let _a2 = slab.alloc(256).unwrap(); // Blocks growth + + // Can't grow because next slot is allocated + let result = slab.try_grow_inplace(a1, 512).unwrap(); + assert!(result.is_none()); + } + + #[test] + fn test_slab_shrink_inplace() { + let mut slab = make_slab::<256>(1024); + let alloc = slab.alloc(512).unwrap(); // 2 slots + + let shrunk = slab.shrink_inplace(alloc, 256).unwrap(); + assert_eq!(shrunk.len, 256); + assert_eq!(shrunk.addr, alloc.addr); + assert_eq!(slab.free_bytes(), 1024 - 256); + } + + #[test] + fn test_slab_realloc_grow_inplace() { + let mut slab = make_slab::<256>(1024); + let alloc = slab.alloc(256).unwrap(); + + let new_alloc = slab.resize(alloc, 512).unwrap(); + assert_eq!(new_alloc.addr, alloc.addr); // Same address (in-place) + assert_eq!(new_alloc.len, 512); + } + + #[test] + fn test_slab_realloc_grow_relocate() { + let mut slab = make_slab::<256>(1024); + let a1 = slab.alloc(256).unwrap(); + let _a2 = slab.alloc(256).unwrap(); // Blocks growth + + let new_alloc = slab.resize(a1, 512).unwrap(); + assert_ne!(new_alloc.addr, a1.addr); // Different address (relocated) + assert_eq!(new_alloc.len, 512); + } + + #[test] + fn test_slab_realloc_shrink() { + let mut slab = make_slab::<256>(1024); + let alloc = slab.alloc(512).unwrap(); + + let new_alloc = slab.resize(alloc, 256).unwrap(); + assert_eq!(new_alloc.addr, alloc.addr); + assert_eq!(new_alloc.len, 256); + } + + #[test] + fn test_slab_realloc_same_size() { + let mut slab = make_slab::<256>(1024); + let alloc = slab.alloc(256).unwrap(); + + let new_alloc = slab.resize(alloc, 256).unwrap(); + assert_eq!(new_alloc.addr, alloc.addr); + assert_eq!(new_alloc.len, alloc.len); + } + + #[test] + fn test_slab_fragmentation_handling() { + let mut slab = make_slab::<256>(1024); + + // Create fragmentation: [U][F][U][F] + let a1 = slab.alloc(256).unwrap(); + let a2 = slab.alloc(256).unwrap(); + let _a3 = slab.alloc(256).unwrap(); + let _a4 = slab.alloc(256).unwrap(); + + slab.dealloc(a2).unwrap(); + slab.dealloc(a1).unwrap(); + + // Should still be able to allocate 2-slot buffer + let big = slab.alloc(512).unwrap(); + assert_eq!(big.len, 512); + } + + #[test] + fn test_pool_new_success() { + let pool = BufferPool::<256, 4096>::new(0x10000, 1024 * 1024).unwrap(); + assert!(pool.inner.borrow().lower.capacity() > 0); + assert!(pool.inner.borrow().upper.capacity() > 0); + } + + #[test] + fn test_pool_alloc_small_to_lower() { + let pool = make_pool::<256, 4096>(1024 * 1024); + let alloc = pool.alloc(128).unwrap(); + + // Should come from lower slab + assert!(pool.inner.borrow().lower.contains(alloc.addr)); + assert_eq!(alloc.len, 256); + } + + #[test] + fn test_pool_alloc_large_to_upper() { + let pool = make_pool::<256, 4096>(1024 * 1024); + let alloc = pool.alloc(1500).unwrap(); + + // Should come from upper slab + assert!(pool.inner.borrow().upper.contains(alloc.addr)); + assert_eq!(alloc.len, 4096); + } + + #[test] + fn test_pool_alloc_fallback_to_upper() { + let pool = make_pool::<256, 4096>(1024 * 1024); + + // Fill lower slab completely + let mut allocations = Vec::new(); + while pool.inner.borrow().lower.free_bytes() > 0 { + allocations.push(pool.inner.borrow_mut().lower.alloc(256).unwrap()); + } + + // Small allocation should fallback to upper slab + let alloc = pool.alloc(128).unwrap(); + assert!(pool.inner.borrow().upper.contains(alloc.addr)); + } + + #[test] + fn test_pool_free_from_lower() { + let pool = make_pool::<256, 4096>(1024 * 1024); + let alloc = pool.alloc(128).unwrap(); + + let free_before = pool.inner.borrow().lower.free_bytes(); + pool.dealloc(alloc).unwrap(); + assert_eq!( + pool.inner.borrow().lower.free_bytes(), + free_before + alloc.len + ); + } + + #[test] + fn test_pool_free_from_upper() { + let pool = make_pool::<256, 4096>(1024 * 1024); + let alloc = pool.alloc(1500).unwrap(); + + let free_before = pool.inner.borrow().upper.free_bytes(); + pool.dealloc(alloc).unwrap(); + assert_eq!( + pool.inner.borrow().upper.free_bytes(), + free_before + alloc.len + ); + } + + #[test] + fn test_pool_realloc_within_same_tier() { + let pool = make_pool::<256, 4096>(1024 * 1024); + let alloc = pool.alloc(128).unwrap(); + + // Realloc within lower tier (128 -> 200, both fit in 256 slots) + let new_alloc = pool.resize(alloc, 200).unwrap(); + assert!(pool.inner.borrow().lower.contains(new_alloc.addr)); + } + + #[test] + fn test_pool_realloc_move_to_different_tier() { + let pool = make_pool::<256, 4096>(1024 * 1024); + let alloc = pool.alloc(128).unwrap(); + assert!(pool.inner.borrow().lower.contains(alloc.addr)); + + // Realloc to size that needs upper tier + let new_alloc = pool.resize(alloc, 1500).unwrap(); + assert!(pool.inner.borrow().upper.contains(new_alloc.addr)); + } + + #[test] + fn test_pool_realloc_shrink_stays_in_tier() { + let pool = make_pool::<256, 4096>(1024 * 1024); + let alloc = pool.alloc(1500).unwrap(); + assert!(pool.inner.borrow().upper.contains(alloc.addr)); + + // Shrink but stay in upper tier + let new_alloc = pool.resize(alloc, 1000).unwrap(); + assert!(pool.inner.borrow().upper.contains(new_alloc.addr)); + } + + #[test] + fn test_pool_stress_many_allocations() { + let pool = make_pool::<256, 4096>(4 * 1024 * 1024); + let mut allocations = Vec::new(); + + // Allocate many buffers + for i in 0..100 { + let size = if i % 2 == 0 { 128 } else { 1500 }; + allocations.push(pool.alloc(size).unwrap()); + } + + // Free half of them + for i in (0..100).step_by(2) { + pool.dealloc(allocations[i]).unwrap(); + } + + // Should be able to allocate again + for i in 0..50 { + let size = if i % 2 == 0 { 128 } else { 1500 }; + let _alloc = pool.alloc(size).unwrap(); + } + } + + #[test] + fn test_pool_mixed_workload() { + let pool = make_pool::<256, 4096>(2 * 1024 * 1024); + + // Simulate virtio-net workload + let desc_buf = pool.alloc(64).unwrap(); // Control message + let rx_buf1 = pool.alloc(1500).unwrap(); // MTU packet + let rx_buf2 = pool.alloc(1500).unwrap(); // MTU packet + let tx_buf = pool.alloc(4096).unwrap(); // Large buffer + + // Free and reallocate + pool.dealloc(rx_buf1).unwrap(); + let rx_buf3 = pool.alloc(1500).unwrap(); + + // Should reuse freed buffer (LIFO) + assert_eq!(rx_buf3.addr, rx_buf1.addr); + + pool.dealloc(desc_buf).unwrap(); + pool.dealloc(rx_buf2).unwrap(); + pool.dealloc(rx_buf3).unwrap(); + pool.dealloc(tx_buf).unwrap(); + } + + #[test] + fn test_pool_zero_allocation_error() { + let pool = make_pool::<256, 4096>(1024 * 1024); + let result = pool.alloc(0); + assert!(matches!(result, Err(AllocError::InvalidArg))); + } + + #[test] + fn test_pool_too_large_allocation() { + let pool = make_pool::<256, 4096>(1024 * 1024); + let result = pool.alloc(2 * 1024 * 1024); // Larger than pool + assert!(matches!(result, Err(AllocError::OutOfMemory))); + } + + #[test] + fn test_align_up_helper() { + assert_eq!(align_up(0, 256), 0); + assert_eq!(align_up(1, 256), 256); + assert_eq!(align_up(256, 256), 256); + assert_eq!(align_up(257, 256), 512); + assert_eq!(align_up(511, 256), 512); + assert_eq!(align_up(512, 256), 512); + } + + #[test] + fn test_slab_usable_size() { + let slab = make_slab::<256>(1024); + assert_eq!(slab.usable_size(0, 0), 0); + assert_eq!(slab.usable_size(0, 1), 256); + assert_eq!(slab.usable_size(0, 256), 256); + assert_eq!(slab.usable_size(0, 257), 512); + } + + #[test] + fn test_slab_contains() { + let slab = make_slab::<256>(1024); + let range = slab.range(); + + assert!(slab.contains(range.start)); + assert!(!slab.contains(range.end)); // Exclusive end + assert!(!slab.contains(0)); + } + + // Edge case: allocation exactly at boundary + #[test] + fn test_pool_boundary_allocation() { + let pool = make_pool::<256, 4096>(1024 * 1024); + + // Allocate exactly at boundary + let alloc = pool.alloc(256).unwrap(); + assert!(pool.inner.borrow().lower.contains(alloc.addr)); + + // Allocate just over boundary + let alloc2 = pool.alloc(257).unwrap(); + assert!(pool.inner.borrow().upper.contains(alloc2.addr)); + } + + // Test overflow protection + #[test] + fn test_addr_of_overflow_protection() { + let slab = make_slab::<4096>(8192); + + // This should not panic due to overflow checks + let addr = slab.addr_of(usize::MAX); + assert!(addr.is_none()); + } + + #[test] + fn test_no_overlapping_allocations() { + let mut slab = make_slab::<4096>(32768); // 8 slots + + // Allocate slot 0-1 + let a1 = slab.alloc(8000).unwrap(); + assert_eq!(a1.len, 8192); + + // Shrink to slot 0 only + let a2 = slab.shrink_inplace(a1, 4000).unwrap(); + assert_eq!(a2.len, 4096); + + // Allocate at slot 1-2 + let a3 = slab.alloc(8000).unwrap(); + assert_eq!(a3.len, 8192); + let slot1_addr = a2.addr + 4096; + assert_eq!(a3.addr, slot1_addr); + + // Free slot 0 + slab.dealloc(a2).unwrap(); + + // Try to allocate 2 slots - should NOT get slot 0-1 because slot 1 is occupied! + let a4 = slab.alloc(8000).unwrap(); + assert_ne!(a4.addr, a2.addr); // Should be at a different location + + slab.dealloc(a3).unwrap(); + slab.dealloc(a4).unwrap(); + } + + #[test] + fn test_slab_reset_returns_to_initial_state() { + let mut slab = make_slab::<256>(4096); + let initial_free = slab.free_bytes(); + let initial_cap = slab.capacity(); + + // Allocate some slots + let _a1 = slab.alloc(256).unwrap(); + let _a2 = slab.alloc(512).unwrap(); + assert!(slab.free_bytes() < initial_free); + + slab.reset(); + + assert_eq!(slab.free_bytes(), initial_free); + assert_eq!(slab.capacity(), initial_cap); + assert!(slab.last_free_run.is_none()); + assert_eq!(slab.used_slots.count_ones(..), 0); + + // Should be able to allocate the full capacity again + let a = slab.alloc(initial_cap).unwrap(); + assert_eq!(a.len, initial_cap); + } + + #[test] + fn test_slab_reset_matches_new() { + let base = align_up(0x10000, 256) as u64; + let region = 4096; + + let fresh = Slab::<256>::new(base, region).unwrap(); + + let mut used = Slab::<256>::new(base, region).unwrap(); + let _a = used.alloc(256).unwrap(); + let _b = used.alloc(1024).unwrap(); + used.reset(); + + assert_eq!(used.free_bytes(), fresh.free_bytes()); + assert_eq!(used.capacity(), fresh.capacity()); + assert_eq!( + used.used_slots.count_ones(..), + fresh.used_slots.count_ones(..) + ); + assert!(used.last_free_run.is_none()); + assert!(fresh.last_free_run.is_none()); + } + + #[test] + fn test_buffer_pool_reset_returns_to_initial_state() { + let pool = make_pool::<256, 4096>(0x20000); + + // Allocate from both tiers + let a1 = pool.inner.borrow_mut().alloc(128).unwrap(); + let a2 = pool.inner.borrow_mut().alloc(8192).unwrap(); + assert!(a1.len > 0); + assert!(a2.len > 0); + + pool.reset(); + + let inner = pool.inner.borrow(); + assert_eq!(inner.lower.used_slots.count_ones(..), 0); + assert_eq!(inner.upper.used_slots.count_ones(..), 0); + assert!(inner.lower.last_free_run.is_none()); + assert!(inner.upper.last_free_run.is_none()); + } + + #[test] + fn test_buffer_pool_reset_allows_reallocation() { + let pool = make_pool::<256, 4096>(0x20000); + + // Fill up some allocations + let mut allocs = Vec::new(); + for _ in 0..5 { + allocs.push(pool.inner.borrow_mut().alloc(256).unwrap()); + } + + pool.reset(); + + // Should be able to allocate as if fresh + let a = pool.inner.borrow_mut().alloc(256).unwrap(); + assert!(a.len > 0); + } + + #[test] + fn test_recycle_pool_restore_allocated_removes_from_free_list() { + let pool = make_recycle_pool(4, 4096); + assert_eq!(pool.num_free(), 4); + + let addrs = [0x80000, 0x81000]; // slots 0 and 1 + pool.restore_allocated(&addrs).unwrap(); + assert_eq!(pool.num_free(), 2); + + // Allocating should only return the two remaining slots + let a1 = pool.alloc(4096).unwrap(); + let a2 = pool.alloc(4096).unwrap(); + assert!(pool.alloc(4096).is_err()); + + // The allocated addresses should be the non-restored ones + let mut got = [a1.addr, a2.addr]; + got.sort(); + assert_eq!(got, [0x82000, 0x83000]); + } + + #[test] + fn test_recycle_pool_restore_allocated_invalid_addr_returns_error() { + let pool = make_recycle_pool(4, 4096); + let result = pool.restore_allocated(&[0xDEAD]); + assert!(result.is_err()); + } + + #[test] + fn test_recycle_pool_restore_allocated_then_dealloc_roundtrip() { + let pool = make_recycle_pool(4, 4096); + let addr = 0x81000u64; + + pool.restore_allocated(&[addr]).unwrap(); + assert_eq!(pool.num_free(), 3); + + // Dealloc the restored address + pool.dealloc(Allocation { addr, len: 4096 }).unwrap(); + assert_eq!(pool.num_free(), 4); + } + + #[test] + fn test_recycle_pool_restore_allocated_all_slots() { + let pool = make_recycle_pool(4, 4096); + let addrs: Vec = (0..4).map(|i| 0x80000 + i * 4096).collect(); + + pool.restore_allocated(&addrs).unwrap(); + assert_eq!(pool.num_free(), 0); + assert!(pool.alloc(4096).is_err()); + } + + #[test] + fn test_recycle_pool_restore_allocated_empty_list_is_noop() { + let pool = make_recycle_pool(4, 4096); + pool.restore_allocated(&[]).unwrap(); + assert_eq!(pool.num_free(), 4); + } + + #[test] + fn test_recycle_pool_restore_allocated_resets_first() { + let pool = make_recycle_pool(4, 4096); + + // Allocate some slots + let _ = pool.alloc(4096).unwrap(); + let _ = pool.alloc(4096).unwrap(); + assert_eq!(pool.num_free(), 2); + + // restore_allocated resets then removes - so 4 - 1 = 3 + pool.restore_allocated(&[0x80000]).unwrap(); + assert_eq!(pool.num_free(), 3); + } + + #[test] + fn test_recycle_pool_dealloc_out_of_range() { + let pool = make_recycle_pool(4, 4096); + let _ = pool.alloc(4096).unwrap(); + + let bogus = Allocation { + addr: 0xDEAD, + len: 4096, + }; + assert!(matches!( + pool.dealloc(bogus), + Err(AllocError::InvalidFree(0xDEAD, 4096)) + )); + } + + #[test] + fn test_recycle_pool_dealloc_misaligned() { + let pool = make_recycle_pool(4, 4096); + let _ = pool.alloc(4096).unwrap(); + + let misaligned = Allocation { + addr: 0x80001, + len: 4096, + }; + assert!(matches!( + pool.dealloc(misaligned), + Err(AllocError::InvalidFree(0x80001, 4096)) + )); + } + + #[test] + fn test_recycle_pool_dealloc_double_free() { + let pool = make_recycle_pool(4, 4096); + let a = pool.alloc(4096).unwrap(); + pool.dealloc(a).unwrap(); + + // Second dealloc should fail - address is already in the free list + assert!(matches!( + pool.dealloc(a), + Err(AllocError::InvalidFree(_, _)) + )); + } + + #[test] + fn test_recycle_pool_random_order_dealloc() { + let pool = make_recycle_pool(8, 4096); + + let mut allocs: Vec = (0..8).map(|_| pool.alloc(4096).unwrap()).collect(); + assert_eq!(pool.num_free(), 0); + + // Dealloc in reverse order + allocs.reverse(); + for a in &allocs { + pool.dealloc(*a).unwrap(); + } + assert_eq!(pool.num_free(), 8); + + // All slots should be re-allocatable + let reallocs: Vec = (0..8).map(|_| pool.alloc(4096).unwrap()).collect(); + assert_eq!(pool.num_free(), 0); + + // Verify all addresses are distinct + let mut addrs: Vec = reallocs.iter().map(|a| a.addr).collect(); + addrs.sort(); + addrs.dedup(); + assert_eq!(addrs.len(), 8); + } + + #[test] + fn test_recycle_pool_interleaved_alloc_dealloc_order() { + let pool = make_recycle_pool(4, 4096); + + let a0 = pool.alloc(4096).unwrap(); + let a1 = pool.alloc(4096).unwrap(); + let a2 = pool.alloc(4096).unwrap(); + let a3 = pool.alloc(4096).unwrap(); + assert_eq!(pool.num_free(), 0); + + // Free middle slots first (out of allocation order) + pool.dealloc(a2).unwrap(); + pool.dealloc(a0).unwrap(); + assert_eq!(pool.num_free(), 2); + + // Re-alloc gets the out-of-order slots back (LIFO) + let b0 = pool.alloc(4096).unwrap(); + assert_eq!(b0.addr, a0.addr); + let b1 = pool.alloc(4096).unwrap(); + assert_eq!(b1.addr, a2.addr); + + // Free everything in yet another order + pool.dealloc(a1).unwrap(); + pool.dealloc(b0).unwrap(); + pool.dealloc(b1).unwrap(); + pool.dealloc(a3).unwrap(); + assert_eq!(pool.num_free(), 4); + + // All 4 original addresses should be available + let mut final_addrs: Vec = (0..4).map(|_| pool.alloc(4096).unwrap().addr).collect(); + final_addrs.sort(); + let expected: Vec = (0..4).map(|i| 0x80000 + i * 4096).collect(); + assert_eq!(final_addrs, expected); + } + + #[test] + fn test_recycle_pool_dealloc_order_independent_of_alloc_order() { + let pool = make_recycle_pool(6, 256); + + // Allocate all + let allocs: Vec = (0..6).map(|_| pool.alloc(256).unwrap()).collect(); + + // Dealloc in scattered order: 4, 1, 5, 0, 3, 2 + let order = [4, 1, 5, 0, 3, 2]; + for &i in &order { + pool.dealloc(allocs[i]).unwrap(); + } + assert_eq!(pool.num_free(), 6); + + // Re-allocate all and verify we get back the full set + let mut realloc_addrs: Vec = (0..6).map(|_| pool.alloc(256).unwrap().addr).collect(); + realloc_addrs.sort(); + + let mut orig_addrs: Vec = allocs.iter().map(|a| a.addr).collect(); + orig_addrs.sort(); + + assert_eq!(realloc_addrs, orig_addrs); + } +} + +#[cfg(test)] +mod fuzz { + use quickcheck::{Arbitrary, Gen, QuickCheck}; + + use super::*; + + const MAX_OPS: usize = 10; + const MAX_ALLOC_SIZE: usize = 8192; + + #[derive(Clone, Debug)] + enum Op { + Alloc(usize), + Dealloc(usize), + Resize(usize, usize), + } + + impl Arbitrary for Op { + fn arbitrary(g: &mut Gen) -> Self { + match u8::arbitrary(g) % 3 { + 0 => Op::Alloc(usize::arbitrary(g) % MAX_ALLOC_SIZE + 1), + 1 => Op::Dealloc(usize::arbitrary(g)), + 2 => Op::Resize( + usize::arbitrary(g), + usize::arbitrary(g) % MAX_ALLOC_SIZE + 1, + ), + _ => unreachable!(), + } + } + } + + #[derive(Clone, Debug)] + struct Scenario { + pool_size: usize, + ops: Vec, + } + + impl Arbitrary for Scenario { + fn arbitrary(g: &mut Gen) -> Self { + let pool_size = (usize::arbitrary(g) % (4 * 1024 * 1024)) + (1024 * 1024); + let num_ops = usize::arbitrary(g) % MAX_OPS + 1; + let ops = (0..num_ops).map(|_| Op::arbitrary(g)).collect(); + + Scenario { pool_size, ops } + } + } + + fn run_scenario(s: Scenario) -> bool { + let base = align_up(0x10000, 4096) as u64; + let pool = match BufferPool::<256, 4096>::new(base, s.pool_size) { + Ok(p) => p, + Err(_) => return true, + }; + + let mut allocations: Vec = Vec::new(); + + for op in &s.ops { + match op { + Op::Alloc(size) => match pool.alloc(*size) { + Ok(alloc) => { + assert!(alloc.len >= *size); + allocations.push(alloc); + } + Err(AllocError::NoSpace | AllocError::OutOfMemory) => {} + Err(_) => { + return false; + } + }, + Op::Dealloc(idx) => { + if allocations.is_empty() { + continue; + } + + let idx = idx % allocations.len(); + let alloc = allocations.swap_remove(idx); + + match pool.dealloc(alloc) { + Ok(_) => {} + Err(_) => return false, + } + } + Op::Resize(idx, new_size) => { + if allocations.is_empty() { + continue; + } + + let idx = idx % allocations.len(); + let old_alloc = allocations[idx]; + + match pool.resize(old_alloc, *new_size) { + Ok(new_alloc) => { + assert!(new_alloc.len >= *new_size); + allocations[idx] = new_alloc; + } + Err(AllocError::NoSpace | AllocError::OutOfMemory) => {} + Err(_) => return false, + } + } + } + + if check_pool_invariants(&pool, &allocations).is_err() { + return false; + } + } + + // Cleanup + for alloc in &allocations { + if pool.dealloc(*alloc).is_err() { + return false; + } + } + + check_pool_invariants(&pool, &allocations).is_ok() + } + + fn check_slab_invariants(slab: &Slab) -> Result<(), &'static str> { + // Check that number of used + free slots equals total + let used = slab.used_slots.count_ones(..); + let free = slab.used_slots.count_zeroes(..); + if used + free != slab.used_slots.len() { + return Err("used + free != total slots"); + } + + let expected_free = free * N; + if slab.free_bytes() != expected_free { + return Err("free_bytes doesn't match bitmap"); + } + + if let Some(alloc) = slab.last_free_run { + if alloc.len == 0 || alloc.len % N != 0 { + return Err("last_free_run has invalid length"); + } + if !slab.contains(alloc.addr) { + return Err("last_free_run addr outside range"); + } + } + + Ok(()) + } + + fn check_pool_invariants( + pool: &BufferPool, + allocations: &[Allocation], + ) -> Result<(), &'static str> { + check_slab_invariants(&pool.inner.borrow().lower)?; + check_slab_invariants(&pool.inner.borrow().upper)?; + + if pool.inner.borrow().lower.range().end > pool.inner.borrow().upper.range().start { + return Err("lower and upper ranges overlap"); + } + + let mut seen = std::collections::HashSet::new(); + + for alloc in allocations { + if !pool.inner.borrow().lower.contains(alloc.addr) + && !pool.inner.borrow().upper.contains(alloc.addr) + { + return Err("allocation address outside pool ranges"); + } + + if alloc.len % L != 0 && alloc.len % U != 0 { + return Err("allocation length not aligned to any tier"); + } + + if !seen.insert(alloc.addr) { + return Err("duplicate allocation address in tracking"); + } + } + + Ok(()) + } + + #[test] + fn prop_allocator_invariants() { + #[cfg(miri)] + let tests = 10; + #[cfg(not(miri))] + let tests = 1000; + + QuickCheck::new() + .tests(tests) + .quickcheck(run_scenario as fn(Scenario) -> bool); + } +} diff --git a/src/hyperlight_common/src/virtq/producer.rs b/src/hyperlight_common/src/virtq/producer.rs new file mode 100644 index 000000000..7d3e35ab2 --- /dev/null +++ b/src/hyperlight_common/src/virtq/producer.rs @@ -0,0 +1,1446 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +use alloc::collections::VecDeque; +use alloc::vec; +use alloc::vec::Vec; + +use bytes::Bytes; +use smallvec::SmallVec; + +use super::*; + +/// A completion received by the driver (producer) side. +/// +/// Contains the completion data and metadata about the completed entry. +/// The `data` field is a zero-copy [`Bytes`] backed by a shared-memory +/// pool allocation that is returned when the last `Bytes` clone is dropped. +#[derive(Debug)] +pub struct RecvCompletion { + /// Token identifying which entry this completion corresponds to. + pub token: Token, + /// Completion data from the device. + pub data: Bytes, + /// Whether this entry is oneshot so there is no writable completion buffer. + /// Oneshot entries are fire-and-forget: the producer does not + /// expect any response data from the device. + pub oneshot: bool, +} + +/// Allocation tracking for an in-flight descriptor chain. +/// +/// Each variant corresponds to a buffer layout submitted by the driver +/// (guest/producer) and consumed by the device (host/consumer). +/// "Readable" and "writable" are from the device's perspective, following +/// the virtio convention. +#[derive(Debug, Clone, Copy)] +pub(crate) enum Inflight { + /// Driver sends data, device only acknowledges (fire-and-forget). + /// The readable buffer carries the entry; no writable buffer for a + /// device response. + ReadOnly { entry: Allocation }, + /// Driver pre-posts a writable buffer for the device to fill. + /// No readable entry - the device writes a response into the + /// completion buffer unprompted (e.g. event delivery). + WriteOnly { completion: Allocation }, + /// Bidirectional: driver sends an entry, device writes a response. + /// The readable buffer carries the entry, the writable buffer + /// receives the completion (typical request/response pattern). + ReadWrite { + entry: Allocation, + completion: Allocation, + }, +} + +impl Inflight { + fn entry(&self) -> Option { + match self { + Inflight::ReadOnly { entry } => Some(*entry), + Inflight::ReadWrite { entry, .. } => Some(*entry), + Inflight::WriteOnly { .. } => None, + } + } + + fn completion(&self) -> Option { + match self { + Inflight::WriteOnly { completion } => Some(*completion), + Inflight::ReadWrite { completion, .. } => Some(*completion), + Inflight::ReadOnly { .. } => None, + } + } + + fn try_into_chain(self, entry_len: usize) -> Result { + if let Some(entry) = self.entry() + && entry_len > entry.len + { + return Err(VirtqError::EntryTooLarge); + } + + Ok(match self { + Inflight::ReadOnly { entry } => BufferChainBuilder::new() + .readable(entry.addr, entry_len as u32) + .build()?, + Inflight::WriteOnly { completion } => BufferChainBuilder::new() + .writable(completion.addr, completion.len as u32) + .build()?, + Inflight::ReadWrite { entry, completion } => BufferChainBuilder::new() + .readable(entry.addr, entry_len as u32) + .writable(completion.addr, completion.len as u32) + .build()?, + }) + } +} + +/// A high-level virtqueue producer (driver side). +/// +/// The producer sends entries to the consumer (device), and receives completions. +/// This is typically used on the driver/guest side. +/// +/// # Example +/// +/// ```ignore +/// let mut producer = VirtqProducer::new(layout, mem, notifier, pool); +/// +/// // Build and submit an entry +/// let mut se = producer.chain().entry(64).completion(64).build()?; +/// se.write_all(b"hello")?; +/// let token = producer.submit(se)?; +/// +/// // Later, poll for completion +/// if let Some(cqe) = producer.poll()? { +/// assert_eq!(cqe.token, token); +/// println!("Got completion: {:?}", cqe.data); +/// } +/// ``` +pub struct VirtqProducer { + inner: RingProducer, + notifier: N, + pool: P, + next_token: u32, + inflight: Vec>, + pending: VecDeque, +} + +impl VirtqProducer +where + M: MemOps + Clone, + N: Notifier, + P: BufferProvider + Clone, +{ + /// Create a new virtqueue producer. + /// + /// # Arguments + /// + /// * `layout` - Ring memory layout (descriptor table and event suppression addresses) + /// * `mem` - Memory operations implementation for reading/writing to shared memory + /// * `notifier` - Callback for notifying the device (consumer) about new entries + /// * `pool` - Buffer allocator for entry/completion data + pub fn new(layout: Layout, mem: M, notifier: N, pool: P) -> Self { + let inner = RingProducer::new(layout, mem); + let ring_len = inner.len(); + let inflight = vec![None; ring_len]; + + Self { + inner, + pool, + notifier, + inflight, + next_token: 0, + pending: VecDeque::with_capacity(ring_len), + } + } + + /// Poll for a single completion from the device. + /// + /// Returns buffered completions from prior [`reclaim`](Self::reclaim) + /// calls first, then checks the ring for new completions. + /// + /// Returns `Ok(Some(completion))` if a completion is available, `Ok(None)` if no + /// completions are ready (would block), or an error if the device misbehaved. + /// + /// The returned [`RecvCompletion::data`] is a zero-copy [`Bytes`] backed by the + /// shared-memory allocation via [`BufferOwner`]. The pool allocation is + /// held alive as long as any `Bytes` clone exists, and is returned to the + /// pool when the last clone is dropped. + /// + /// # Errors + /// + /// - [`VirtqError::InvalidState`] - Device returned invalid descriptor ID or + /// wrote more data than the completion buffer capacity + pub fn poll(&mut self) -> Result, VirtqError> + where + M: Send + 'static, + P: Send + 'static, + { + if let Some(cqe) = self.pending.pop_front() { + return Ok(Some(cqe)); + } + self.poll_ring() + } + + /// Reclaim ring slots and pool entries from completed descriptors. + /// + /// Processes all available used entries from the ring: frees entry + /// buffer allocations immediately, and buffers completion data for + /// later retrieval via [`poll`](Self::poll). + /// + /// Completions with empty data from read-only/oneshot entries are + /// discarded immediately. + /// + /// Use this to free resources under backpressure without losing + /// completion data. Returns the number of entries reclaimed. + pub fn reclaim(&mut self) -> Result + where + M: Send + 'static, + P: Send + 'static, + { + let mut count = 0; + while let Some(cqe) = self.poll_ring()? { + if !cqe.oneshot { + debug_assert!(self.pending.len() < self.inflight.len()); + debug_assert!(!cqe.data.is_empty()); + self.pending.push_back(cqe); + } + count += 1; + } + Ok(count) + } + + /// Poll one completion directly from the ring (bypassing pending buffer). + fn poll_ring(&mut self) -> Result, VirtqError> + where + M: Send + 'static, + P: Send + 'static, + { + let used = match self.inner.poll_used() { + Ok(u) => u, + Err(RingError::WouldBlock) => return Ok(None), + Err(e) => return Err(e.into()), + }; + + let id = used.id as usize; + let (token, inf) = self + .inflight + .get_mut(id) + .ok_or(VirtqError::InvalidState)? + .take() + .ok_or(VirtqError::InvalidState)?; + + // the token's descriptor ID must match the ring's + debug_assert_eq!( + token.1, used.id, + "ring returned desc_id={} but inflight slot {} has token with desc_id={}", + used.id, id, token.1, + ); + + let written = used.len as usize; + + // Free entry buffers (request data no longer needed) + if let Some(entry) = inf.entry() { + self.pool.dealloc(entry)?; + } + + let completion_guard = inf + .completion() + .map(|buf| PoolAlloc::new(self.pool.clone(), buf)); + + // Read completion data + let has_completion = completion_guard.is_some(); + let data = match completion_guard { + Some(buf) if written > buf.allocation().len => { + // This is a protocol violation + return Err(VirtqError::InvalidState); + } + Some(buf) => Bytes::from_owner( + buf.into_buffer_owner(self.inner.mem().clone(), written) + .map_err(|_| VirtqError::MemoryReadError)?, + ), + None => Bytes::new(), + }; + + Ok(Some(RecvCompletion { + token, + data, + oneshot: !has_completion, + })) + } + + /// Drain all available completions, calling the provided closure for each. + /// + /// This is a convenience method that repeatedly calls [`poll`](Self::poll) + /// until no more completions are available. + /// + /// # Arguments + /// + /// * `f` - Closure called for each completion with its token and data + /// + /// # Example + /// + /// ```ignore + /// producer.drain(|token, data| { + /// println!("Got {:?}: {} bytes", token, data.len()); + /// })?; + /// ``` + pub fn drain(&mut self, mut f: impl FnMut(Token, Bytes)) -> Result<(), VirtqError> + where + M: Send + 'static, + P: Send + 'static, + { + while let Some(cqe) = self.poll()? { + f(cqe.token, cqe.data); + } + + Ok(()) + } + + /// Begin building a descriptor chain for submission. + /// + /// Returns a [`ChainBuilder`] that allocates buffers from the pool. + /// ``` + pub fn chain(&self) -> ChainBuilder { + ChainBuilder::new(self.inner.mem().clone(), self.pool.clone()) + } + + /// Begin a batch of submissions. + /// + /// Entries submitted through the returned [`SubmitBatch`] are published to + /// the ring immediately, but the consumer is notified at most once when + /// [`SubmitBatch::finish`] is called. This mirrors the virtio pattern of + /// adding multiple buffers and then kicking the queue once. + pub fn batch(&mut self) -> SubmitBatch<'_, M, N, P> { + SubmitBatch::new(self) + } + + /// Submit a [`SendEntry`] to the ring. + /// + /// Publishes the descriptor chain, stores the in-flight tracking state, + /// and notifies the consumer if event suppression allows. Notifications + /// are layout-neutral; use [`batch`](Self::batch) when a higher-level + /// protocol wants to publish multiple entries and kick once. + /// + /// # Errors + /// + /// - [`VirtqError::EntryTooLarge`] - written exceeds entry buffer capacity + /// - [`VirtqError::RingError`] - ring is full + /// - [`VirtqError::InvalidState`] - descriptor ID collision + pub fn submit(&mut self, entry: SendEntry) -> Result { + let cursor_before = self.inner.avail_cursor(); + let token = self.publish(entry)?; + self.notify_since(cursor_before)?; + Ok(token) + } + + fn publish(&mut self, mut entry: SendEntry) -> Result { + let written = entry.written; + let inflight = *entry.inflight.as_ref().ok_or(VirtqError::InvalidState)?; + + let chain = inflight.try_into_chain(written)?; + let id = self.inner.submit_available(&chain)?; + let inflight = entry.inflight.take().ok_or(VirtqError::InvalidState)?; + + let token = Token(self.next_token, id); + self.next_token = self.next_token.wrapping_add(1); + + let slot = self + .inflight + .get_mut(id as usize) + .ok_or(VirtqError::InvalidState)?; + + if slot.is_some() { + return Err(VirtqError::InvalidState); + } + + *slot = Some((token, inflight)); + + Ok(token) + } + + fn notify_since(&mut self, cursor: RingCursor) -> Result { + let should_notify = self.inner.should_notify_since(cursor)?; + if should_notify { + self.notify_now(); + } + Ok(should_notify) + } + + fn notify_now(&self) { + self.notifier.notify(QueueStats { + num_free: self.inner.num_free(), + num_inflight: self.inner.num_inflight(), + }); + } + + /// Signal backpressure to the consumer. + /// + /// Bypasses event suppression. Call this when submit fails with a + /// backpressure error and the consumer needs to drain. + #[inline] + pub fn notify_backpressure(&self) { + self.notify_now(); + } + + /// Get the current used cursor position. + /// + /// Useful for setting up descriptor-based event suppression. + #[inline] + pub fn used_cursor(&self) -> RingCursor { + self.inner.used_cursor() + } + + /// Number of free (unsubmitted) descriptors in the ring. + #[inline] + pub fn num_free(&self) -> usize { + self.inner.num_free() + } + + /// Configure event suppression for used buffer notifications. + /// + /// This controls when the device (consumer) signals us about completed buffers: + /// + /// - [`SuppressionKind::Enable`]: Always signal (default) - good for latency + /// - [`SuppressionKind::Disable`]: Never signal - caller must poll + /// - [`SuppressionKind::Descriptor`]: Signal only at specific cursor position + /// + /// # Example: Completion Batching + /// + /// ```ignore + /// // Submit entries, then suppress notifications until all complete + /// let mut se = producer.chain().entry(64).completion(128).build()?; + /// se.write_all(b"entry1")?; + /// producer.submit(se)?; + /// let cursor = producer.used_cursor(); + /// producer.set_used_suppression(SuppressionKind::Descriptor(cursor))?; + /// // Device will notify only after reaching that cursor position + /// ``` + pub fn set_used_suppression(&mut self, kind: SuppressionKind) -> Result<(), VirtqError> { + match kind { + SuppressionKind::Enable => self.inner.enable_used_notifications()?, + SuppressionKind::Disable => self.inner.disable_used_notifications()?, + SuppressionKind::Descriptor(cursor) => self + .inner + .enable_used_notifications_desc(cursor.head(), cursor.wrap())?, + } + Ok(()) + } + + /// Reset ring, inflight, and pool state to initial values. + /// + /// # Safety + /// + /// All [`RecvCompletion`]s (and their backing [`Bytes`]) from previous `poll()` + /// calls must have been dropped before calling this. Outstanding completions + /// hold pool allocations via `BufferOwner`; resetting the pool while they exist + /// would cause double-free on drop. + /// + /// TODO(virtq): find a way to allow guest to keep completions across resets. + pub fn reset(&mut self) { + self.pool.reset(); + self.inner.reset(); + self.pending.clear(); + self.inflight.fill(None); + } + + /// Replace the pool and reset ring, inflight, and pending state. + /// + /// Use this when restoring from a snapshot where the pool has been + /// relocated or recreated. + /// + /// # Safety + /// + /// Same as [`reset`](Self::reset) - all outstanding completions + /// must have been dropped. + pub fn reset_with_pool(&mut self, pool: P) { + self.pool = pool; + self.inner.reset(); + self.pending.clear(); + self.inflight.fill(None); + } +} + +/// A scoped batch of producer submissions. +/// +/// Submissions are published immediately, while notification is delayed until +/// [`finish`](Self::finish). `finish` is explicit because the event-suppression +/// check can fail; dropping a batch does not notify. +#[must_use = "call finish to notify the consumer about batched submissions"] +pub struct SubmitBatch<'a, M, N, P> { + producer: &'a mut VirtqProducer, + notify_from: Option, +} + +impl<'a, M, N, P> SubmitBatch<'a, M, N, P> +where + M: MemOps + Clone, + N: Notifier, + P: BufferProvider + Clone, +{ + fn new(producer: &'a mut VirtqProducer) -> Self { + Self { + producer, + notify_from: None, + } + } + + /// Begin building a descriptor chain for this batch. + pub fn chain(&self) -> ChainBuilder { + self.producer.chain() + } + + /// Publish an entry as part of this batch without notifying yet. + pub fn submit(&mut self, entry: SendEntry) -> Result { + let cursor_before = self.producer.inner.avail_cursor(); + let token = self.producer.publish(entry)?; + if self.notify_from.is_none() { + self.notify_from = Some(cursor_before); + } + Ok(token) + } + + /// Finish the batch and notify the consumer once if event suppression + /// requires it for the whole published range. + /// + /// Returns `true` if a notification was sent. + pub fn finish(mut self) -> Result { + let Some(notify_from) = self.notify_from.take() else { + return Ok(false); + }; + + self.producer.notify_since(notify_from) + } +} + +/// Snapshot restore support for producers backed by [`RecyclePool`]. +impl VirtqProducer +where + M: MemOps + Clone, + N: Notifier, +{ + /// Replace the pool and reconstruct producer state from a prefilled ring. + /// + /// The host prefills the H2G ring with `min(ring_size, pool_count)` + /// descriptors during restore (`restore_h2g_prefill`), writing + /// descriptors in forward order: position i gets + /// `addr = pool_base + i * slot_size`. + /// + /// Any descriptors already consumed by the host marked used + /// will be discovered naturally by `poll_used()` after restore. + pub fn restore_from_ring(&mut self, pool: RecyclePool) -> Result<(), VirtqError> { + self.reset_with_pool(pool); + + let ring_size = self.inner.len(); + let pool_count = self.pool.count(); + let prefill_count = core::cmp::min(ring_size, pool_count); + let slot_size = self.pool.slot_size(); + + let mut ids = SmallVec::<[u16; 64]>::new(); + + // Scan descriptors to discover in-flight IDs and set up inflight table + for pos in 0..prefill_count as u16 { + let desc_base = self + .inner + .desc_table() + .desc_addr(pos) + .ok_or(VirtqError::RingError(RingError::InvalidState))?; + + let id = self + .inner + .mem() + .read_val::(desc_base + Descriptor::ID_OFFSET as u64) + .map_err(|_| VirtqError::MemoryReadError)?; + + if (id as usize) >= ring_size { + return Err(VirtqError::InvalidState); + } + + if self.inflight[id as usize].is_some() { + return Err(VirtqError::InvalidState); + } + + let addr = self + .pool + .slot_addr(pos as usize) + .ok_or(VirtqError::InvalidState)?; + + let token = Token(self.next_token, id); + self.next_token = self.next_token.wrapping_add(1); + + self.inflight[id as usize] = Some(( + token, + Inflight::WriteOnly { + completion: Allocation { + addr, + len: slot_size, + }, + }, + )); + + ids.push(id); + } + + self.inner.reset_prefilled(&ids); + + let addrs: SmallVec<[u64; 64]> = (0..prefill_count) + .map(|i| self.pool.slot_addr(i).ok_or(VirtqError::InvalidState)) + .collect::>()?; + + self.pool + .restore_allocated(&addrs) + .map_err(|_| VirtqError::InvalidState)?; + + debug_assert!( + self.inflight.iter().filter(|s| s.is_some()).count() == prefill_count, + "restore_from_ring: expected {} inflight entries, found {}", + prefill_count, + self.inflight.iter().filter(|s| s.is_some()).count() + ); + + Ok(()) + } +} + +/// Builder for configuring a descriptor chain's buffer layout. +/// +/// If dropped without building, no resources are leaked (allocations are +/// deferred to [`build`](Self::build)). +#[must_use = "call .build() to create a SendEntry"] +pub struct ChainBuilder { + mem: M, + pool: P, + entry_cap: Option, + cqe_cap: Option, +} + +impl ChainBuilder { + fn new(mem: M, pool: P) -> Self { + Self { + mem, + pool, + entry_cap: None, + cqe_cap: None, + } + } + + fn alloc(&self, size: usize) -> Result, VirtqError> { + Ok(PoolAlloc::allocate(self.pool.clone(), size)?) + } + + /// Request an entry buffer of `cap` bytes. + /// + /// The entry holds data sent from the driver to the consumer (device). + /// The actual allocation is deferred to [`build`](Self::build). + pub fn entry(mut self, cap: usize) -> Self { + self.entry_cap = Some(cap); + self + } + + /// Request a completion buffer of `cap` bytes. + /// + /// The completion buffer is filled by the consumer and returned via + /// [`VirtqProducer::poll`] as [`RecvCompletion`]. + pub fn completion(mut self, cap: usize) -> Self { + self.cqe_cap = Some(cap); + self + } + + /// Allocate buffers and return a [`SendEntry`] for writing. + /// + /// # Errors + /// + /// - [`VirtqError::InvalidState`] - Neither entry nor completion requested + /// - [`VirtqError::Alloc`] - Pool exhausted + pub fn build(self) -> Result, VirtqError> { + if self.entry_cap.is_none() && self.cqe_cap.is_none() { + return Err(VirtqError::InvalidState); + } + + let entry_alloc = self.entry_cap.map(|cap| self.alloc(cap)).transpose()?; + let completion_alloc = self.cqe_cap.map(|cap| self.alloc(cap)).transpose()?; + + let inflight = match (entry_alloc, completion_alloc) { + (Some(entry), Some(cqe)) => Inflight::ReadWrite { + entry: entry.into_raw(), + completion: cqe.into_raw(), + }, + (Some(entry), None) => Inflight::ReadOnly { + entry: entry.into_raw(), + }, + (None, Some(cqe)) => Inflight::WriteOnly { + completion: cqe.into_raw(), + }, + (None, None) => unreachable!(), + }; + + Ok(SendEntry { + mem: self.mem, + pool: self.pool, + inflight: Some(inflight), + written: 0, + }) + } +} + +/// A configured entry ready for writing and submission. +/// +/// Created by [`ChainBuilder::build`]. Write data into the entry buffer +/// with [`write_all`](Self::write_all), +/// or use [`buf_mut`](Self::buf_mut) for zero-copy direct access. +/// Then submit via [`VirtqProducer::submit`]. +/// +/// # Examples +/// +/// ```ignore +/// let mut se = producer.chain().entry(64).completion(128).build()?; +/// se.write_all(b"header")?; +/// se.write_all(b" body")?; +/// let tok = producer.submit(se)?; +/// +/// // Zero-copy direct access +/// let mut se = producer.chain().entry(128).build()?; +/// let buf = se.buf_mut()?; +/// let n = serialize_into(buf); +/// se.set_written(n)?; +/// let tok = producer.submit(se)?; +/// ``` +/// +/// If dropped without submitting, allocated buffers are returned to the pool. +#[must_use = "dropping without submitting deallocates the buffers"] +pub struct SendEntry { + mem: M, + pool: P, + written: usize, + inflight: Option, +} + +impl SendEntry { + fn entry(&self) -> Result { + self.inflight + .as_ref() + .and_then(|i| i.entry()) + .ok_or(VirtqError::NoReadableBuffer) + } + + /// Total entry buffer capacity in bytes. + /// + /// Returns 0 when there are no entry buffers. + pub fn capacity(&self) -> usize { + self.inflight + .as_ref() + .and_then(|i| i.entry()) + .map_or(0, |a| a.len) + } + + /// Number of bytes written so far via [`write_all`](Self::write_all) + /// or [`set_written`](Self::set_written). + pub fn written(&self) -> usize { + self.written + } + + /// Set the write cursor to an explicit byte count. + /// + /// Use this after [`buf_mut`](Self::buf_mut) where you wrote directly + /// into the buffer. The value tells the consumer how many bytes of + /// the entry buffer are valid. + /// + /// # Errors + /// + /// - [`VirtqError::EntryTooLarge`] - `written` exceeds entry buffer capacity + pub fn set_written(&mut self, written: usize) -> Result<(), VirtqError> { + if written > self.capacity() { + return Err(VirtqError::EntryTooLarge); + } + + self.written = written; + Ok(()) + } + + /// Remaining writable capacity in the entry buffer. + pub fn remaining(&self) -> usize { + self.capacity() - self.written + } + + /// Write the entire buffer into the entry. + /// + /// Appends at the current write position. Uses [`MemOps::write`] + /// (volatile on host side). + /// + /// # Errors + /// + /// - [`VirtqError::EntryTooLarge`] - buf exceeds remaining capacity + /// - [`VirtqError::NoReadableBuffer`] - no entry buffer allocated + /// - [`VirtqError::MemoryWriteError`] - underlying write failed + pub fn write_all(&mut self, buf: &[u8]) -> Result<(), VirtqError> { + let alloc = self.entry()?; + + if buf.len() > self.remaining() { + return Err(VirtqError::EntryTooLarge); + } + + let addr = alloc.addr + self.written as u64; + self.mem + .write(addr, buf) + .map_err(|_| VirtqError::MemoryWriteError)?; + + self.written += buf.len(); + Ok(()) + } + + /// Zero-copy access to the full entry buffer in shared memory. + /// + /// Returns `&mut [u8]` pointing directly into the allocated buffer. + /// This is safe on the guest side (producer). After writing, call + /// [`set_written`](Self::set_written) to record how many bytes are valid. + /// + /// **Note**: This bypasses the write cursor. Use either `buf_mut()` + + /// `set_written(n)` or the `write_all` method, not both. + /// + /// # Errors + /// + /// - [`VirtqError::NoReadableBuffer`] - no entry buffer allocated + /// - [`VirtqError::MemoryWriteError`] - failed to access shared memory + pub fn buf_mut(&mut self) -> Result<&mut [u8], VirtqError> { + let alloc = self.entry()?; + unsafe { + self.mem + .as_mut_slice(alloc.addr, alloc.len) + .map_err(|_| VirtqError::MemoryWriteError) + } + } +} + +impl Drop for SendEntry { + fn drop(&mut self) { + let inf = match self.inflight.take() { + Some(i) => i, + None => return, // already submitted + }; + if let Some(a) = inf.entry() { + let _ = self.pool.dealloc(a); + } + if let Some(a) = inf.completion() { + let _ = self.pool.dealloc(a); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::virtq::ring::tests::{OwnedRing, TestMem, make_consumer, make_producer, make_ring}; + use crate::virtq::test_utils::*; + + type RecycleProducer = VirtqProducer; + + const SLOT_SIZE: usize = 4096; + + fn make_recycle_producer(ring: &OwnedRing, slot_count: usize) -> RecycleProducer { + let layout = ring.layout(); + let mem = ring.mem(); + let pool = make_pool(ring, slot_count); + let notifier = TestNotifier::new(); + + VirtqProducer::new(layout, mem, notifier, pool) + } + + fn make_pool(ring: &OwnedRing, slot_count: usize) -> RecyclePool { + let mem = ring.mem(); + let pool_base = mem.base_addr() + Layout::query_size(ring.len()) as u64 + 0x100; + RecyclePool::new(pool_base, slot_count * SLOT_SIZE, SLOT_SIZE).unwrap() + } + + #[derive(Clone)] + struct NoDirectSliceMem(TestMem); + + // SAFETY: Delegates all non-slice memory operations to TestMem. Direct + // slices are intentionally unsupported to exercise producer error handling. + unsafe impl MemOps for NoDirectSliceMem { + type Error = (); + + fn read(&self, addr: u64, dst: &mut [u8]) -> Result<(), Self::Error> { + self.0.read(addr, dst).map_err(|e| match e {}) + } + + fn write(&self, addr: u64, src: &[u8]) -> Result<(), Self::Error> { + self.0.write(addr, src).map_err(|e| match e {}) + } + + fn load_acquire(&self, addr: u64) -> Result { + self.0.load_acquire(addr).map_err(|e| match e {}) + } + + fn store_release(&self, addr: u64, val: u16) -> Result<(), Self::Error> { + self.0.store_release(addr, val).map_err(|e| match e {}) + } + + unsafe fn as_slice(&self, _addr: u64, _len: usize) -> Result<&[u8], Self::Error> { + Err(()) + } + + unsafe fn as_mut_slice(&self, _addr: u64, _len: usize) -> Result<&mut [u8], Self::Error> { + Err(()) + } + } + + #[test] + fn test_chain_readwrite_build() { + let ring = make_ring(16); + let (producer, _consumer, _notifier) = make_test_producer(&ring); + + let se = producer.chain().entry(64).completion(128).build().unwrap(); + assert_eq!(se.capacity(), 64); + assert_eq!(se.written(), 0); + assert_eq!(se.remaining(), 64); + } + + #[test] + fn test_chain_entry_only_build() { + let ring = make_ring(16); + let (producer, _consumer, _notifier) = make_test_producer(&ring); + + let se = producer.chain().entry(32).build().unwrap(); + assert_eq!(se.capacity(), 32); + } + + #[test] + fn test_chain_completion_only_build() { + let ring = make_ring(16); + let (producer, _consumer, _notifier) = make_test_producer(&ring); + + let se = producer.chain().completion(64).build().unwrap(); + assert_eq!(se.capacity(), 0); + } + + #[test] + fn test_chain_empty_build_fails() { + let ring = make_ring(16); + let (producer, _consumer, _notifier) = make_test_producer(&ring); + + let result = producer.chain().build(); + assert!(matches!(result, Err(VirtqError::InvalidState))); + } + + #[test] + fn test_send_entry_write_all_and_submit() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let mut se = producer.chain().entry(64).completion(128).build().unwrap(); + + se.write_all(b"hello").unwrap(); + se.write_all(b" world").unwrap(); + assert_eq!(se.written(), 11); + assert_eq!(se.remaining(), 53); + let tok = producer.submit(se).unwrap(); + + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.token(), tok); + assert_eq!(entry.data().as_ref(), b"hello world"); + consumer.complete(completion).unwrap(); + } + + #[test] + fn test_send_entry_buf_mut() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let mut se = producer.chain().entry(64).completion(128).build().unwrap(); + let buf = se.buf_mut().unwrap(); + assert_eq!(buf.len(), 64); + buf[..5].copy_from_slice(b"hello"); + se.set_written(5).unwrap(); + let _tok = producer.submit(se).unwrap(); + + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.data().as_ref(), b"hello"); + consumer.complete(completion).unwrap(); + } + + #[test] + fn test_send_entry_write_too_large() { + let ring = make_ring(16); + let (producer, _consumer, _notifier) = make_test_producer(&ring); + + let mut se = producer.chain().entry(4).build().unwrap(); + let err = se.write_all(b"too long").unwrap_err(); + assert!(matches!(err, VirtqError::EntryTooLarge)); + } + + #[test] + fn test_writeonly_has_no_entry_buffer() { + let ring = make_ring(16); + let (producer, _consumer, _notifier) = make_test_producer(&ring); + + let mut se = producer.chain().completion(32).build().unwrap(); + let err = se.write_all(b"data").unwrap_err(); + assert!(matches!(err, VirtqError::NoReadableBuffer)); + } + + #[test] + fn test_drop_chain_builder_deallocs() { + let ring = make_ring(16); + let (mut producer, _consumer, _notifier) = make_test_producer(&ring); + + { + let _builder = producer.chain().entry(64).completion(128); + // dropped without build + } + + // Ring should still be fully usable + let se = producer.chain().entry(64).completion(128).build().unwrap(); + let tok = producer.submit(se).unwrap(); + assert!(tok.1 < 16); + } + + #[test] + fn test_drop_send_entry_deallocs() { + let ring = make_ring(16); + let (mut producer, _consumer, _notifier) = make_test_producer(&ring); + + { + let _se = producer.chain().entry(64).completion(128).build().unwrap(); + // dropped without submit + } + + // Ring should still be fully usable + let se = producer.chain().entry(64).completion(128).build().unwrap(); + let tok = producer.submit(se).unwrap(); + assert!(tok.1 < 16); + } + + #[test] + fn test_submit_notifies() { + let ring = make_ring(16); + let (mut producer, _consumer, notifier) = make_test_producer(&ring); + + let initial_count = notifier.notification_count(); + + let mut se = producer.chain().entry(64).completion(128).build().unwrap(); + se.write_all(b"hello").unwrap(); + producer.submit(se).unwrap(); + + assert!(notifier.notification_count() > initial_count); + } + + #[test] + fn test_submit_read_only_notifies_by_default() { + let ring = make_ring(16); + let (mut producer, _consumer, notifier) = make_test_producer(&ring); + + let initial_count = notifier.notification_count(); + + let mut se = producer.chain().entry(64).build().unwrap(); + se.write_all(b"fire-and-forget").unwrap(); + producer.submit(se).unwrap(); + + assert!(notifier.notification_count() > initial_count); + } + + #[test] + fn test_submit_write_only_notifies_by_default() { + let ring = make_ring(16); + let (mut producer, _consumer, notifier) = make_test_producer(&ring); + + let initial_count = notifier.notification_count(); + + let se = producer.chain().completion(128).build().unwrap(); + producer.submit(se).unwrap(); + + assert!(notifier.notification_count() > initial_count); + } + + #[test] + fn test_batch_notifies_once_on_finish() { + let ring = make_ring(16); + let (mut producer, mut consumer, notifier) = make_test_producer(&ring); + + let initial_count = notifier.notification_count(); + + let mut batch = producer.batch(); + + let mut first = batch.chain().entry(64).build().unwrap(); + first.write_all(b"first").unwrap(); + batch.submit(first).unwrap(); + + let mut second = batch.chain().entry(64).build().unwrap(); + second.write_all(b"second").unwrap(); + batch.submit(second).unwrap(); + + assert_eq!(notifier.notification_count(), initial_count); + assert!(batch.finish().unwrap()); + assert_eq!(notifier.notification_count(), initial_count + 1); + + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.data().as_ref(), b"first"); + consumer.complete(completion).unwrap(); + + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.data().as_ref(), b"second"); + consumer.complete(completion).unwrap(); + } + + #[test] + fn test_batch_finish_notifies_from_batch_start_cursor() { + let ring = make_ring(16); + let (mut producer, mut consumer, notifier) = make_test_producer(&ring); + + let cursor = consumer.avail_cursor(); + consumer + .set_avail_suppression(SuppressionKind::Descriptor(cursor)) + .unwrap(); + + let mut batch = producer.batch(); + + let mut first = batch.chain().entry(64).build().unwrap(); + first.write_all(b"first").unwrap(); + batch.submit(first).unwrap(); + assert_eq!(notifier.notification_count(), 0); + + let mut second = batch.chain().entry(64).completion(64).build().unwrap(); + second.write_all(b"second").unwrap(); + batch.submit(second).unwrap(); + + assert!(batch.finish().unwrap()); + + assert_eq!(notifier.notification_count(), 1); + } + + #[test] + fn test_empty_batch_finish_does_not_notify() { + let ring = make_ring(16); + let (mut producer, _consumer, notifier) = make_test_producer(&ring); + + let batch = producer.batch(); + assert!(!batch.finish().unwrap()); + assert_eq!(notifier.notification_count(), 0); + } + + #[test] + fn test_set_written_too_large() { + let ring = make_ring(16); + let (producer, _consumer, _notifier) = make_test_producer(&ring); + + let mut se = producer.chain().entry(32).completion(64).build().unwrap(); + let err = se.set_written(64).unwrap_err(); + assert!(matches!(err, VirtqError::EntryTooLarge)); + } + + #[test] + fn test_write_only_round_trip() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let se = producer.chain().completion(32).build().unwrap(); + let token = producer.submit(se).unwrap(); + + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.token(), token); + assert!(entry.data().is_empty()); + + if let SendCompletion::Writable(mut wc) = completion { + wc.write_all(b"filled-by-consumer").unwrap(); + consumer.complete(wc.into()).unwrap(); + } else { + panic!("expected Writable"); + } + + let cqe = producer.poll().unwrap().unwrap(); + assert_eq!(cqe.token, token); + assert_eq!(cqe.data.len(), b"filled-by-consumer".len()); + assert_eq!(&cqe.data[..], b"filled-by-consumer"); + } + + #[test] + fn test_read_only_round_trip() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let mut se = producer.chain().entry(32).build().unwrap(); + se.write_all(b"fire-and-forget").unwrap(); + let token = producer.submit(se).unwrap(); + + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.token(), token); + assert_eq!(entry.data().as_ref(), b"fire-and-forget"); + assert!(matches!(completion, SendCompletion::Ack(_))); + consumer.complete(completion).unwrap(); + + let cqe = producer.poll().unwrap().unwrap(); + assert_eq!(cqe.token, token); + assert_eq!(cqe.data.len(), 0); + assert!(cqe.data.is_empty()); + } + + #[test] + fn test_readwrite_round_trip() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let mut se = producer.chain().entry(64).completion(128).build().unwrap(); + se.write_all(b"request data").unwrap(); + let token = producer.submit(se).unwrap(); + + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.data().as_ref(), b"request data"); + if let SendCompletion::Writable(mut wc) = completion { + wc.write_all(b"response data").unwrap(); + consumer.complete(wc.into()).unwrap(); + } else { + panic!("expected Writable"); + } + + let cqe = producer.poll().unwrap().unwrap(); + assert_eq!(cqe.token, token); + assert_eq!(&cqe.data[..], b"response data"); + } + + #[test] + fn test_poll_completion_requires_direct_slice() { + let ring = make_ring(16); + let layout = ring.layout(); + let test_mem = ring.mem(); + let pool_base = test_mem.base_addr() + Layout::query_size(ring.len()) as u64 + 0x100; + let pool = TestPool::new(pool_base, 0x8000); + let notifier = TestNotifier::new(); + let mem = NoDirectSliceMem(test_mem); + let mut producer = VirtqProducer::new(layout, mem.clone(), notifier.clone(), pool); + let mut consumer = VirtqConsumer::new(layout, mem, notifier); + + let mut se = producer.chain().entry(64).completion(128).build().unwrap(); + se.write_all(b"request data").unwrap(); + producer.submit(se).unwrap(); + + let (_entry, completion) = consumer.poll(1024).unwrap().unwrap(); + if let SendCompletion::Writable(mut wc) = completion { + wc.write_all(b"response data").unwrap(); + consumer.complete(wc.into()).unwrap(); + } else { + panic!("expected Writable"); + } + + assert!(matches!(producer.poll(), Err(VirtqError::MemoryReadError))); + } + + #[test] + fn test_virtq_producer_reset() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + // Submit and complete a round trip + let mut se = producer.chain().entry(32).completion(64).build().unwrap(); + se.write_all(b"hello").unwrap(); + producer.submit(se).unwrap(); + + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.data().as_ref(), b"hello"); + consumer.complete(completion).unwrap(); + let _ = producer.poll().unwrap().unwrap(); + + // Now reset + producer.reset(); + + // All inflight slots should be None + assert!(producer.inflight.iter().all(|s| s.is_none())); + // Ring state should be back to initial + assert_eq!(producer.inner.num_free(), producer.inner.len()); + } + + #[test] + fn test_virtq_producer_reset_clears_inflight() { + let ring = make_ring(16); + let (mut producer, _consumer, _notifier) = make_test_producer(&ring); + + // Submit without completing + let se = producer.chain().completion(64).build().unwrap(); + producer.submit(se).unwrap(); + + assert!(producer.inflight.iter().any(|s| s.is_some())); + + producer.reset(); + + assert!(producer.inflight.iter().all(|s| s.is_none())); + assert_eq!(producer.inner.num_free(), producer.inner.len()); + } + + #[test] + fn test_restore_from_ring_requires_full_prefill() { + let ring = make_ring(8); + let mut producer = make_recycle_producer(&ring, 8); + + // Ring has no prefilled descriptors - restore should fail + // because IDs read from zeroed memory will all be 0 (duplicate) + assert!(producer.restore_from_ring(make_pool(&ring, 8)).is_err()); + } + + #[test] + fn test_restore_from_ring_partial_prefill_fails() { + let ring = make_ring(8); + let producer = make_recycle_producer(&ring, 8); + let pool_base = producer.pool.base_addr(); + + // Simulate host prefill: write only one descriptor + let mut writer = make_producer(&ring); + writer + .submit_one(pool_base, SLOT_SIZE as u32, true) + .unwrap(); + + // Restore should fail because only 1 of 8 positions has a + // valid unique ID - remaining positions have id=0 (duplicate) + let mut restored = make_recycle_producer(&ring, 8); + assert!(restored.restore_from_ring(make_pool(&ring, 8)).is_err()); + } + + #[test] + fn test_restore_from_ring_full_prefill() { + let depth = 8usize; + let ring = make_ring(depth); + let producer = make_recycle_producer(&ring, depth); + let pool_base = producer.pool.base_addr(); + + // Simulate host prefill: write all descriptors + let mut writer = make_producer(&ring); + for i in 0..depth { + let addr = pool_base + (i * SLOT_SIZE) as u64; + writer.submit_one(addr, SLOT_SIZE as u32, true).unwrap(); + } + + let mut restored = make_recycle_producer(&ring, depth); + restored.restore_from_ring(make_pool(&ring, depth)).unwrap(); + + // All inflight slots should be populated + let inflight_count = restored.inflight.iter().filter(|s| s.is_some()).count(); + assert_eq!(inflight_count, depth); + + // Pool should be fully allocated + assert_eq!(restored.pool.num_free(), 0); + } + + #[test] + fn test_restore_from_ring_forward_order() { + let depth = 4usize; + let ring = make_ring(depth); + let producer = make_recycle_producer(&ring, depth); + let pool_base = producer.pool.base_addr(); + + // Forward order prefill + let mut writer = make_producer(&ring); + for i in 0..depth { + writer + .submit_one(pool_base + (i * SLOT_SIZE) as u64, SLOT_SIZE as u32, true) + .unwrap(); + } + + let mut restored = make_recycle_producer(&ring, depth); + restored.restore_from_ring(make_pool(&ring, depth)).unwrap(); + } + + #[test] + fn test_restore_from_ring_reverse_order() { + let depth = 4usize; + let ring = make_ring(depth); + let producer = make_recycle_producer(&ring, depth); + let pool_base = producer.pool.base_addr(); + + // Reverse order prefill (current host behavior) + let mut writer = make_producer(&ring); + for i in (0..depth).rev() { + writer + .submit_one(pool_base + (i * SLOT_SIZE) as u64, SLOT_SIZE as u32, true) + .unwrap(); + } + + let mut restored = make_recycle_producer(&ring, depth); + restored.restore_from_ring(make_pool(&ring, depth)).unwrap(); + } + + #[test] + fn test_restore_from_ring_pool_state_correct() { + let depth = 8usize; + let ring = make_ring(depth); + let producer = make_recycle_producer(&ring, depth); + let pool_base = producer.pool.base_addr(); + + // Full prefill + let mut writer = make_producer(&ring); + for i in 0..depth { + writer + .submit_one(pool_base + (i * SLOT_SIZE) as u64, SLOT_SIZE as u32, true) + .unwrap(); + } + + let mut restored = make_recycle_producer(&ring, depth); + restored.restore_from_ring(make_pool(&ring, depth)).unwrap(); + // All slots are allocated after full-prefill restore + assert_eq!(restored.pool.num_free(), 0); + } + + #[test] + fn test_restore_from_ring_idempotent() { + let depth = 4usize; + let ring = make_ring(depth); + let producer = make_recycle_producer(&ring, depth); + let pool_base = producer.pool.base_addr(); + + let mut writer = make_producer(&ring); + for i in 0..depth { + writer + .submit_one(pool_base + (i * SLOT_SIZE) as u64, SLOT_SIZE as u32, true) + .unwrap(); + } + + let mut restored = make_recycle_producer(&ring, depth); + restored.restore_from_ring(make_pool(&ring, depth)).unwrap(); + restored.restore_from_ring(make_pool(&ring, depth)).unwrap(); + assert_eq!(restored.pool.num_free(), 0); + } + + #[test] + fn test_restore_from_ring_then_poll_used() { + let depth = 4usize; + let ring = make_ring(depth); + let producer = make_recycle_producer(&ring, depth); + let pool_base = producer.pool.base_addr(); + + // Simulate host prefill + let mut writer = make_producer(&ring); + for i in 0..depth { + writer + .submit_one(pool_base + (i * SLOT_SIZE) as u64, SLOT_SIZE as u32, true) + .unwrap(); + } + + // Restore producer and use ring-level consumer to complete one entry + let mut restored = make_recycle_producer(&ring, depth); + restored.restore_from_ring(make_pool(&ring, depth)).unwrap(); + + // Ring-level consumer reads available descriptors + let mut consumer = make_consumer(&ring); + let (id, chain) = consumer.poll_available().unwrap(); + let writable = chain.writables(); + assert_eq!(writable.len(), 1); + + // Write some data into the writable buffer + let payload = b"test payload"; + consumer.mem().write(writable[0].addr, payload).unwrap(); + consumer.submit_used(id, payload.len() as u32).unwrap(); + + // Producer polls for the completion + let cqe = restored.poll().unwrap().unwrap(); + assert_eq!(&cqe.data[..payload.len()], payload); + + // Pool slot should be returned after data is dropped + drop(cqe); + assert_eq!(restored.pool.num_free(), 1); + } +} diff --git a/src/hyperlight_common/src/virtq/ring.rs b/src/hyperlight_common/src/virtq/ring.rs index 66791d2f0..391a58a0b 100644 --- a/src/hyperlight_common/src/virtq/ring.rs +++ b/src/hyperlight_common/src/virtq/ring.rs @@ -406,7 +406,7 @@ impl RingCursor { /// Advance by n positions using modular arithmetic. #[inline] - fn advance_by(&mut self, n: u16) { + pub(crate) fn advance_by(&mut self, n: u16) { debug_assert!(self.head.checked_add(n).is_some()); let new = self.head + n; let wraps = new / self.size; @@ -910,7 +910,6 @@ impl RingProducer { self.id_num.iter_mut().for_each(|n| *n = 0); self.event_flags_shadow = EventFlags::ENABLE; } - /// Reset the ring to the "N slots submitted, none completed" state. /// /// `ids` contains the descriptor IDs that are in-flight. @@ -918,7 +917,7 @@ impl RingProducer { pub fn reset_prefilled(&mut self, ids: &[u16]) { let size = self.desc_table.len(); let count = ids.len(); - assert!(count <= size); + debug_assert!(count <= size); let wrapped = count >= size; self.avail_cursor.head = if wrapped { 0 } else { count as u16 }; @@ -929,15 +928,11 @@ impl RingProducer { self.id_num.iter_mut().for_each(|n| *n = 0); for &id in ids { - assert!((id as usize) < size); - assert_eq!(self.id_num[id as usize], 0); self.id_num[id as usize] = 1; } self.num_free = size - count; self.id_free.clear(); - self.id_free - .extend((0..size as u16).filter(|id| self.id_num[*id as usize] == 0)); } } @@ -3258,7 +3253,6 @@ pub(crate) mod tests { consumer.reset(); assert_eq!(consumer.num_inflight, 0); } - #[test] fn test_reset_prefilled_sets_cursors() { let ring = make_ring(8); @@ -3300,10 +3294,7 @@ pub(crate) mod tests { assert!(producer.used_cursor.wrap()); assert_eq!(producer.num_free, 4); - assert_eq!(producer.id_free.len(), 4); - for &id in &[0, 1, 2, 4] { - assert!(producer.id_free.contains(&id)); - } + assert!(producer.id_free.is_empty()); // Only the specified IDs are in-flight for &id in &[5, 6, 7, 3] { assert_eq!(producer.id_num[id as usize], 1); @@ -3313,19 +3304,6 @@ pub(crate) mod tests { } } - #[test] - fn test_reset_prefilled_partial_then_submit() { - let ring = make_ring(8); - let mut producer = make_producer(&ring); - producer.reset_prefilled(&[4, 5, 6, 7]); - - let id = producer.submit_one(0x8000, 128, false).unwrap(); - - assert!([0, 1, 2, 3].contains(&id)); - assert_eq!(producer.num_free, 3); - assert_eq!(producer.id_num[id as usize], 1); - } - #[test] fn test_reset_prefilled_then_poll_used() { let ring = make_ring(4); diff --git a/src/hyperlight_common/src/virtq/stream.rs b/src/hyperlight_common/src/virtq/stream.rs new file mode 100644 index 000000000..8dbc7da27 --- /dev/null +++ b/src/hyperlight_common/src/virtq/stream.rs @@ -0,0 +1,579 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Stream table and flow-control primitives for virtqueue byte streams. +//! +//! This module provides the `no_std`-friendly core types used on both +//! the guest and host sides: +//! +//! - [`StreamId`], [`StreamHandle`]: identification and cross-boundary +//! handle shape. +//! - [`StreamDirection`]: G2H vs H2G per-direction ID spaces. +//! - [`StreamTable`]: per-direction table of open streams, with a +//! generation counter (bumped on sandbox reset) that gates stale +//! messages. +//! - Credit accounting: [`WriterCredit`] / [`ReaderCredit`] wrappers +//! plus signed-arithmetic helpers that mirror virtio-vsock's +//! wrap-safe math. +//! +//! Only the data model and arithmetic lives here. Wiring to the +//! producer/consumer virtqueues lives in the guest and host crates +//! (Phase 2/3). + +use alloc::vec::Vec; + +use crate::virtq::msg::{STREAM_GEN_MASK, STREAM_ID_MAX}; + +/// Default initial credit (bytes) a newly-created local stream end +/// advertises to its peer. One BufferPool slot's worth for MVP. +pub const STREAM_INITIAL_CREDIT: u32 = 4096; + +/// Direction a stream flows over the virtqueue pair. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum StreamDirection { + /// Guest produces, host consumes. + Guest2Host, + /// Host produces, guest consumes. + Host2Guest, +} + +/// 12-bit stream id within a given direction and generation. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct StreamId(u16); + +impl StreamId { + /// Construct from a raw 12-bit value. Returns `None` if out of range. + pub const fn from_u16(id: u16) -> Option { + if id > STREAM_ID_MAX { + None + } else { + Some(Self(id)) + } + } + + pub const fn as_u16(self) -> u16 { + self.0 + } +} + +/// Opaque handle identifying a stream across the VM boundary. +/// +/// Carries the minimum information needed to route messages and to +/// reject stale traffic after sandbox reset. The `initial_credit` is +/// the buffer capacity the receiver advertises at handle-transfer time +/// (bootstraps the peer's `buf_alloc`). `None` means "use the default". +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct StreamHandle { + pub direction: StreamDirection, + pub generation: u8, + pub stream_id: StreamId, + pub initial_credit: Option, +} + +impl StreamHandle { + pub const fn new( + direction: StreamDirection, + generation: u8, + stream_id: StreamId, + initial_credit: Option, + ) -> Self { + Self { + direction, + generation, + stream_id, + initial_credit, + } + } +} + +/// Writer-side credit accounting. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub struct WriterCredit { + /// Total bytes this writer has transmitted. + pub tx_cnt: u32, + /// Last `fwd_cnt` advertised by the reader. + pub peer_fwd_cnt: u32, + /// Last `buf_alloc` advertised by the reader. + pub peer_buf_alloc: u32, +} + +impl WriterCredit { + /// Initial writer state. `peer_buf_alloc` should be seeded to the + /// reader's bootstrap credit. + pub const fn new(peer_buf_alloc: u32) -> Self { + Self { + tx_cnt: 0, + peer_fwd_cnt: 0, + peer_buf_alloc, + } + } + + /// Bytes the writer may transmit right now without overrunning the + /// reader's advertised buffer. Uses signed 64-bit arithmetic so + /// that a peer shrinking its buffer below the in-flight count + /// returns 0 instead of underflowing. + pub fn available(&self) -> u32 { + available_credit(self.peer_buf_alloc, self.tx_cnt, self.peer_fwd_cnt) + } + + /// Record that `n` bytes have been placed on the wire. Saturating + /// to avoid spurious panics on pathological peer state; the writer + /// is expected to have clamped `n` to `available()` beforehand. + pub fn record_sent(&mut self, n: u32) { + self.tx_cnt = self.tx_cnt.wrapping_add(n); + } + + /// Absorb a preamble observed from the reader. + /// + /// Only monotonically-newer counters are adopted: `fwd_cnt` uses + /// signed wrap-safe comparison; `buf_alloc` is whatever the reader + /// last said (it can grow or shrink). + pub fn observe_peer(&mut self, peer_fwd_cnt: u32, peer_buf_alloc: u32) { + if wrap_gt(peer_fwd_cnt, self.peer_fwd_cnt) { + self.peer_fwd_cnt = peer_fwd_cnt; + } + self.peer_buf_alloc = peer_buf_alloc; + } +} + +/// Reader-side credit accounting. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub struct ReaderCredit { + /// Total bytes this reader has delivered to the consumer. + pub fwd_cnt: u32, + /// Buffer capacity advertised to the writer. May be lowered by + /// raising it only when a threshold of space has been freed. + pub buf_alloc: u32, +} + +impl ReaderCredit { + pub const fn new(buf_alloc: u32) -> Self { + Self { + fwd_cnt: 0, + buf_alloc, + } + } + + /// Record that `n` bytes have been consumed by the reader. + pub fn record_consumed(&mut self, n: u32) { + self.fwd_cnt = self.fwd_cnt.wrapping_add(n); + } +} + +/// Signed wrap-safe "strictly greater than". Returns `true` if `a` is +/// a newer counter value than `b`, tolerating u32 wraparound. +pub fn wrap_gt(a: u32, b: u32) -> bool { + (a.wrapping_sub(b) as i32) > 0 +} + +/// Available credit (in bytes) given the reader's advertised buffer +/// and the writer's current tx/peer-fwd counters. Implements the +/// vsock-equivalent calculation with signed arithmetic to survive +/// peer_buf_alloc shrinking. +pub fn available_credit(peer_buf_alloc: u32, tx_cnt: u32, peer_fwd_cnt: u32) -> u32 { + let in_flight = tx_cnt.wrapping_sub(peer_fwd_cnt) as i64; + let avail = peer_buf_alloc as i64 - in_flight; + if avail <= 0 { + 0 + } else if avail > u32::MAX as i64 { + u32::MAX + } else { + avail as u32 + } +} + +/// Lifecycle state of a single stream entry. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StreamLifecycle { + /// Open - data and control messages flow normally. + Open, + /// Writer has sent StreamEnd; reader may still drain buffered + /// chunks, but no new data will arrive. + WriterClosed, + /// Reader has sent Cancel; writer must stop producing. + Cancelled, + /// Fully closed; entry will be reaped on the next sweep. + Closed, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum EndRole { + /// This side is the writer for the stream (peer reads). + Writer, + /// This side is the reader for the stream (peer writes). + Reader, +} + +#[derive(Debug)] +struct StreamEntry { + id: StreamId, + #[allow(dead_code)] // Used in later phases for role-based routing. + role: EndRole, + writer: WriterCredit, + reader: ReaderCredit, + state: StreamLifecycle, +} + +/// Per-direction stream table. +/// +/// - Allocates [`StreamId`]s monotonically (no ID reuse within a +/// generation) to avoid races with late Cancel/CreditUpdate messages. +/// - Tracks a `generation` counter bumped on sandbox reset; inbound +/// messages carrying a mismatched generation are considered stale +/// and dropped by the routing layer. +#[derive(Debug)] +pub struct StreamTable { + direction: StreamDirection, + generation: u8, + next_id: u16, + entries: Vec, +} + +/// Errors produced by [`StreamTable`] mutations. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StreamTableError { + /// All stream IDs for the current generation have been used. + IdSpaceExhausted, + /// Referenced stream id is not present in the table. + Unknown(StreamId), +} + +impl StreamTable { + pub fn new(direction: StreamDirection) -> Self { + Self { + direction, + generation: 0, + next_id: 0, + entries: Vec::new(), + } + } + + pub fn direction(&self) -> StreamDirection { + self.direction + } + + pub fn generation(&self) -> u8 { + self.generation + } + + pub fn len(&self) -> usize { + self.entries.iter().filter(|e| !e.is_reaped()).count() + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Bump the generation and drop all open streams. Called on sandbox + /// reset/restore. Returns the new generation. + pub fn reset(&mut self) -> u8 { + self.generation = (self.generation.wrapping_add(1)) & (STREAM_GEN_MASK as u8); + self.next_id = 0; + self.entries.clear(); + self.generation + } + + /// Allocate a new stream id and register it as a locally-owned end. + /// `role_is_writer` indicates whether this side is the writer + /// (peer reads) or reader (peer writes). + fn allocate( + &mut self, + role: EndRole, + local_buf_alloc: u32, + peer_buf_alloc: u32, + ) -> Result { + if self.next_id > STREAM_ID_MAX { + return Err(StreamTableError::IdSpaceExhausted); + } + let id = StreamId(self.next_id); + self.next_id += 1; + self.entries.push(StreamEntry { + id, + role, + writer: WriterCredit::new(peer_buf_alloc), + reader: ReaderCredit::new(local_buf_alloc), + state: StreamLifecycle::Open, + }); + Ok(id) + } + + /// Register a locally-owned writer end. Typically called by the + /// side that is producing into the stream. + pub fn open_writer( + &mut self, + local_buf_alloc: u32, + peer_buf_alloc: u32, + ) -> Result { + self.allocate(EndRole::Writer, local_buf_alloc, peer_buf_alloc) + } + + /// Register a locally-owned reader end. + pub fn open_reader( + &mut self, + local_buf_alloc: u32, + peer_buf_alloc: u32, + ) -> Result { + self.allocate(EndRole::Reader, local_buf_alloc, peer_buf_alloc) + } + + fn find_mut(&mut self, id: StreamId) -> Result<&mut StreamEntry, StreamTableError> { + self.entries + .iter_mut() + .find(|e| e.id == id && !e.is_reaped()) + .ok_or(StreamTableError::Unknown(id)) + } + + fn find(&self, id: StreamId) -> Result<&StreamEntry, StreamTableError> { + self.entries + .iter() + .find(|e| e.id == id && !e.is_reaped()) + .ok_or(StreamTableError::Unknown(id)) + } + + pub fn lifecycle(&self, id: StreamId) -> Result { + self.find(id).map(|e| e.state) + } + + pub fn writer_credit(&self, id: StreamId) -> Result { + self.find(id).map(|e| e.writer) + } + + pub fn reader_credit(&self, id: StreamId) -> Result { + self.find(id).map(|e| e.reader) + } + + /// Apply a preamble observed from the peer for this stream. Updates + /// writer-side peer counters (only fields the peer-as-reader + /// advertises are meaningful). + pub fn observe_peer( + &mut self, + id: StreamId, + peer_fwd_cnt: u32, + peer_buf_alloc: u32, + ) -> Result<(), StreamTableError> { + let entry = self.find_mut(id)?; + entry.writer.observe_peer(peer_fwd_cnt, peer_buf_alloc); + Ok(()) + } + + /// Record that the local writer end has put `n` bytes on the wire. + pub fn record_sent(&mut self, id: StreamId, n: u32) -> Result<(), StreamTableError> { + let entry = self.find_mut(id)?; + entry.writer.record_sent(n); + Ok(()) + } + + /// Record that the local reader end has consumed `n` bytes. + pub fn record_consumed(&mut self, id: StreamId, n: u32) -> Result<(), StreamTableError> { + let entry = self.find_mut(id)?; + entry.reader.record_consumed(n); + Ok(()) + } + + /// Mark the writer as closed (sent StreamEnd). + pub fn mark_writer_closed(&mut self, id: StreamId) -> Result<(), StreamTableError> { + let entry = self.find_mut(id)?; + entry.state = match entry.state { + StreamLifecycle::Open => StreamLifecycle::WriterClosed, + other => other, + }; + Ok(()) + } + + /// Mark the stream as cancelled. + pub fn mark_cancelled(&mut self, id: StreamId) -> Result<(), StreamTableError> { + let entry = self.find_mut(id)?; + entry.state = StreamLifecycle::Cancelled; + Ok(()) + } + + /// Fully close and reap the entry. + pub fn close(&mut self, id: StreamId) -> Result<(), StreamTableError> { + let entry = self.find_mut(id)?; + entry.state = StreamLifecycle::Closed; + Ok(()) + } + + /// Returns true if a message with the given generation should be + /// accepted for this table. + pub fn accepts_generation(&self, generation: u8) -> bool { + generation == self.generation + } +} + +impl StreamEntry { + fn is_reaped(&self) -> bool { + matches!(self.state, StreamLifecycle::Closed) + } + + #[cfg(test)] + fn role(&self) -> EndRole { + self.role + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn available_credit_basic() { + assert_eq!(available_credit(100, 0, 0), 100); + assert_eq!(available_credit(100, 30, 0), 70); + assert_eq!(available_credit(100, 100, 0), 0); + assert_eq!(available_credit(100, 100, 30), 30); + } + + #[test] + fn available_credit_peer_shrink() { + // peer shrinks buf_alloc below in-flight bytes + assert_eq!(available_credit(10, 50, 30), 0); + } + + #[test] + fn available_credit_wrap() { + // tx_cnt has wrapped past peer_fwd_cnt + let tx = 5u32; + let peer_fwd = u32::MAX - 10; + let in_flight = tx.wrapping_sub(peer_fwd); + assert_eq!(in_flight, 16); + assert_eq!(available_credit(100, tx, peer_fwd), 84); + } + + #[test] + fn wrap_gt_near_zero_and_max() { + assert!(wrap_gt(10, 5)); + assert!(!wrap_gt(5, 10)); + // 1 is "newer" than u32::MAX despite smaller absolute value + assert!(wrap_gt(1, u32::MAX)); + assert!(!wrap_gt(u32::MAX, 1)); + } + + #[test] + fn writer_observe_monotonic() { + let mut w = WriterCredit::new(100); + w.observe_peer(10, 200); + assert_eq!(w.peer_fwd_cnt, 10); + assert_eq!(w.peer_buf_alloc, 200); + // Older fwd_cnt is ignored + w.observe_peer(5, 150); + assert_eq!(w.peer_fwd_cnt, 10); + assert_eq!(w.peer_buf_alloc, 150); + } + + #[test] + fn writer_record_sent_and_available() { + let mut w = WriterCredit::new(100); + assert_eq!(w.available(), 100); + w.record_sent(60); + assert_eq!(w.available(), 40); + w.observe_peer(50, 100); + assert_eq!(w.available(), 90); + } + + #[test] + fn table_allocates_and_closes() { + let mut t = StreamTable::new(StreamDirection::Guest2Host); + assert!(t.is_empty()); + let a = t.open_writer(4096, 4096).unwrap(); + let b = t.open_reader(4096, 4096).unwrap(); + assert_ne!(a, b); + assert_eq!(t.len(), 2); + assert_eq!(t.lifecycle(a), Ok(StreamLifecycle::Open)); + t.mark_writer_closed(a).unwrap(); + assert_eq!(t.lifecycle(a), Ok(StreamLifecycle::WriterClosed)); + t.close(a).unwrap(); + assert_eq!(t.lifecycle(a), Err(StreamTableError::Unknown(a))); + assert_eq!(t.len(), 1); + } + + #[test] + fn table_ids_are_not_reused() { + let mut t = StreamTable::new(StreamDirection::Host2Guest); + let a = t.open_writer(4096, 4096).unwrap(); + t.close(a).unwrap(); + let b = t.open_writer(4096, 4096).unwrap(); + assert_ne!(a, b); + } + + #[test] + fn table_reset_bumps_generation_and_drops_entries() { + let mut t = StreamTable::new(StreamDirection::Guest2Host); + let _a = t.open_writer(4096, 4096).unwrap(); + let _b = t.open_reader(4096, 4096).unwrap(); + assert_eq!(t.generation(), 0); + let g = t.reset(); + assert_eq!(g, 1); + assert!(t.is_empty()); + let c = t.open_writer(4096, 4096).unwrap(); + assert_eq!(c, StreamId(0)); + assert!(t.accepts_generation(1)); + assert!(!t.accepts_generation(0)); + } + + #[test] + fn table_reset_wraps_generation() { + let mut t = StreamTable::new(StreamDirection::Guest2Host); + for _ in 0..16 { + t.reset(); + } + assert_eq!(t.generation(), 0); + } + + #[test] + fn table_id_space_exhaustion() { + let mut t = StreamTable::new(StreamDirection::Guest2Host); + // Drain to the end of the 12-bit space. + t.next_id = STREAM_ID_MAX; + let id = t.open_writer(4096, 4096).unwrap(); + assert_eq!(id.as_u16(), STREAM_ID_MAX); + let err = t.open_writer(4096, 4096).unwrap_err(); + assert_eq!(err, StreamTableError::IdSpaceExhausted); + } + + #[test] + fn observe_and_counters_flow() { + let mut t = StreamTable::new(StreamDirection::Guest2Host); + let id = t.open_writer(4096, 1024).unwrap(); + t.record_sent(id, 500).unwrap(); + let credit = t.writer_credit(id).unwrap(); + assert_eq!(credit.tx_cnt, 500); + assert_eq!(credit.available(), 524); + + t.observe_peer(id, 200, 2048).unwrap(); + let credit = t.writer_credit(id).unwrap(); + assert_eq!(credit.peer_fwd_cnt, 200); + assert_eq!(credit.peer_buf_alloc, 2048); + assert_eq!(credit.available(), 2048 - (500 - 200)); + } + + #[test] + fn unknown_ids_are_rejected() { + let mut t = StreamTable::new(StreamDirection::Guest2Host); + let fake = StreamId(7); + assert_eq!(t.lifecycle(fake), Err(StreamTableError::Unknown(fake))); + assert_eq!(t.record_sent(fake, 1), Err(StreamTableError::Unknown(fake))); + } + + #[test] + fn role_is_tracked() { + let mut t = StreamTable::new(StreamDirection::Guest2Host); + let a = t.open_writer(4096, 4096).unwrap(); + let b = t.open_reader(4096, 4096).unwrap(); + assert_eq!(t.find(a).unwrap().role(), EndRole::Writer); + assert_eq!(t.find(b).unwrap().role(), EndRole::Reader); + } +} diff --git a/src/hyperlight_guest/Cargo.toml b/src/hyperlight_guest/Cargo.toml index d9de514ae..3ab158a7e 100644 --- a/src/hyperlight_guest/Cargo.toml +++ b/src/hyperlight_guest/Cargo.toml @@ -15,6 +15,7 @@ Provides only the essential building blocks for interacting with the host enviro anyhow = { version = "1.0.102", default-features = false } serde_json = { version = "1.0", default-features = false, features = ["alloc"] } hyperlight-common = { workspace = true, default-features = false } +bytemuck = { version = "1.24", features = ["derive"] } flatbuffers = { version= "25.12.19", default-features = false } tracing = { version = "0.1.44", default-features = false, features = ["attributes"] } diff --git a/src/hyperlight_guest/src/error.rs b/src/hyperlight_guest/src/error.rs index 0a33bce79..62ca01bda 100644 --- a/src/hyperlight_guest/src/error.rs +++ b/src/hyperlight_guest/src/error.rs @@ -17,8 +17,10 @@ limitations under the License. use alloc::format; use alloc::string::{String, ToString as _}; -pub use hyperlight_common::flatbuffer_wrappers::guest_error::ErrorCode; +pub(crate) use hyperlight_common::flatbuffer_wrappers::guest_error::ErrorCode; +use hyperlight_common::flatbuffer_wrappers::guest_error::GuestError; use hyperlight_common::func::Error as FuncError; +use hyperlight_common::virtq::VirtqError; use {anyhow, serde_json}; pub type Result = core::result::Result; @@ -80,6 +82,24 @@ impl From for HyperlightGuestError { } } +impl From for HyperlightGuestError { + fn from(e: VirtqError) -> Self { + Self { + kind: ErrorCode::GuestError, + message: format!("virtq: {e}"), + } + } +} + +impl From for HyperlightGuestError { + fn from(e: GuestError) -> Self { + Self { + kind: e.code, + message: e.message, + } + } +} + /// Extension trait to add context to `Option` and `Result` types in guest code, /// converting them to `Result`. /// @@ -171,10 +191,10 @@ impl GuestErrorContext for core::result::Result { #[macro_export] macro_rules! bail { ($ec:expr => $($msg:tt)*) => { - return ::core::result::Result::Err($crate::error::HyperlightGuestError::new($ec, ::alloc::format!($($msg)*))); + return ::core::result::Result::Err($crate::error::HyperlightGuestError::new($ec, ::alloc::format!($($msg)*))) }; ($($msg:tt)*) => { - $crate::bail!($crate::error::ErrorCode::GuestError => $($msg)*); + $crate::bail!($crate::error::ErrorCode::GuestError => $($msg)*) }; } diff --git a/src/hyperlight_guest/src/guest_handle/host_comm.rs b/src/hyperlight_guest/src/guest_handle/host_comm.rs index c72de8a3f..10b8e9a7a 100644 --- a/src/hyperlight_guest/src/guest_handle/host_comm.rs +++ b/src/hyperlight_guest/src/guest_handle/host_comm.rs @@ -18,21 +18,13 @@ use alloc::format; use alloc::string::ToString; use alloc::vec::Vec; -use flatbuffers::FlatBufferBuilder; -use hyperlight_common::flatbuffer_wrappers::function_call::{FunctionCall, FunctionCallType}; -use hyperlight_common::flatbuffer_wrappers::function_types::{ - FunctionCallResult, ParameterValue, ReturnType, ReturnValue, -}; use hyperlight_common::flatbuffer_wrappers::guest_error::ErrorCode; use hyperlight_common::flatbuffer_wrappers::guest_log_data::GuestLogData; use hyperlight_common::flatbuffer_wrappers::guest_log_level::LogLevel; -use hyperlight_common::flatbuffer_wrappers::util::estimate_flatbuffer_capacity; -use hyperlight_common::outb::OutBAction; use tracing::instrument; use super::handle::GuestHandle; use crate::error::{HyperlightGuestError, Result}; -use crate::exit::out32; impl GuestHandle { /// Get user memory region as bytes. @@ -59,99 +51,6 @@ impl GuestHandle { } } - /// Get a return value from a host function call. - /// This usually requires a host function to be called first using - /// `call_host_function_internal`. - /// - /// When calling `call_host_function`, this function is called - /// internally to get the return value. - #[instrument(skip_all, level = "Trace")] - pub fn get_host_return_value>(&self) -> Result { - let inner = self - .try_pop_shared_input_data_into::() - .expect("Unable to deserialize a return value from host") - .into_inner(); - - match inner { - Ok(ret) => T::try_from(ret).map_err(|_| { - let expected = core::any::type_name::(); - HyperlightGuestError::new( - ErrorCode::UnsupportedParameterType, - format!("Host return value could not be converted to expected {expected}",), - ) - }), - Err(e) => Err(HyperlightGuestError { - kind: e.code, - message: e.message, - }), - } - } - - pub fn get_host_return_raw(&self) -> Result { - let inner = self - .try_pop_shared_input_data_into::() - .expect("Unable to deserialize a return value from host") - .into_inner(); - - match inner { - Ok(ret) => Ok(ret), - Err(e) => Err(HyperlightGuestError { - kind: e.code, - message: e.message, - }), - } - } - - /// Call a host function without reading its return value from shared mem. - /// This is used by both the Rust and C APIs to reduce code duplication. - /// - /// Note: The function return value must be obtained by calling - /// `get_host_return_value`. - #[instrument(skip_all, level = "Trace")] - pub fn call_host_function_without_returning_result( - &self, - function_name: &str, - parameters: Option>, - return_type: ReturnType, - ) -> Result<()> { - let estimated_capacity = - estimate_flatbuffer_capacity(function_name, parameters.as_deref().unwrap_or(&[])); - - let host_function_call = FunctionCall::new( - function_name.to_string(), - parameters, - FunctionCallType::Host, - return_type, - ); - - let mut builder = FlatBufferBuilder::with_capacity(estimated_capacity); - - let host_function_call_buffer = host_function_call.encode(&mut builder); - self.push_shared_output_data(host_function_call_buffer)?; - - unsafe { - out32(OutBAction::CallFunction as u16, 0); - } - - Ok(()) - } - - /// Call a host function with the given parameters and return type. - /// This function serializes the function call and its parameters, - /// sends it to the host, and then retrieves the return value. - /// - /// The return value is deserialized into the specified type `T`. - #[instrument(skip_all, level = "Info")] - pub fn call_host_function>( - &self, - function_name: &str, - parameters: Option>, - return_type: ReturnType, - ) -> Result { - self.call_host_function_without_returning_result(function_name, parameters, return_type)?; - self.get_host_return_value::() - } - /// Log a message with the specified log level, source, caller, source file, and line number. pub fn log_message( &self, @@ -162,7 +61,7 @@ impl GuestHandle { source_file: &str, line: u32, ) { - // Closure to send log message to host + // Closure to send log message to host via G2H virtqueue let _send_to_host = || { let guest_log_data = GuestLogData::new( message.to_string(), @@ -177,12 +76,10 @@ impl GuestHandle { .try_into() .expect("Failed to convert GuestLogData to bytes"); - self.push_shared_output_data(&bytes) - .expect("Unable to push log data to shared output data"); - - unsafe { - out32(OutBAction::Log as u16, 0); - } + crate::virtq::with_context(|ctx| { + ctx.emit_log(&bytes) + .expect("Unable to send log data via virtq"); + }); }; #[cfg(all(feature = "trace_guest", target_arch = "x86_64"))] diff --git a/src/hyperlight_guest/src/guest_handle/io.rs b/src/hyperlight_guest/src/guest_handle/io.rs deleted file mode 100644 index 46c1d68f6..000000000 --- a/src/hyperlight_guest/src/guest_handle/io.rs +++ /dev/null @@ -1,150 +0,0 @@ -/* -Copyright 2025 The Hyperlight Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -use alloc::format; -use alloc::string::ToString; -use core::any::type_name; -use core::slice::from_raw_parts_mut; - -use hyperlight_common::flatbuffer_wrappers::guest_error::ErrorCode; -use tracing::instrument; - -use super::handle::GuestHandle; -use crate::error::{HyperlightGuestError, Result}; - -impl GuestHandle { - /// Pops the top element from the shared input data buffer and returns it as a T - #[instrument(skip_all, level = "Trace")] - pub fn try_pop_shared_input_data_into(&self) -> Result - where - T: for<'a> TryFrom<&'a [u8]>, - { - let peb_ptr = self.peb().unwrap(); - let input_stack_size = unsafe { (*peb_ptr).input_stack.size as usize }; - let input_stack_ptr = unsafe { (*peb_ptr).input_stack.ptr as *mut u8 }; - - let idb = unsafe { from_raw_parts_mut(input_stack_ptr, input_stack_size) }; - - if idb.is_empty() { - return Err(HyperlightGuestError::new( - ErrorCode::GuestError, - "Got a 0-size buffer in pop_shared_input_data_into".to_string(), - )); - } - - // get relative offset to next free address - let stack_ptr_rel: u64 = - u64::from_le_bytes(idb[..8].try_into().expect("Shared input buffer too small")); - - if stack_ptr_rel as usize > input_stack_size || stack_ptr_rel < 16 { - return Err(HyperlightGuestError::new( - ErrorCode::GuestError, - format!( - "Invalid stack pointer: {} in pop_shared_input_data_into", - stack_ptr_rel - ), - )); - } - - // go back 8 bytes and read. This is the offset to the element on top of stack - let last_element_offset_rel = u64::from_le_bytes( - idb[stack_ptr_rel as usize - 8..stack_ptr_rel as usize] - .try_into() - .expect("Invalid stack pointer in pop_shared_input_data_into"), - ); - - let buffer = &idb[last_element_offset_rel as usize..]; - - // convert the buffer to T - let type_t = match T::try_from(buffer) { - Ok(t) => Ok(t), - Err(_e) => { - return Err(HyperlightGuestError::new( - ErrorCode::GuestError, - format!("Unable to convert buffer to {}", type_name::()), - )); - } - }; - - // update the stack pointer to point to the element we just popped of since that is now free - idb[..8].copy_from_slice(&last_element_offset_rel.to_le_bytes()); - - // zero out popped off buffer - idb[last_element_offset_rel as usize..stack_ptr_rel as usize].fill(0); - - type_t - } - - /// Pushes the given data onto the shared output data buffer. - pub fn push_shared_output_data(&self, data: &[u8]) -> Result<()> { - let peb_ptr = self.peb().unwrap(); - let output_stack_size = unsafe { (*peb_ptr).output_stack.size as usize }; - let output_stack_ptr = unsafe { (*peb_ptr).output_stack.ptr as *mut u8 }; - - let odb = unsafe { from_raw_parts_mut(output_stack_ptr, output_stack_size) }; - - if odb.is_empty() { - return Err(HyperlightGuestError::new( - ErrorCode::GuestError, - "Got a 0-size buffer in push_shared_output_data".to_string(), - )); - } - - // get offset to next free address on the stack - let stack_ptr_rel: u64 = - u64::from_le_bytes(odb[..8].try_into().expect("Shared output buffer too small")); - - // check if the stack pointer is within the bounds of the buffer. - // It can be equal to the size, but never greater - // It can never be less than 8. An empty buffer's stack pointer is 8 - if stack_ptr_rel as usize > output_stack_size || stack_ptr_rel < 8 { - return Err(HyperlightGuestError::new( - ErrorCode::GuestError, - format!( - "Invalid stack pointer: {} in push_shared_output_data", - stack_ptr_rel - ), - )); - } - - // check if there is enough space in the buffer - let size_required = data.len() + 8; // the data plus the pointer pointing to the data - let size_available = output_stack_size - stack_ptr_rel as usize; - if size_required > size_available { - return Err(HyperlightGuestError::new( - ErrorCode::GuestError, - format!( - "Not enough space in shared output buffer. Required: {}, Available: {}", - size_required, size_available - ), - )); - } - - // write the actual data - odb[stack_ptr_rel as usize..stack_ptr_rel as usize + data.len()].copy_from_slice(data); - - // write the offset to the newly written data, to the top of the stack - let bytes: [u8; 8] = stack_ptr_rel.to_le_bytes(); - odb[stack_ptr_rel as usize + data.len()..stack_ptr_rel as usize + data.len() + 8] - .copy_from_slice(&bytes); - - // update stack pointer to point to next free address - let new_stack_ptr_rel: u64 = (stack_ptr_rel as usize + data.len() + 8) as u64; - odb[0..8].copy_from_slice(&(new_stack_ptr_rel).to_le_bytes()); - - Ok(()) - } -} diff --git a/src/hyperlight_guest/src/lib.rs b/src/hyperlight_guest/src/lib.rs index 19e5ac5f2..ed2656a6d 100644 --- a/src/hyperlight_guest/src/lib.rs +++ b/src/hyperlight_guest/src/lib.rs @@ -26,9 +26,9 @@ pub mod exit; pub mod layout; pub mod prim_alloc; pub mod types; +pub mod virtq; pub mod guest_handle { pub mod handle; pub mod host_comm; - pub mod io; } diff --git a/src/hyperlight_guest/src/virtq/context.rs b/src/hyperlight_guest/src/virtq/context.rs new file mode 100644 index 000000000..08b668865 --- /dev/null +++ b/src/hyperlight_guest/src/virtq/context.rs @@ -0,0 +1,471 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Guest virtqueue context. + +use alloc::vec::Vec; +use core::result; +use core::sync::atomic::AtomicU16; +use core::sync::atomic::Ordering::Relaxed; + +use flatbuffers::FlatBufferBuilder; +use hyperlight_common::flatbuffer_wrappers::function_call::{FunctionCall, FunctionCallType}; +use hyperlight_common::flatbuffer_wrappers::function_types::{ + FunctionCallResult, ParameterValue, ReturnType, ReturnValue, +}; +use hyperlight_common::flatbuffer_wrappers::util::estimate_flatbuffer_capacity; +use hyperlight_common::mem::PAGE_SIZE_USIZE; +use hyperlight_common::outb::OutBAction; +use hyperlight_common::virtq::msg::{MsgKind, VirtqMsgHeader}; +use hyperlight_common::virtq::{ + self, BufferPool, Layout, Notifier, QueueStats, RecyclePool, Token, VirtqProducer, +}; + +use super::GuestMemOps; +use crate::bail; +use crate::error::Result; + +static REQUEST_ID: AtomicU16 = AtomicU16::new(0); + +/// Guest-side notifier that triggers a VM exit via outb. +#[derive(Clone, Copy)] +pub struct GuestNotifier; + +impl Notifier for GuestNotifier { + fn notify(&self, _stats: QueueStats) { + unsafe { crate::exit::out32(OutBAction::VirtqNotify as u16, 0) }; + } +} + +/// Type alias for the guest-side G2H producer. +pub type G2hProducer = VirtqProducer; + +/// Type alias for the guest-side H2G producer (uses fixed-size RecyclePool slots). +pub type H2gProducer = VirtqProducer; + +/// Configuration for one queue passed to [`GuestContext::new`]. +pub struct QueueConfig { + /// Ring descriptor layout in shared memory. + pub layout: Layout, + /// Base GVA of the buffer pool region. + pub pool_gva: u64, + /// Number of pages in the buffer pool. + pub pool_pages: usize, +} + +/// Virtqueue runtime state for guest-host communication. +pub struct GuestContext { + /// guest-to-host driver + g2h_producer: G2hProducer, + /// host-to-guest driver + h2g_producer: H2gProducer, + /// H2G slot size in bytes (each prefilled writable descriptor). + h2g_slot_size: usize, + /// snapshot generation counter + generation: u64, + /// Number of H2G requests received that still need a G2H response. + pending_replies: u32, + /// used by cabi + last_host_result: Option>, +} + +impl GuestContext { + /// Create a new context with G2H and H2G queues. + pub fn new(g2h: QueueConfig, h2g: QueueConfig, generation: u64) -> Self { + let size = g2h.pool_pages * PAGE_SIZE_USIZE; + let g2h_pool = + BufferPool::new(g2h.pool_gva, size).expect("failed to create G2H buffer pool"); + let g2h_producer = + VirtqProducer::new(g2h.layout, GuestMemOps, GuestNotifier, g2h_pool.clone()); + + let size = h2g.pool_pages * PAGE_SIZE_USIZE; + let h2g_slot_size = PAGE_SIZE_USIZE; + let h2g_pool = RecyclePool::new(h2g.pool_gva, size, h2g_slot_size) + .expect("failed to create H2G recycle pool"); + let h2g_producer = + VirtqProducer::new(h2g.layout, GuestMemOps, GuestNotifier, h2g_pool.clone()); + + let mut ctx = Self { + g2h_producer, + h2g_producer, + h2g_slot_size, + generation, + pending_replies: 0, + last_host_result: None, + }; + + ctx.prefill_h2g().expect("H2G initial prefill failed"); + ctx + } + + /// Call a host function via the G2H virtqueue. + /// + /// For host functions known to return payloads larger than 4096 bytes, use + /// [`call_host_function_with_hint`](Self::call_host_function_with_hint). + pub fn call_host_function>( + &mut self, + function_name: &str, + parameters: Option>, + return_type: ReturnType, + ) -> Result { + let hint = if matches!(return_type, ReturnType::String | ReturnType::VecBytes) { + BufferPool::upper_slot_size() + } else { + BufferPool::lower_slot_size() + }; + self.call_host_function_with_hint(function_name, parameters, return_type, hint) + } + + /// Call a host function with an explicit response capacity hint. + /// + /// `resp_hint` is the total completion buffer size in bytes + /// (including wire overhead: VirtqMsgHeader + FlatBuffer framing). + /// The BufferPool allocates multiple adjacent slots when the hint + /// exceeds a single slot size, so this is zero-copy for the host. + pub fn call_host_function_with_hint>( + &mut self, + function_name: &str, + parameters: Option>, + return_type: ReturnType, + resp_hint: usize, + ) -> Result { + let params = parameters.as_deref().unwrap_or_default(); + let estimated_capacity = estimate_flatbuffer_capacity(function_name, params); + + let fc = FunctionCall::new( + function_name.into(), + parameters, + FunctionCallType::Host, + return_type, + ); + + let mut builder = FlatBufferBuilder::with_capacity(estimated_capacity); + let payload = fc.encode(&mut builder); + + let reqid = REQUEST_ID.fetch_add(1, Relaxed); + let msg = VirtqMsgHeader::new(MsgKind::Request, reqid, payload.len() as u32); + let hdr = bytemuck::bytes_of(&msg); + + let entry_len = VirtqMsgHeader::SIZE + payload.len(); + + // Reply guard: readwrite chains use 2 descriptors, leave room for pending replies. + self.ensure_reply_capacity(2)?; + + let token = match self.try_send_readwrite(hdr, payload, entry_len, resp_hint) { + Ok(tok) => tok, + Err(e) if e.is_transient() => { + self.g2h_producer.notify_backpressure(); + + if let Err(err) = self.g2h_producer.reclaim() { + bail!("G2H reclaim: {err}"); + } + + let Ok(tok) = self.try_send_readwrite(hdr, payload, entry_len, resp_hint) else { + bail!("G2H call retry"); + }; + + tok + } + Err(e) => bail!("G2H call: {e}"), + }; + + // Poll completions, skipping earlier entries like log acks + // until we find the completion matching our request token. + let completion = loop { + let Some(cqe) = self.g2h_producer.poll()? else { + bail!("G2H: no completion received"); + }; + if cqe.token == token { + break cqe; + } + }; + + let result_bytes = &completion.data; + if result_bytes.len() < VirtqMsgHeader::SIZE { + bail!("G2H: response too short for header"); + } + + let payload_bytes = &result_bytes[VirtqMsgHeader::SIZE..]; + let Ok(fcr) = FunctionCallResult::try_from(payload_bytes) else { + bail!("G2H: malformed response"); + }; + + let ret = fcr.into_inner()?; + let Ok(ret) = T::try_from(ret) else { + bail!("G2H: host return value type mismatch"); + }; + + Ok(ret) + } + + /// Receive a host-to-guest function call from the H2G queue. + /// + /// Each descriptor carries a [`VirtqMsgHeader`] with `payload_len` for + /// that chunk. If [`MsgFlags::MORE`](hyperlight_common::virtq::msg::MsgFlags::MORE) + /// is set, more descriptors follow. + /// + /// Increments the reply guard counter so that subsequent G2H sends + /// reserve capacity for the response. + pub fn recv_h2g_call(&mut self) -> Result { + let Some(first) = self.h2g_producer.poll()? else { + bail!("H2G: no pending call"); + }; + + let data = &first.data; + if data.len() < VirtqMsgHeader::SIZE { + bail!("H2G: completion too short for header"); + } + + let hdr: &VirtqMsgHeader = bytemuck::from_bytes(&data[..VirtqMsgHeader::SIZE]); + + if hdr.msg_kind() != Ok(MsgKind::Request) { + bail!("H2G: unexpected message kind: 0x{:02x}", hdr.kind); + } + + let chunk_len = hdr.payload_len as usize; + + // Track that we owe a response on G2H. + self.pending_replies = self.pending_replies.saturating_add(1); + + if !hdr.has_more() { + // Single-descriptor fast path + let payload = &data[VirtqMsgHeader::SIZE..VirtqMsgHeader::SIZE + chunk_len]; + let fc = FunctionCall::try_from(payload)?; + return Ok(fc); + } + + // Multi-descriptor: accumulate payload until MsgFlags::MORE is cleared + let mut assembled = Vec::with_capacity(chunk_len * 2); + assembled.extend_from_slice(&data[VirtqMsgHeader::SIZE..VirtqMsgHeader::SIZE + chunk_len]); + + loop { + let Some(next) = self.h2g_producer.poll()? else { + bail!("H2G: expected continuation descriptor, none available"); + }; + + let next_data = &next.data; + if next_data.len() < VirtqMsgHeader::SIZE { + bail!("H2G: continuation too short for header"); + } + + let next_hdr: &VirtqMsgHeader = + bytemuck::from_bytes(&next_data[..VirtqMsgHeader::SIZE]); + + let next_chunk = next_hdr.payload_len as usize; + + assembled.extend_from_slice( + &next_data[VirtqMsgHeader::SIZE..VirtqMsgHeader::SIZE + next_chunk], + ); + + if !next_hdr.has_more() { + break; + } + } + + let fc = FunctionCall::try_from(assembled.as_slice())?; + Ok(fc) + } + + /// Send the result of a host-to-guest call back to the host via the + /// G2H queue, then refill H2G descriptor slots until the ring is full. + /// + /// Decrements the reply guard counter after a successful send. + pub fn send_h2g_result(&mut self, payload: &[u8]) -> Result<()> { + self.send_g2h_oneshot(MsgKind::Response, payload)?; + self.pending_replies = self.pending_replies.saturating_sub(1); + self.prefill_h2g() + } + + /// Restore the H2G producer after snapshot restore. + /// + /// Creates a new [`RecyclePool`] at `pool_gva` and calls + /// [`restore_from_ring`] to reconstruct inflight state + /// from the host's prefilled descriptors. + pub fn restore_h2g(&mut self, pool_gva: u64, pool_size: usize) { + let pool = RecyclePool::new(pool_gva, pool_size, self.h2g_slot_size) + .expect("H2G RecyclePool creation failed"); + + self.h2g_producer + .restore_from_ring(pool) + .expect("H2G restore_from_ring failed"); + } + + /// Reset the G2H producer with a fresh pool. + /// + /// Creates a new [`BufferPool`] at `pool_gva` and resets the + /// producer to its initial state. + pub fn reset_g2h(&mut self, pool_gva: u64, pool_size: usize) { + let pool = BufferPool::new(pool_gva, pool_size).expect("G2H BufferPool creation failed"); + self.g2h_producer.reset_with_pool(pool); + self.last_host_result = None; + } + + /// Send a log message via the G2H queue. Fire-and-forget. + pub fn emit_log(&mut self, log_data: &[u8]) -> Result<()> { + self.send_g2h_oneshot(MsgKind::Log, log_data) + } + + /// Get the current generation counter. + pub fn generation(&self) -> u64 { + self.generation + } + + /// Set the generation counter after snapshot restore. + pub fn set_generation(&mut self, generation: u64) { + self.generation = generation; + } + + /// Stash a host function result for later retrieval. + /// + /// Used by the C API's two-step calling convention where + /// `hl_call_host_function` and `hl_get_host_return_value_as_*` + /// are separate calls. + pub fn stash_host_result(&mut self, result: Result) { + self.last_host_result = Some(result); + } + + /// Take the stashed host return value. + /// + /// Panics if no value was stashed or if the type conversion fails. + /// If the stashed result was an error, panics with the error message. + pub fn take_host_return>(&mut self) -> T { + let val = self + .last_host_result + .take() + .expect("No host return value available") + .expect("Host function returned an error"); + + match T::try_from(val) { + Ok(v) => v, + Err(_) => panic!("Host return value type mismatch"), + } + } + + /// Pre-fill the H2G queue with completion-only descriptors so the host + /// can write incoming call payloads into them. + fn prefill_h2g(&mut self) -> Result<()> { + let mut batch = self.h2g_producer.batch(); + + loop { + let entry = match batch.chain().completion(self.h2g_slot_size).build() { + Ok(e) => e, + Err(e) if e.is_transient() => { + batch.finish()?; + return Ok(()); + } + Err(e) => bail!("H2G prefill build: {e}"), + }; + + match batch.submit(entry) { + Ok(_) => {} + Err(e) if e.is_transient() => { + batch.finish()?; + return Ok(()); + } + Err(e) => bail!("H2G prefill submit: {e}"), + } + } + } + + /// Ensure the G2H ring has enough free descriptors to accommodate + /// both the requested send (`need_descs`) and all pending replies. + fn ensure_reply_capacity(&mut self, need_descs: usize) -> Result<()> { + let reserved = self.pending_replies as usize; + loop { + let free = self.g2h_producer.num_free(); + if free >= need_descs + reserved { + return Ok(()); + } + + self.g2h_producer.notify_backpressure(); + let reclaimed = self.g2h_producer.reclaim()?; + if reclaimed == 0 { + // No progress - host hasn't completed any entries yet. + // Fall through and let the send path handle backpressure + // via its own retry logic. + return Ok(()); + } + } + } + + /// Send a one-way message on the G2H queue ReadOnly and no completion. + /// + /// For non-response sends, the reply guard is checked first to + /// ensure enough G2H capacity is reserved for pending replies. + fn send_g2h_oneshot(&mut self, kind: MsgKind, payload: &[u8]) -> Result<()> { + let reqid = REQUEST_ID.fetch_add(1, Relaxed); + let hdr = VirtqMsgHeader::new(kind, reqid, payload.len() as u32); + let hdr_bytes = bytemuck::bytes_of(&hdr); + let entry_len = VirtqMsgHeader::SIZE + payload.len(); + + // Reply guard: non-response sends must leave room for pending replies. + if kind != MsgKind::Response { + self.ensure_reply_capacity(1)?; + } + + // First attempt + match self.try_send_readonly(hdr_bytes, payload, entry_len) { + Ok(_) => return Ok(()), + Err(virtq::VirtqError::Backpressure) => { + // VM exit so host drains and completes G2H entries. + self.g2h_producer.notify_backpressure(); + } + Err(e) => bail!("G2H oneshot: {e}"), + } + + // Reclaim ring/pool resources from completed entries. + if let Err(e) = self.g2h_producer.reclaim() { + bail!("G2H oneshot retry: {e}"); + } + // Retry after backpressure + match self.try_send_readonly(hdr_bytes, payload, entry_len) { + Ok(_) => Ok(()), + Err(e) => bail!("G2H oneshot retry: {e}"), + } + } + + fn try_send_readonly( + &mut self, + header: &[u8], + payload: &[u8], + entry_len: usize, + ) -> result::Result { + let mut entry = self.g2h_producer.chain().entry(entry_len).build()?; + + entry.write_all(header)?; + entry.write_all(payload)?; + self.g2h_producer.submit(entry) + } + + fn try_send_readwrite( + &mut self, + header: &[u8], + payload: &[u8], + entry_len: usize, + completion_cap: usize, + ) -> result::Result { + let mut entry = self + .g2h_producer + .chain() + .entry(entry_len) + .completion(completion_cap) + .build()?; + + entry.write_all(header)?; + entry.write_all(payload)?; + self.g2h_producer.submit(entry) + } +} diff --git a/src/hyperlight_guest/src/virtq/mem.rs b/src/hyperlight_guest/src/virtq/mem.rs new file mode 100644 index 000000000..16375c868 --- /dev/null +++ b/src/hyperlight_guest/src/virtq/mem.rs @@ -0,0 +1,61 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Guest-side [`MemOps`] implementation for virtqueue access. + +use core::convert::Infallible; +use core::sync::atomic::{AtomicU16, Ordering}; +use core::{ptr, slice}; + +use hyperlight_common::virtq::MemOps; + +/// Guest-side memory accessor for virtqueue operations. Treats virtq +/// addresses as guest virtual addresses that map directly to memory. +#[derive(Clone, Copy, Debug)] +pub struct GuestMemOps; + +// SAFETY: GuestMemOps treats virtqueue addresses as directly mapped guest +// virtual addresses and performs the required acquire/release operations. +unsafe impl MemOps for GuestMemOps { + type Error = Infallible; + + fn read(&self, addr: u64, dst: &mut [u8]) -> Result<(), Self::Error> { + unsafe { ptr::copy_nonoverlapping(addr as *const u8, dst.as_mut_ptr(), dst.len()) }; + Ok(()) + } + + fn write(&self, addr: u64, src: &[u8]) -> Result<(), Self::Error> { + unsafe { ptr::copy_nonoverlapping(src.as_ptr(), addr as *mut u8, src.len()) }; + Ok(()) + } + + fn load_acquire(&self, addr: u64) -> Result { + Ok(unsafe { (*(addr as *const AtomicU16)).load(Ordering::Acquire) }) + } + + fn store_release(&self, addr: u64, val: u16) -> Result<(), Self::Error> { + unsafe { (*(addr as *const AtomicU16)).store(val, Ordering::Release) }; + Ok(()) + } + + unsafe fn as_slice(&self, addr: u64, len: usize) -> Result<&[u8], Self::Error> { + Ok(unsafe { slice::from_raw_parts(addr as *const u8, len) }) + } + + unsafe fn as_mut_slice(&self, addr: u64, len: usize) -> Result<&mut [u8], Self::Error> { + Ok(unsafe { slice::from_raw_parts_mut(addr as *mut u8, len) }) + } +} diff --git a/src/hyperlight_guest/src/virtq/mod.rs b/src/hyperlight_guest/src/virtq/mod.rs new file mode 100644 index 000000000..119fa2175 --- /dev/null +++ b/src/hyperlight_guest/src/virtq/mod.rs @@ -0,0 +1,80 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Guest-side virtqueue support. +//! +//! Global context is installed once via [`set_global_context`] and +//! accessed via [`with_context`]. + +pub mod context; +pub mod mem; + +use core::cell::RefCell; +use core::sync::atomic::{AtomicU8, Ordering}; + +use context::GuestContext; +pub use mem::GuestMemOps; + +// Init state machine +const UNINITIALIZED: u8 = 0; +const INITIALIZED: u8 = 1; + +static INIT_STATE: AtomicU8 = AtomicU8::new(UNINITIALIZED); +static GLOBAL_CONTEXT: SyncWrap>> = SyncWrap(RefCell::new(None)); + +// Sync wrapper for the global context. +struct SyncWrap(T); +/// SAFETY: The guest is single-threaded. +unsafe impl Sync for SyncWrap {} + +/// Check if the global context has been initialized. +pub fn is_initialized() -> bool { + INIT_STATE.load(Ordering::Acquire) == INITIALIZED +} + +/// Access the global guest context via closure. +/// +/// # Panics +/// +/// Panics if the context has not been initialized or re-entranted. +pub fn with_context(f: impl FnOnce(&mut GuestContext) -> R) -> R { + assert!( + INIT_STATE.load(Ordering::Acquire) == INITIALIZED, + "guest context not initialized" + ); + let mut borrow = GLOBAL_CONTEXT.0.borrow_mut(); + f(borrow.as_mut().unwrap()) +} + +/// Install the global guest context. Called once during guest init. +/// +/// # Panics +/// +/// Panics if called more than once. +pub fn set_global_context(ctx: GuestContext) { + if INIT_STATE + .compare_exchange( + UNINITIALIZED, + INITIALIZED, + Ordering::SeqCst, + Ordering::SeqCst, + ) + .is_err() + { + panic!("guest context already initialized"); + } + unsafe { *GLOBAL_CONTEXT.0.as_ptr() = Some(ctx) }; +} diff --git a/src/hyperlight_guest_bin/Cargo.toml b/src/hyperlight_guest_bin/Cargo.toml index 060f67b88..0e4509865 100644 --- a/src/hyperlight_guest_bin/Cargo.toml +++ b/src/hyperlight_guest_bin/Cargo.toml @@ -13,8 +13,10 @@ and third-party code used by our C-API needed to build a native hyperlight-guest """ [features] -default = ["libc", "macros"] +default = ["libc", "printf", "macros", "virtq"] libc = ["dep:hyperlight-libc"] # compile libc from picolibc +printf = [ "libc" ] # compile printf +virtq = [] # use virtqueue for guest-to-host calls trace_guest = ["hyperlight-common/trace_guest", "hyperlight-guest/trace_guest", "hyperlight-guest-tracing/trace"] mem_profile = ["hyperlight-common/mem_profile"] macros = ["dep:hyperlight-guest-macro", "dep:linkme"] diff --git a/src/hyperlight_guest_bin/bindgen_wrapper.h b/src/hyperlight_guest_bin/bindgen_wrapper.h new file mode 100644 index 000000000..404384514 --- /dev/null +++ b/src/hyperlight_guest_bin/bindgen_wrapper.h @@ -0,0 +1,8 @@ +/* Bindgen wrapper for picolibc types used by hyperlight guest */ + +/* Enable POSIX clock definitions that picolibc guards behind __rtems__ */ +#define _POSIX_MONOTONIC_CLOCK 200112L + +#include +#include +#include diff --git a/src/hyperlight_guest_bin/src/guest_function/call.rs b/src/hyperlight_guest_bin/src/guest_function/call.rs index 82874c659..105053932 100644 --- a/src/hyperlight_guest_bin/src/guest_function/call.rs +++ b/src/hyperlight_guest_bin/src/guest_function/call.rs @@ -21,11 +21,11 @@ use flatbuffers::FlatBufferBuilder; use hyperlight_common::flatbuffer_wrappers::function_call::{FunctionCall, FunctionCallType}; use hyperlight_common::flatbuffer_wrappers::function_types::{FunctionCallResult, ParameterType}; use hyperlight_common::flatbuffer_wrappers::guest_error::{ErrorCode, GuestError}; -use hyperlight_guest::bail; use hyperlight_guest::error::{HyperlightGuestError, Result}; +use hyperlight_guest::{bail, virtq}; use tracing::instrument; -use crate::{GUEST_HANDLE, REGISTERED_GUEST_FUNCTIONS}; +use crate::REGISTERED_GUEST_FUNCTIONS; core::arch::global_asm!( ".weak guest_dispatch_function", @@ -98,30 +98,31 @@ pub(crate) fn internal_dispatch_function() { tracing::span!(tracing::Level::INFO, "internal_dispatch_function").entered() }; - let handle = unsafe { GUEST_HANDLE }; + // After snapshot restore, the ring memory is zeroed but the + // producer's cursors are stale. Check once per dispatch entry. + crate::virtq::maybe_reset_virtqueues(); - let function_call = handle - .try_pop_shared_input_data_into::() - .expect("Function call deserialization failed"); + let function_call = virtq::with_context(|ctx| { + ctx.recv_h2g_call() + .expect("H2G: expected a host-to-guest call") + }); let res = call_guest_function(function_call); - match res { - Ok(bytes) => { - handle - .push_shared_output_data(bytes.as_slice()) - .expect("Failed to serialize function call result"); - } + let res_bytes = match res { + Ok(bytes) => bytes, Err(err) => { let guest_error = Err(GuestError::new(err.kind, err.message)); let fcr = FunctionCallResult::new(guest_error); let mut builder = FlatBufferBuilder::new(); - let data = fcr.encode(&mut builder); - handle - .push_shared_output_data(data) - .expect("Failed to serialize function call result"); + fcr.encode(&mut builder).to_vec() } - } + }; + + virtq::with_context(|ctx| { + ctx.send_h2g_result(&res_bytes) + .expect("H2G: failed to send result"); + }); // All this tracing logic shall be done right before the call to `hlt` which is done after this // function returns diff --git a/src/hyperlight_guest_bin/src/host_comm.rs b/src/hyperlight_guest_bin/src/host_comm.rs index 301462313..18cc458b9 100644 --- a/src/hyperlight_guest_bin/src/host_comm.rs +++ b/src/hyperlight_guest_bin/src/host_comm.rs @@ -14,8 +14,10 @@ See the License for the specific language governing permissions and limitations under the License. */ -use alloc::string::ToString; +use alloc::string::{String, ToString}; use alloc::vec::Vec; +use core::ffi::{CStr, c_char}; +use core::mem; use hyperlight_common::flatbuffer_wrappers::function_call::FunctionCall; use hyperlight_common::flatbuffer_wrappers::function_types::{ @@ -25,9 +27,15 @@ use hyperlight_common::flatbuffer_wrappers::guest_error::ErrorCode; use hyperlight_common::flatbuffer_wrappers::util::get_flatbuffer_result; use hyperlight_common::func::{ParameterTuple, SupportedReturnType}; use hyperlight_guest::error::{HyperlightGuestError, Result}; +use hyperlight_guest::virtq; +use tracing::instrument; + +const BUFFER_SIZE: usize = 1000; +static mut MESSAGE_BUFFER: Vec = Vec::new(); use crate::GUEST_HANDLE; +#[instrument(skip_all, level = "Info")] pub fn call_host_function( function_name: &str, parameters: Option>, @@ -36,34 +44,32 @@ pub fn call_host_function( where T: TryFrom, { - let handle = unsafe { GUEST_HANDLE }; - handle.call_host_function::(function_name, parameters, return_type) -} - -pub fn call_host(function_name: impl AsRef, args: impl ParameterTuple) -> Result -where - T: SupportedReturnType + TryFrom, -{ - call_host_function::(function_name.as_ref(), Some(args.into_value()), T::TYPE) + virtq::with_context(|ctx| ctx.call_host_function(function_name, parameters, return_type)) } -pub fn call_host_function_without_returning_result( +/// Call a host function with an explicit response capacity hint. +/// +/// `response_hint` is the total completion buffer size in bytes including wire overhead. +/// Use this when you know the host function returns a large payload (e.g., >4096 bytes). +pub fn call_host_function_with_hint( function_name: &str, parameters: Option>, return_type: ReturnType, -) -> Result<()> { - let handle = unsafe { GUEST_HANDLE }; - handle.call_host_function_without_returning_result(function_name, parameters, return_type) -} - -pub fn get_host_return_value_raw() -> Result { - let handle = unsafe { GUEST_HANDLE }; - handle.get_host_return_raw() + response_hint: usize, +) -> Result +where + T: TryFrom, +{ + virtq::with_context(|ctx| { + ctx.call_host_function_with_hint(function_name, parameters, return_type, response_hint) + }) } -pub fn get_host_return_value>() -> Result { - let handle = unsafe { GUEST_HANDLE }; - handle.get_host_return_value::() +pub fn call_host(function_name: impl AsRef, args: impl ParameterTuple) -> Result +where + T: SupportedReturnType + TryFrom, +{ + call_host_function::(function_name.as_ref(), Some(args.into_value()), T::TYPE) } pub fn read_n_bytes_from_user_memory(num: u64) -> Result> { @@ -76,9 +82,8 @@ pub fn read_n_bytes_from_user_memory(num: u64) -> Result> { /// This function requires memory to be setup to be used. In particular, the /// existence of the input and output memory regions. pub fn print_output_with_host_print(function_call: FunctionCall) -> Result> { - let handle = unsafe { GUEST_HANDLE }; if let ParameterValue::String(message) = function_call.parameters.unwrap().remove(0) { - let res = handle.call_host_function::( + let res = call_host_function::( "HostPrint", Some(Vec::from(&[ParameterValue::String(message)])), ReturnType::Int, @@ -92,3 +97,45 @@ pub fn print_output_with_host_print(function_call: FunctionCall) -> Result( + "HostPrint", + Some(Vec::from(&[ParameterValue::String(str)])), + ReturnType::Int, + ) + .expect("Failed to call HostPrint"); + + // Clear the buffer after sending + message_buffer.clear(); + } +} diff --git a/src/hyperlight_guest_bin/src/lib.rs b/src/hyperlight_guest_bin/src/lib.rs index 450b54930..88b449267 100644 --- a/src/hyperlight_guest_bin/src/lib.rs +++ b/src/hyperlight_guest_bin/src/lib.rs @@ -53,6 +53,7 @@ pub mod host_comm; pub mod memory; #[cfg(target_arch = "x86_64")] pub mod paging; +mod virtq; /// Bridge between picolibc's POSIX expectations and the Hyperlight host. /// cbindgen:ignore @@ -256,6 +257,9 @@ pub(crate) extern "C" fn generic_init( OS_PAGE_SIZE = ops as u32; } + // Initialize virtqueues + virtq::init_virtqueues(); + // set up the logger let guest_log_level_filter = GuestLogFilter::try_from(max_log_level).expect("Invalid log level"); diff --git a/src/hyperlight_guest_bin/src/virtq.rs b/src/hyperlight_guest_bin/src/virtq.rs new file mode 100644 index 000000000..45b207214 --- /dev/null +++ b/src/hyperlight_guest_bin/src/virtq.rs @@ -0,0 +1,123 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Guest-side virtqueue initialization and reset. + +use core::num::NonZeroU16; + +use hyperlight_common::layout::{ + SCRATCH_TOP_G2H_POOL_PAGES_OFFSET, SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET, + SCRATCH_TOP_G2H_RING_GVA_OFFSET, SCRATCH_TOP_H2G_POOL_GVA_OFFSET, + SCRATCH_TOP_H2G_POOL_PAGES_OFFSET, SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET, + SCRATCH_TOP_H2G_RING_GVA_OFFSET, SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET, scratch_top_ptr, +}; +use hyperlight_common::mem::PAGE_SIZE_USIZE; +use hyperlight_common::virtq::Layout as VirtqLayout; +use hyperlight_guest::prim_alloc::alloc_phys_pages; +use hyperlight_guest::virtq::context::{GuestContext, QueueConfig}; + +use crate::paging::phys_to_virt; + +/// Initialize virtqueue context. +pub(crate) fn init_virtqueues() { + let g2h_gva = unsafe { *scratch_top_ptr::(SCRATCH_TOP_G2H_RING_GVA_OFFSET) }; + let g2h_depth = unsafe { *scratch_top_ptr::(SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET) }; + let h2g_gva = unsafe { *scratch_top_ptr::(SCRATCH_TOP_H2G_RING_GVA_OFFSET) }; + let h2g_depth = unsafe { *scratch_top_ptr::(SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET) }; + let g2h_pages = unsafe { *scratch_top_ptr::(SCRATCH_TOP_G2H_POOL_PAGES_OFFSET) } as usize; + let h2g_pages = unsafe { *scratch_top_ptr::(SCRATCH_TOP_H2G_POOL_PAGES_OFFSET) } as usize; + let generation = unsafe { *scratch_top_ptr::(SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET) }; + + assert!(g2h_depth > 0 && h2g_depth > 0 && g2h_pages > 0 && h2g_pages > 0); + assert!(g2h_gva != 0 && h2g_gva != 0); + + // Zero ring memory + let g2h_ring_size = VirtqLayout::query_size(g2h_depth as usize); + unsafe { core::ptr::write_bytes(g2h_gva as *mut u8, 0, g2h_ring_size) }; + + let h2g_ring_size = VirtqLayout::query_size(h2g_depth as usize); + unsafe { core::ptr::write_bytes(h2g_gva as *mut u8, 0, h2g_ring_size) }; + + // Build ring layouts + let nz = NonZeroU16::new(g2h_depth).expect("G2H depth zero"); + let g2h_layout = unsafe { VirtqLayout::from_base(g2h_gva, nz) }.expect("invalid layout"); + + let nz = NonZeroU16::new(h2g_depth).expect("H2G depth zero"); + let h2g_layout = unsafe { VirtqLayout::from_base(h2g_gva, nz) }.expect("invalid layout"); + + // Allocate buffer pools + let g2h_pool_gva = alloc_pool(g2h_pages); + let h2g_pool_gva = alloc_pool(h2g_pages); + + // Publish H2G pool GVA so the host can prefill after restore + unsafe { *scratch_top_ptr::(SCRATCH_TOP_H2G_POOL_GVA_OFFSET) = h2g_pool_gva }; + + let ctx = GuestContext::new( + QueueConfig { + layout: g2h_layout, + pool_gva: g2h_pool_gva, + pool_pages: g2h_pages, + }, + QueueConfig { + layout: h2g_layout, + pool_gva: h2g_pool_gva, + pool_pages: h2g_pages, + }, + generation, + ); + hyperlight_guest::virtq::set_global_context(ctx); +} + +/// Reset virtqueue state if a snapshot restore was detected. +/// +/// Compares the generation counter in scratch-top metadata against +/// the context's cached value. On mismatch, restores H2G from the +/// host-prefilled ring and allocates a fresh G2H pool. +pub(crate) fn maybe_reset_virtqueues() { + if !hyperlight_guest::virtq::is_initialized() { + return; + } + + let curr_gen = unsafe { *scratch_top_ptr::(SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET) }; + + hyperlight_guest::virtq::with_context(|ctx| { + if curr_gen == ctx.generation() { + return; + } + + // Read host-assigned H2G pool location from scratch-top + let h2g_pool_gva = unsafe { *scratch_top_ptr::(SCRATCH_TOP_H2G_POOL_GVA_OFFSET) }; + let h2g_pages = unsafe { *scratch_top_ptr::(SCRATCH_TOP_H2G_POOL_PAGES_OFFSET) }; + let g2h_pages = unsafe { *scratch_top_ptr::(SCRATCH_TOP_G2H_POOL_PAGES_OFFSET) }; + + let h2g_pages = h2g_pages as usize; + let g2h_pages = g2h_pages as usize; + let g2h_pool_gva = alloc_pool(g2h_pages); + + ctx.restore_h2g(h2g_pool_gva, h2g_pages * PAGE_SIZE_USIZE); + ctx.reset_g2h(g2h_pool_gva, g2h_pages * PAGE_SIZE_USIZE); + ctx.set_generation(curr_gen); + }); +} + +/// Allocate and zero `n` physical pages, returning the GVA. +fn alloc_pool(n: usize) -> u64 { + let gpa = unsafe { alloc_phys_pages(n as u64) }; + let ptr = phys_to_virt(gpa).expect("failed to map pool pages"); + let size = n * PAGE_SIZE_USIZE; + unsafe { core::ptr::write_bytes(ptr, 0, size) }; + ptr as u64 +} diff --git a/src/hyperlight_guest_capi/src/dispatch.rs b/src/hyperlight_guest_capi/src/dispatch.rs index e0a8bc34c..86ee0fcbe 100644 --- a/src/hyperlight_guest_capi/src/dispatch.rs +++ b/src/hyperlight_guest_capi/src/dispatch.rs @@ -20,12 +20,14 @@ use alloc::vec::Vec; use core::ffi::{CStr, c_char}; use hyperlight_common::flatbuffer_wrappers::function_call::FunctionCall; -use hyperlight_common::flatbuffer_wrappers::function_types::{ParameterType, ReturnType}; +use hyperlight_common::flatbuffer_wrappers::function_types::{ + ParameterType, ReturnType, ReturnValue, +}; use hyperlight_common::flatbuffer_wrappers::guest_error::ErrorCode; use hyperlight_guest::error::{HyperlightGuestError, Result}; +use hyperlight_guest::virtq; use hyperlight_guest_bin::guest_function::definition::GuestFunctionDefinition; use hyperlight_guest_bin::guest_function::register::GuestFunctionRegister; -use hyperlight_guest_bin::host_comm::call_host_function_without_returning_result; use crate::types::{FfiFunctionCall, FfiVec}; static mut REGISTERED_C_GUEST_FUNCTIONS: GuestFunctionRegister = @@ -98,15 +100,22 @@ pub extern "C" fn hl_register_function_definition( unsafe { (&mut *(&raw mut REGISTERED_C_GUEST_FUNCTIONS)).register(func_def) }; } -/// The caller is responsible for freeing the memory associated with given `FfiFunctionCall`. +/// Call a host function. The return value can be retrieved with +/// `hl_get_host_return_value_as_*` immediately after. #[unsafe(no_mangle)] pub extern "C" fn hl_call_host_function(function_call: &FfiFunctionCall) { let parameters = unsafe { function_call.copy_parameters() }; let func_name = unsafe { function_call.copy_function_name() }; let return_type = unsafe { function_call.copy_return_type() }; - // Use the non-generic internal implementation - // The C API will then call specific getter functions to fetch the properly typed return value - let _ = call_host_function_without_returning_result(&func_name, Some(parameters), return_type) - .expect("Failed to call host function"); + virtq::with_context(|ctx| { + let result = + ctx.call_host_function::(&func_name, Some(parameters), return_type); + ctx.stash_host_result(result); + }); +} + +/// Retrieve the return value stashed by the last `hl_call_host_function`. +pub(crate) fn take_last_host_return>() -> T { + virtq::with_context(|ctx| ctx.take_host_return::()) } diff --git a/src/hyperlight_guest_capi/src/error.rs b/src/hyperlight_guest_capi/src/error.rs index 03217600e..720911157 100644 --- a/src/hyperlight_guest_capi/src/error.rs +++ b/src/hyperlight_guest_capi/src/error.rs @@ -19,7 +19,6 @@ use core::ffi::{CStr, c_char}; use flatbuffers::FlatBufferBuilder; use hyperlight_common::flatbuffer_wrappers::function_types::FunctionCallResult; use hyperlight_common::flatbuffer_wrappers::guest_error::{ErrorCode, GuestError}; -use hyperlight_guest_bin::GUEST_HANDLE; use crate::alloc::borrow::ToOwned; @@ -35,12 +34,11 @@ pub extern "C" fn hl_set_error(err: ErrorCode, message: *const c_char) { let fcr = FunctionCallResult::new(guest_error); let mut builder = FlatBufferBuilder::new(); let data = fcr.encode(&mut builder); - unsafe { - #[allow(static_mut_refs)] // we are single threaded - GUEST_HANDLE - .push_shared_output_data(data) - .expect("Failed to set error") - } + + hyperlight_guest::virtq::with_context(|ctx| { + ctx.send_h2g_result(data) + .expect("Failed to send error via virtq"); + }); } #[unsafe(no_mangle)] diff --git a/src/hyperlight_guest_capi/src/flatbuffer.rs b/src/hyperlight_guest_capi/src/flatbuffer.rs index ff12400d6..043431e4c 100644 --- a/src/hyperlight_guest_capi/src/flatbuffer.rs +++ b/src/hyperlight_guest_capi/src/flatbuffer.rs @@ -21,8 +21,8 @@ use alloc::vec::Vec; use core::ffi::{CStr, c_char}; use hyperlight_common::flatbuffer_wrappers::util::get_flatbuffer_result; -use hyperlight_guest_bin::host_comm::get_host_return_value; +use crate::dispatch::take_last_host_return; use crate::types::FfiVec; // The reason for the capitalized type in the function names below @@ -106,44 +106,43 @@ pub extern "C" fn hl_flatbuffer_result_from_Bool(value: bool) -> Box { #[unsafe(no_mangle)] pub extern "C" fn hl_get_host_return_value_as_Int() -> i32 { - get_host_return_value().expect("Unable to get host return value as int") + take_last_host_return() } #[unsafe(no_mangle)] pub extern "C" fn hl_get_host_return_value_as_UInt() -> u32 { - get_host_return_value().expect("Unable to get host return value as uint") + take_last_host_return() } // the same for long, ulong #[unsafe(no_mangle)] pub extern "C" fn hl_get_host_return_value_as_Long() -> i64 { - get_host_return_value().expect("Unable to get host return value as long") + take_last_host_return() } #[unsafe(no_mangle)] pub extern "C" fn hl_get_host_return_value_as_ULong() -> u64 { - get_host_return_value().expect("Unable to get host return value as ulong") + take_last_host_return() } #[unsafe(no_mangle)] pub extern "C" fn hl_get_host_return_value_as_Bool() -> bool { - get_host_return_value().expect("Unable to get host return value as bool") + take_last_host_return() } #[unsafe(no_mangle)] pub extern "C" fn hl_get_host_return_value_as_Float() -> f32 { - get_host_return_value().expect("Unable to get host return value as f32") + take_last_host_return() } #[unsafe(no_mangle)] pub extern "C" fn hl_get_host_return_value_as_Double() -> f64 { - get_host_return_value().expect("Unable to get host return value as f64") + take_last_host_return() } #[unsafe(no_mangle)] pub extern "C" fn hl_get_host_return_value_as_String() -> *const c_char { - let string_value: String = - get_host_return_value().expect("Unable to get host return value as string"); + let string_value: String = take_last_host_return(); let c_string = CString::new(string_value).expect("Failed to create CString"); c_string.into_raw() @@ -151,8 +150,7 @@ pub extern "C" fn hl_get_host_return_value_as_String() -> *const c_char { #[unsafe(no_mangle)] pub extern "C" fn hl_get_host_return_value_as_VecBytes() -> Box { - let vec_value: Vec = - get_host_return_value().expect("Unable to get host return value as vec bytes"); + let vec_value: Vec = take_last_host_return(); Box::new(unsafe { FfiVec::from_vec(vec_value) }) } diff --git a/src/hyperlight_host/Cargo.toml b/src/hyperlight_host/Cargo.toml index f65780504..985642f9f 100644 --- a/src/hyperlight_host/Cargo.toml +++ b/src/hyperlight_host/Cargo.toml @@ -21,6 +21,7 @@ bench = false # see https://bheisler.github.io/criterion.rs/book/faq.html#cargo- workspace = true [dependencies] +bytemuck = { version = "1.25", features = ["derive"] } gdbstub = { version = "0.7.10", optional = true } gdbstub_arch = { version = "0.3.3", optional = true } goblin = { version = "0.10", default-features = false, features = ["std", "elf32", "elf64", "endian_fd"] } diff --git a/src/hyperlight_host/benches/benchmarks.rs b/src/hyperlight_host/benches/benchmarks.rs index 462e8908d..f6e72df3a 100644 --- a/src/hyperlight_host/benches/benchmarks.rs +++ b/src/hyperlight_host/benches/benchmarks.rs @@ -57,13 +57,13 @@ impl SandboxSize { Self::Medium => { let mut cfg = SandboxConfiguration::default(); cfg.set_heap_size(MEDIUM_HEAP_SIZE); - cfg.set_scratch_size(0x50000); + cfg.set_scratch_size(0x80000); Some(cfg) } Self::Large => { let mut cfg = SandboxConfiguration::default(); cfg.set_heap_size(LARGE_HEAP_SIZE); - cfg.set_scratch_size(0x100000); + cfg.set_scratch_size(0x200000); Some(cfg) } } @@ -379,10 +379,15 @@ fn guest_call_benchmark_large_param(c: &mut Criterion) { let large_vec = vec![0u8; SIZE]; let large_string = String::from_utf8(large_vec.clone()).unwrap(); + let h2g_pool_pages = (2 * SIZE + (1024 * 1024)) / 4096; + let heap_size = SIZE as u64 * 15; + let mut config = SandboxConfiguration::default(); - config.set_input_data_size(2 * SIZE + (1024 * 1024)); // 2 * SIZE + 1 MB, to allow 1MB for the rest of the serialized function call - config.set_heap_size(SIZE as u64 * 15); - config.set_scratch_size(6 * SIZE + 4 * (1024 * 1024)); // Big enough for the IO data regions and enough of the heap to be used + config.set_h2g_pool_pages(h2g_pool_pages); + config.set_h2g_queue_depth(h2g_pool_pages.next_power_of_two()); + config.set_heap_size(heap_size); + // Scratch backs all guest physical pages (heap, page tables, pools). + config.set_scratch_size(heap_size as usize + 4 * 1024 * 1024); let sandbox = UninitializedSandbox::new( GuestBinary::FilePath(simple_guest_as_string().unwrap()), @@ -465,7 +470,7 @@ fn sample_workloads_benchmark(c: &mut Criterion) { fn bench_24k_in_8k_out(b: &mut criterion::Bencher, guest_path: String) { let mut cfg = SandboxConfiguration::default(); - cfg.set_input_data_size(25 * 1024); + cfg.set_h2g_pool_pages(7); // 25 * 1024 / 4096 ~= 7 pages let mut sandbox = UninitializedSandbox::new(GuestBinary::FilePath(guest_path), Some(cfg)) .unwrap() diff --git a/src/hyperlight_host/src/hypervisor/hyperlight_vm/x86_64.rs b/src/hyperlight_host/src/hypervisor/hyperlight_vm/x86_64.rs index f06c94964..0b3593be1 100644 --- a/src/hyperlight_host/src/hypervisor/hyperlight_vm/x86_64.rs +++ b/src/hyperlight_host/src/hypervisor/hyperlight_vm/x86_64.rs @@ -2110,17 +2110,20 @@ mod tests { } /// Creates VM with guest code that: dirtys FPU (if flag==0), does FXSAVE to buffer, sets flag=1. - /// Uses output_data region for FXSAVE buffer (like regular guest output), scratch for stack. + /// Uses a scratch region area for the FXSAVE buffer. fn hyperlight_vm_with_mem_mgr_fxsave() -> FxsaveTestContext { use iced_x86::code_asm::*; // Compute fixed addresses for FXSAVE buffer and flag. - // These are in the output_data region which starts at a known offset. - // We use a default SandboxConfiguration to get the same layout as create_test_vm_context. + // Place the buffer at halfway through scratch: well past + // the rings and page tables at the start, and well below + // the stack and scratch-top metadata at the end. let config: SandboxConfiguration = Default::default(); let layout = SandboxMemoryLayout::new(config, 512, 4096, None).unwrap(); - let fxsave_offset = layout.get_output_data_buffer_scratch_host_offset(); - let fxsave_gva = layout.get_output_data_buffer_gva(); + let scratch_size = config.get_scratch_size(); + let fxsave_offset = (scratch_size / 2) & !0xFFF; // page-aligned + let fxsave_gva = + hyperlight_common::layout::scratch_base_gva(scratch_size) + fxsave_offset as u64; let flag_gva = fxsave_gva + 512; let mut a = CodeAssembler::new(64).unwrap(); diff --git a/src/hyperlight_host/src/mem/layout.rs b/src/hyperlight_host/src/mem/layout.rs index 26615d579..52bc1af70 100644 --- a/src/hyperlight_host/src/mem/layout.rs +++ b/src/hyperlight_host/src/mem/layout.rs @@ -47,17 +47,12 @@ limitations under the License. //! //! There is also a scratch region at the top of physical memory, //! which is mostly laid out as a large undifferentiated blob of -//! memory, although at present the snapshot process specially -//! privileges the statically allocated input and output data regions: +//! memory: //! //! +-------------------------------------------+ (top of physical memory) //! | Exception Stack, Metadata | //! +-------------------------------------------+ (1 page below) //! | Scratch Memory | -//! +-------------------------------------------+ -//! | Output Data | -//! +-------------------------------------------+ -//! | Input Data | //! +-------------------------------------------+ (scratch size) use std::fmt::Debug; @@ -223,8 +218,6 @@ pub(crate) struct SandboxMemoryLayout { /// The following fields are offsets to the actual PEB struct fields. /// They are used when writing the PEB struct itself peb_offset: usize, - peb_input_data_offset: usize, - peb_output_data_offset: usize, peb_init_data_offset: usize, peb_heap_data_offset: usize, #[cfg(feature = "nanvix-unstable")] @@ -267,14 +260,6 @@ impl Debug for SandboxMemoryLayout { .field("PEB Address", &format_args!("{:#x}", self.peb_address)) .field("PEB Offset", &format_args!("{:#x}", self.peb_offset)) .field("Code Size", &format_args!("{:#x}", self.code_size)) - .field( - "Input Data Offset", - &format_args!("{:#x}", self.peb_input_data_offset), - ) - .field( - "Output Data Offset", - &format_args!("{:#x}", self.peb_output_data_offset), - ) .field( "Init Data Offset", &format_args!("{:#x}", self.peb_init_data_offset), @@ -320,9 +305,6 @@ impl SandboxMemoryLayout { /// The base address of the sandbox's memory. pub(crate) const BASE_ADDRESS: usize = 0x1000; - // the offset into a sandbox's input/output buffer where the stack starts - pub(crate) const STACK_POINTER_SIZE_BYTES: u64 = 8; - /// Create a new `SandboxMemoryLayout` with the given /// `SandboxConfiguration`, code size and stack/heap size. #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] @@ -338,8 +320,8 @@ impl SandboxMemoryLayout { return Err(MemoryRequestTooBig(scratch_size, Self::MAX_MEMORY_SIZE)); } let min_scratch_size = hyperlight_common::layout::min_scratch_size( - cfg.get_input_data_size(), - cfg.get_output_data_size(), + cfg.get_g2h_queue_depth(), + cfg.get_h2g_queue_depth(), ); if scratch_size < min_scratch_size { return Err(MemoryRequestTooSmall(scratch_size, min_scratch_size)); @@ -348,8 +330,6 @@ impl SandboxMemoryLayout { let guest_code_offset = 0; // The following offsets are to the fields of the PEB struct itself! let peb_offset = code_size.next_multiple_of(PAGE_SIZE_USIZE); - let peb_input_data_offset = peb_offset + offset_of!(HyperlightPEB, input_stack); - let peb_output_data_offset = peb_offset + offset_of!(HyperlightPEB, output_stack); let peb_init_data_offset = peb_offset + offset_of!(HyperlightPEB, init_data); let peb_heap_data_offset = peb_offset + offset_of!(HyperlightPEB, guest_heap); #[cfg(feature = "nanvix-unstable")] @@ -384,8 +364,6 @@ impl SandboxMemoryLayout { let mut ret = Self { peb_offset, heap_size, - peb_input_data_offset, - peb_output_data_offset, peb_init_data_offset, peb_heap_data_offset, #[cfg(feature = "nanvix-unstable")] @@ -406,13 +384,6 @@ impl SandboxMemoryLayout { Ok(ret) } - /// Get the offset in guest memory to the output data size - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - pub(super) fn get_output_data_size_offset(&self) -> usize { - // The size field is the first field in the `OutputData` struct - self.peb_output_data_offset - } - /// Get the offset in guest memory to the init data size #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub(super) fn get_init_data_size_offset(&self) -> usize { @@ -425,14 +396,6 @@ impl SandboxMemoryLayout { self.scratch_size } - /// Get the offset in guest memory to the output data pointer. - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_output_data_pointer_offset(&self) -> usize { - // This field is immediately after the output data size field, - // which is a `u64`. - self.get_output_data_size_offset() + size_of::() - } - /// Get the offset in guest memory to the init data pointer. #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub(super) fn get_init_data_pointer_offset(&self) -> usize { @@ -441,55 +404,51 @@ impl SandboxMemoryLayout { self.get_init_data_size_offset() + size_of::() } - /// Get the guest virtual address of the start of output data. - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - pub(crate) fn get_output_data_buffer_gva(&self) -> u64 { - hyperlight_common::layout::scratch_base_gva(self.scratch_size) - + self.sandbox_memory_config.get_input_data_size() as u64 + /// Get the offset into the scratch region of the G2H ring. + fn get_g2h_ring_scratch_offset(&self) -> usize { + hyperlight_common::layout::g2h_ring_scratch_offset() } - /// Get the offset into the host scratch buffer of the start of - /// the output data. - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - pub(crate) fn get_output_data_buffer_scratch_host_offset(&self) -> usize { - self.sandbox_memory_config.get_input_data_size() + /// Get the size of the G2H ring in bytes. + #[allow(dead_code)] + fn get_g2h_ring_size(&self) -> usize { + hyperlight_common::virtq::Layout::query_size( + self.sandbox_memory_config.get_g2h_queue_depth(), + ) } - /// Get the offset in guest memory to the input data size. - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - pub(super) fn get_input_data_size_offset(&self) -> usize { - // The input data size is the first field in the input stack's `GuestMemoryRegion` struct - self.peb_input_data_offset + /// Get the offset into the scratch region of the H2G ring. + fn get_h2g_ring_scratch_offset(&self) -> usize { + hyperlight_common::layout::h2g_ring_scratch_offset( + self.sandbox_memory_config.get_g2h_queue_depth(), + ) } - /// Get the offset in guest memory to the input data pointer. - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_input_data_pointer_offset(&self) -> usize { - // The input data pointer is immediately after the input - // data size field in the input data `GuestMemoryRegion` struct which is a `u64`. - self.get_input_data_size_offset() + size_of::() + /// Get the size of the H2G ring in bytes. + fn get_h2g_ring_size(&self) -> usize { + hyperlight_common::virtq::Layout::query_size( + self.sandbox_memory_config.get_h2g_queue_depth(), + ) } - /// Get the guest virtual address of the start of input data - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_input_data_buffer_gva(&self) -> u64 { + /// Get the GVA of the G2H ring in guest address space. + pub(crate) fn get_g2h_ring_gva(&self) -> u64 { hyperlight_common::layout::scratch_base_gva(self.scratch_size) + + self.get_g2h_ring_scratch_offset() as u64 } - /// Get the offset into the host scratch buffer of the start of - /// the input data - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - pub(crate) fn get_input_data_buffer_scratch_host_offset(&self) -> usize { - 0 + /// Get the GVA of the H2G ring in guest address space. + pub(crate) fn get_h2g_ring_gva(&self) -> u64 { + hyperlight_common::layout::scratch_base_gva(self.scratch_size) + + self.get_h2g_ring_scratch_offset() as u64 } /// Get the offset from the beginning of the scratch region to the /// location where page tables will be eagerly copied on restore #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub(crate) fn get_pt_base_scratch_offset(&self) -> usize { - (self.sandbox_memory_config.get_input_data_size() - + self.sandbox_memory_config.get_output_data_size()) - .next_multiple_of(hyperlight_common::vmem::PAGE_SIZE) + let after_rings = self.get_h2g_ring_scratch_offset() + self.get_h2g_ring_size(); + after_rings.next_multiple_of(hyperlight_common::vmem::PAGE_SIZE) } /// Get the base GPA to which the page tables will be eagerly @@ -592,8 +551,8 @@ impl SandboxMemoryLayout { #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub(crate) fn set_pt_size(&mut self, size: usize) -> Result<()> { let min_fixed_scratch = hyperlight_common::layout::min_scratch_size( - self.sandbox_memory_config.get_input_data_size(), - self.sandbox_memory_config.get_output_data_size(), + self.sandbox_memory_config.get_g2h_queue_depth(), + self.sandbox_memory_config.get_h2g_queue_depth(), ); let min_scratch = min_fixed_scratch + size; if self.scratch_size < min_scratch { @@ -753,34 +712,6 @@ impl SandboxMemoryLayout { // Start of setting up the PEB. The following are in the order of the PEB fields - // Set up input buffer pointer - write_u64( - mem, - self.get_input_data_size_offset(), - self.sandbox_memory_config - .get_input_data_size() - .try_into()?, - )?; - write_u64( - mem, - self.get_input_data_pointer_offset(), - self.get_input_data_buffer_gva(), - )?; - - // Set up output buffer pointer - write_u64( - mem, - self.get_output_data_size_offset(), - self.sandbox_memory_config - .get_output_data_size() - .try_into()?, - )?; - write_u64( - mem, - self.get_output_data_pointer_offset(), - self.get_output_data_buffer_gva(), - )?; - // Set up init data pointer write_u64( mem, @@ -812,10 +743,9 @@ impl SandboxMemoryLayout { // End of setting up the PEB - // The input and output data regions do not have their layout - // initialised here, because they are in the scratch - // region---they are instead set in - // [`SandboxMemoryManager::update_scratch_bookkeeping`]. + // Virtqueue ring layouts are communicated via scratch-top + // metadata (queue depths), not the PEB. Both host and guest + // compute ring addresses from shared offset functions. Ok(()) } @@ -893,7 +823,6 @@ mod tests { let mut cfg = SandboxConfiguration::default(); // scratch_size exceeds 16 GiB limit cfg.set_scratch_size(17 * 1024 * 1024 * 1024); - cfg.set_input_data_size(16 * 1024 * 1024 * 1024); let layout = SandboxMemoryLayout::new(cfg, 4096, 4096, None); assert!(matches!(layout.unwrap_err(), MemoryRequestTooBig(..))); } diff --git a/src/hyperlight_host/src/mem/mgr.rs b/src/hyperlight_host/src/mem/mgr.rs index 68f35ff7d..bd5449fff 100644 --- a/src/hyperlight_host/src/mem/mgr.rs +++ b/src/hyperlight_host/src/mem/mgr.rs @@ -15,13 +15,12 @@ limitations under the License. */ #[cfg(feature = "nanvix-unstable")] use std::mem::offset_of; +use std::num::NonZeroU16; -use flatbuffers::FlatBufferBuilder; -use hyperlight_common::flatbuffer_wrappers::function_call::{ - FunctionCall, validate_guest_function_call_buffer, -}; use hyperlight_common::flatbuffer_wrappers::function_types::FunctionCallResult; -use hyperlight_common::flatbuffer_wrappers::guest_log_data::GuestLogData; +use hyperlight_common::mem::PAGE_SIZE_USIZE; +use hyperlight_common::virtq::msg::{MsgFlags, MsgKind, VirtqMsgHeader}; +use hyperlight_common::virtq::{self, Layout as VirtqLayout}; use hyperlight_common::vmem::{self, PAGE_TABLE_SIZE}; #[cfg(all(feature = "crashdump", not(feature = "i686-guest")))] use hyperlight_common::vmem::{BasicMapping, MappingKind}; @@ -31,6 +30,7 @@ use super::layout::SandboxMemoryLayout; use super::shared_mem::{ ExclusiveSharedMemory, GuestSharedMemory, HostSharedMemory, ReadonlySharedMemory, SharedMemory, }; +use super::virtq_mem::HostMemOps; use crate::hypervisor::regs::CommonSpecialRegisters; use crate::mem::memory_region::MemoryRegion; #[cfg(crashdump)] @@ -38,6 +38,20 @@ use crate::mem::memory_region::{CrashDumpRegion, MemoryRegionFlags, MemoryRegion use crate::sandbox::snapshot::{NextAction, Snapshot}; use crate::{Result, new_error}; +/// Type alias for the host-side G2H virtqueue consumer. +pub(crate) type G2hConsumer = virtq::VirtqConsumer; +/// Type alias for the host-side H2G virtqueue consumer. +pub(crate) type H2gConsumer = virtq::VirtqConsumer; + +/// No-op notifier for host-side consumer. +/// The host resumes the VM to notify the guest, not via the ring. +#[derive(Clone, Copy)] +pub(crate) struct HostNotifier; + +impl virtq::Notifier for HostNotifier { + fn notify(&self, _stats: virtq::QueueStats) {} +} + #[cfg(all(feature = "crashdump", not(feature = "i686-guest")))] fn mapping_kind_to_flags(kind: &MappingKind) -> (MemoryRegionFlags, MemoryRegionType) { match kind { @@ -134,7 +148,6 @@ impl ReadonlySharedMemory { pub(crate) use unused_hack::SnapshotSharedMemory; /// A struct that is responsible for laying out and managing the memory /// for a given `Sandbox`. -#[derive(Clone)] pub(crate) struct SandboxMemoryManager { /// Shared memory for the Sandbox pub(crate) shared_mem: SnapshotSharedMemory, @@ -154,6 +167,29 @@ pub(crate) struct SandboxMemoryManager { /// restored snapshot's own generation number so the guest-visible /// counter tracks which snapshot the sandbox is a clone of. pub(crate) snapshot_count: u64, + /// G2H virtqueue consumer, created after sandbox init. + pub(crate) g2h_consumer: Option, + /// H2G virtqueue consumer, created after sandbox init. + pub(crate) h2g_consumer: Option, + /// Saved H2G pool GVA for prefilling after snapshot restore. + pub(crate) h2g_pool_gva: Option, +} + +impl Clone for SandboxMemoryManager { + fn clone(&self) -> Self { + Self { + shared_mem: self.shared_mem.clone(), + scratch_mem: self.scratch_mem.clone(), + layout: self.layout, + entrypoint: self.entrypoint, + mapped_rgns: self.mapped_rgns, + abort_buffer: self.abort_buffer.clone(), + snapshot_count: self.snapshot_count, + g2h_consumer: None, + h2g_consumer: None, + h2g_pool_gva: self.h2g_pool_gva, + } + } } /// Buffer for building guest page tables during snapshot creation. @@ -289,6 +325,9 @@ where mapped_rgns: 0, abort_buffer: Vec::new(), snapshot_count: 0, + g2h_consumer: None, + h2g_consumer: None, + h2g_pool_gva: None, } } @@ -359,6 +398,9 @@ impl SandboxMemoryManager { mapped_rgns: self.mapped_rgns, abort_buffer: self.abort_buffer, snapshot_count: self.snapshot_count, + g2h_consumer: None, + h2g_consumer: None, + h2g_pool_gva: None, }; let guest_mgr = SandboxMemoryManager { shared_mem: gshm, @@ -368,8 +410,13 @@ impl SandboxMemoryManager { mapped_rgns: self.mapped_rgns, abort_buffer: Vec::new(), // Guest doesn't need abort buffer snapshot_count: self.snapshot_count, + g2h_consumer: None, + h2g_consumer: None, + h2g_pool_gva: None, }; host_mgr.update_scratch_bookkeeping()?; + host_mgr.init_g2h_consumer()?; + host_mgr.init_h2g_consumer()?; Ok((host_mgr, guest_mgr)) } } @@ -436,89 +483,6 @@ impl SandboxMemoryManager { Ok(()) } - /// Reads a host function call from memory - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - pub(crate) fn get_host_function_call(&mut self) -> Result { - self.scratch_mem.try_pop_buffer_into::( - self.layout.get_output_data_buffer_scratch_host_offset(), - self.layout.sandbox_memory_config.get_output_data_size(), - ) - } - - /// Writes a host function call result to memory - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - pub(crate) fn write_response_from_host_function_call( - &mut self, - res: &FunctionCallResult, - ) -> Result<()> { - let mut builder = FlatBufferBuilder::new(); - let data = res.encode(&mut builder); - - self.scratch_mem.push_buffer( - self.layout.get_input_data_buffer_scratch_host_offset(), - self.layout.sandbox_memory_config.get_input_data_size(), - data, - ) - } - - /// Writes a guest function call to memory - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - pub(crate) fn write_guest_function_call(&mut self, buffer: &[u8]) -> Result<()> { - validate_guest_function_call_buffer(buffer).map_err(|e| { - new_error!( - "Guest function call buffer validation failed: {}", - e.to_string() - ) - })?; - - self.scratch_mem.push_buffer( - self.layout.get_input_data_buffer_scratch_host_offset(), - self.layout.sandbox_memory_config.get_input_data_size(), - buffer, - )?; - Ok(()) - } - - /// Reads a function call result from memory. - /// A function call result can be either an error or a successful return value. - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - pub(crate) fn get_guest_function_call_result(&mut self) -> Result { - self.scratch_mem.try_pop_buffer_into::( - self.layout.get_output_data_buffer_scratch_host_offset(), - self.layout.sandbox_memory_config.get_output_data_size(), - ) - } - - /// Read guest log data from the `SharedMemory` contained within `self` - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - pub(crate) fn read_guest_log_data(&mut self) -> Result { - self.scratch_mem.try_pop_buffer_into::( - self.layout.get_output_data_buffer_scratch_host_offset(), - self.layout.sandbox_memory_config.get_output_data_size(), - ) - } - - pub(crate) fn clear_io_buffers(&mut self) { - // Clear the output data buffer - loop { - let Ok(_) = self.scratch_mem.try_pop_buffer_into::>( - self.layout.get_output_data_buffer_scratch_host_offset(), - self.layout.sandbox_memory_config.get_output_data_size(), - ) else { - break; - }; - } - // Clear the input data buffer - loop { - let Ok(_) = self.scratch_mem.try_pop_buffer_into::>( - self.layout.get_input_data_buffer_scratch_host_offset(), - self.layout.sandbox_memory_config.get_input_data_size(), - ) else { - break; - }; - } - } - /// This function restores a memory snapshot from a given snapshot. pub(crate) fn restore_snapshot( &mut self, @@ -566,6 +530,15 @@ impl SandboxMemoryManager { self.snapshot_count = snapshot.snapshot_generation(); self.update_scratch_bookkeeping()?; + + // Place the H2G pool at first_free so the bump allocator starts right after it. + // Guest reads this GVA from scratch-top during reset(). + let h2g_pool_gva = self.place_h2g_pool_at_first_free()?; + + self.init_g2h_consumer()?; + self.init_h2g_consumer()?; + self.restore_h2g_prefill(h2g_pool_gva)?; + Ok((gsnapshot, gscratch)) } @@ -601,15 +574,31 @@ impl SandboxMemoryManager { self.snapshot_count, )?; - // Initialise the guest input and output data buffers in - // scratch memory. TODO: remove the need for this. - self.scratch_mem.write::( - self.layout.get_input_data_buffer_scratch_host_offset(), - SandboxMemoryLayout::STACK_POINTER_SIZE_BYTES, + // Write virtqueue metadata to scratch-top so the guest can + // discover ring locations without reading the PEB. + self.update_scratch_bookkeeping_item( + SCRATCH_TOP_G2H_RING_GVA_OFFSET, + self.layout.get_g2h_ring_gva(), + )?; + self.update_scratch_bookkeeping_item( + SCRATCH_TOP_H2G_RING_GVA_OFFSET, + self.layout.get_h2g_ring_gva(), + )?; + self.scratch_mem.write::( + scratch_size - SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET as usize, + self.layout.sandbox_memory_config.get_g2h_queue_depth() as u16, + )?; + self.scratch_mem.write::( + scratch_size - SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET as usize, + self.layout.sandbox_memory_config.get_h2g_queue_depth() as u16, )?; - self.scratch_mem.write::( - self.layout.get_output_data_buffer_scratch_host_offset(), - SandboxMemoryLayout::STACK_POINTER_SIZE_BYTES, + self.scratch_mem.write::( + scratch_size - SCRATCH_TOP_G2H_POOL_PAGES_OFFSET as usize, + self.layout.sandbox_memory_config.get_g2h_pool_pages() as u16, + )?; + self.scratch_mem.write::( + scratch_size - SCRATCH_TOP_H2G_POOL_PAGES_OFFSET as usize, + self.layout.sandbox_memory_config.get_h2g_pool_pages() as u16, )?; // Copy page tables from `shared_mem` into scratch. PT bytes @@ -856,6 +845,264 @@ impl SandboxMemoryManager { }) })?? } + + /// Compute the G2H virtqueue Layout from scratch region addresses. + pub(crate) fn g2h_virtq_layout(&self) -> Result { + let base = self.layout.get_g2h_ring_gva(); + let depth = self.layout.sandbox_memory_config.get_g2h_queue_depth() as u16; + + let nz = NonZeroU16::new(depth).ok_or_else(|| new_error!("G2H queue depth is zero"))?; + + unsafe { VirtqLayout::from_base(base, nz) } + .map_err(|e| new_error!("Invalid G2H virtq layout: {:?}", e)) + } + + /// Compute the H2G virtqueue Layout from scratch region addresses. + pub(crate) fn h2g_virtq_layout(&self) -> Result { + let base = self.layout.get_h2g_ring_gva(); + let depth = self.layout.sandbox_memory_config.get_h2g_queue_depth() as u16; + + let nz = NonZeroU16::new(depth).ok_or_else(|| new_error!("H2G queue depth is zero"))?; + + unsafe { VirtqLayout::from_base(base, nz) } + .map_err(|e| new_error!("Invalid H2G virtq layout: {:?}", e)) + } + + /// Create a [`HostMemOps`] instance backed by this manager's + /// scratch shared memory. + pub(crate) fn host_mem_ops(&self) -> HostMemOps { + let scratch_base_gva = + hyperlight_common::layout::scratch_base_gva(self.scratch_mem.mem_size()); + HostMemOps::new(&self.scratch_mem, scratch_base_gva) + } + + /// Total G2H buffer pool size in bytes. + pub(crate) fn g2h_pool_size(&self) -> usize { + self.layout.sandbox_memory_config.get_g2h_pool_pages() * PAGE_SIZE_USIZE + } + + pub(crate) fn h2g_pool_size(&self) -> usize { + self.layout.sandbox_memory_config.get_h2g_pool_pages() * PAGE_SIZE_USIZE + } + + /// H2G slot size in bytes. Each prefilled writable descriptor has this capacity. + pub(crate) fn h2g_slot_size(&self) -> usize { + PAGE_SIZE_USIZE + } + + /// Initialize the G2H virtqueue consumer. + /// Must be called after scratch bookkeeping is written. + pub(crate) fn init_g2h_consumer(&mut self) -> Result<()> { + match &mut self.g2h_consumer { + Some(consumer) => { + consumer.reset(); + } + None => { + let layout = self.g2h_virtq_layout()?; + let mem_ops = self.host_mem_ops(); + let consumer = virtq::VirtqConsumer::new(layout, mem_ops, HostNotifier); + self.g2h_consumer = Some(consumer); + } + } + Ok(()) + } + + /// Initialize the H2G virtqueue consumer. + /// + /// Must be called after scratch bookkeeping is written. Avail suppression is set to Disable + /// so guest prefill/refill operations do not trigger VM exits. + pub(crate) fn init_h2g_consumer(&mut self) -> Result<()> { + match &mut self.h2g_consumer { + Some(consumer) => { + consumer.reset(); + consumer + .set_avail_suppression(virtq::SuppressionKind::Disable) + .map_err(|e| new_error!("H2G avail suppression: {:?}", e))?; + } + None => { + let layout = self.h2g_virtq_layout()?; + let mem_ops = self.host_mem_ops(); + let mut consumer = virtq::VirtqConsumer::new(layout, mem_ops, HostNotifier); + consumer + .set_avail_suppression(virtq::SuppressionKind::Disable) + .map_err(|e| new_error!("H2G avail suppression: {:?}", e))?; + self.h2g_consumer = Some(consumer); + } + } + Ok(()) + } + + /// Place the H2G pool at `first_free` during snapshot restore. + /// + /// Writes the pool GVA to scratch-top and advances the bump + /// allocator past the pool so COW page-fault resolution cannot + /// alias pool memory. Returns the computed pool GVA for use by + /// [`restore_h2g_prefill`]. + fn place_h2g_pool_at_first_free(&mut self) -> Result { + use hyperlight_common::layout::*; + + let scratch_size = self.scratch_mem.mem_size(); + let first_free = self.layout.get_first_free_scratch_gpa(); + let base_gpa = scratch_base_gpa(scratch_size); + let base_gva = scratch_base_gva(scratch_size); + let h2g_pool_gva = base_gva + (first_free - base_gpa); + let h2g_pages = self.layout.sandbox_memory_config.get_h2g_pool_pages() as u64; + + self.update_scratch_bookkeeping_item(SCRATCH_TOP_H2G_POOL_GVA_OFFSET, h2g_pool_gva)?; + let allocator = first_free + h2g_pages * PAGE_SIZE_USIZE as u64; + self.update_scratch_bookkeeping_item(SCRATCH_TOP_ALLOCATOR_OFFSET, allocator)?; + + Ok(h2g_pool_gva) + } + + /// Prefill the H2G ring with writable descriptors after snapshot restore. + /// + /// Uses a temporary `RingProducer` to write descriptors into the H2G ring + /// so the host consumer can poll them. The guest's `restore_from_ring` + /// will later reconstruct its inflight state from these descriptors. + fn restore_h2g_prefill(&mut self, pool_gva: u64) -> Result<()> { + let layout = self.h2g_virtq_layout()?; + let mem_ops = self.host_mem_ops(); + let h2g_depth = self.layout.sandbox_memory_config.get_h2g_queue_depth(); + + let slot_size = self.h2g_slot_size(); + let pool_size = self.layout.sandbox_memory_config.get_h2g_pool_pages() * PAGE_SIZE_USIZE; + let slot_count = pool_size / slot_size; + + let mut producer = virtq::RingProducer::new(layout, mem_ops); + let prefill_count = core::cmp::min(slot_count, h2g_depth); + + // Write descriptors in forward order. The guest calls + // restore_from_ring which reconstructs used-descriptor addresses + // as base + position * slot_size, so the iteration order must + // match this formula. + for i in 0..prefill_count { + let addr = pool_gva + (i * slot_size) as u64; + producer + .submit_one(addr, slot_size as u32, true) + .map_err(|e| new_error!("H2G prefill submit: {:?}", e))?; + } + + Ok(()) + } + + /// Write a guest function call into the H2G virtqueue. + /// + /// Large payloads that exceed a single slot are split across multiple descriptors. + pub(crate) fn write_guest_function_call_virtq(&mut self, buffer: &[u8]) -> Result<()> { + let h2g_pool_size = self.h2g_pool_size(); + + let consumer = self + .h2g_consumer + .as_mut() + .ok_or_else(|| new_error!("H2G consumer not initialized"))?; + + let mut offset = 0usize; + + loop { + let remaining = buffer.len() - offset; + + let (entry, completion) = consumer + .poll(h2g_pool_size) + .map_err(|e| new_error!("H2G poll: {:?}", e))? + .ok_or_else(|| new_error!("H2G: no prefilled descriptor available"))?; + + drop(entry); + + let virtq::SendCompletion::Writable(mut wc) = completion else { + return Err(new_error!( + "H2G: expected writable completion (ring corruption)" + )); + }; + + let data_cap = wc.capacity() - VirtqMsgHeader::SIZE; + let chunk_len = remaining.min(data_cap); + let has_more = offset + chunk_len < buffer.len(); + + let flags = if has_more { + MsgFlags::MORE + } else { + MsgFlags::empty() + }; + + let hdr = VirtqMsgHeader::with_flags(MsgKind::Request, flags, 0, chunk_len as u32); + + wc.write_all(bytemuck::bytes_of(&hdr)) + .map_err(|e| new_error!("H2G write header: {:?}", e))?; + wc.write_all(&buffer[offset..offset + chunk_len]) + .map_err(|e| new_error!("H2G write payload: {:?}", e))?; + + consumer + .complete(wc.into()) + .map_err(|e| new_error!("H2G complete: {:?}", e))?; + + offset += chunk_len; + + if !has_more { + break; + } + } + + Ok(()) + } + + /// Read the H2G result from G2H after the guest halts. + /// + /// The guest submitted the Response on G2H with + pub(crate) fn read_h2g_result_from_g2h(&mut self) -> Result { + let g2h_pool_size = self.g2h_pool_size(); + + let consumer = self + .g2h_consumer + .as_mut() + .ok_or_else(|| new_error!("G2H consumer not initialized"))?; + + // Drain the G2H queue, processing Log entries inline, until we + // find the Response that carries the H2G function call result. + loop { + let maybe_next = consumer + .poll(g2h_pool_size) + .map_err(|e| new_error!("G2H poll for H2G result: {:?}", e))?; + + let Some((entry, completion)) = maybe_next else { + return Err(new_error!("G2H: no H2G result entry after halt")); + }; + + let entry_data = entry.data(); + if entry_data.len() < VirtqMsgHeader::SIZE { + return Err(new_error!("G2H: result entry too short")); + } + + let hdr_size = VirtqMsgHeader::SIZE; + let hdr: &VirtqMsgHeader = bytemuck::from_bytes(&entry_data[..hdr_size]); + let available = entry_data.len() - hdr_size; + let payload_len = (hdr.payload_len as usize).min(available); + let payload = &entry_data[hdr_size..hdr_size + payload_len]; + + match hdr.msg_kind() { + Ok(MsgKind::Response) => { + let fcr = FunctionCallResult::try_from(payload) + .map_err(|e| new_error!("G2H: malformed FunctionCallResult: {}", e))?; + consumer + .complete(completion) + .map_err(|e| new_error!("G2H complete: {:?}", e))?; + return Ok(fcr); + } + Ok(MsgKind::Log) => { + crate::sandbox::outb::emit_guest_log(payload); + consumer + .complete(completion) + .map_err(|e| new_error!("G2H complete log: {:?}", e))?; + } + Ok(other) => { + return Err(new_error!("G2H: expected Response or Log, got {:?}", other)); + } + Err(unknown) => { + return Err(new_error!("G2H: unknown message kind: 0x{:02x}", unknown)); + } + } + } + } } #[cfg(test)] @@ -966,4 +1213,89 @@ mod tests { verify_page_tables(name, config); } } + + /// Verify that the H2G pool placed at `first_free` during restore + /// does not overlap with the bump allocator range or the + /// scratch-top metadata region. + /// + /// This guards against the COW-pool GPA overlap bug: if the bump + /// allocator could return GPAs inside the pool region, a COW + /// page-fault would overwrite pool buffer data with stale shared + /// memory content, corrupting virtqueue communication. + fn verify_pool_allocator_no_collision(name: &str, config: SandboxConfiguration) { + let path = simple_guest_as_string().expect("failed to get simple guest path"); + let snapshot = Snapshot::from_env(GuestBinary::FilePath(path), config) + .unwrap_or_else(|e| panic!("{name}: failed to create snapshot: {e}")); + + let layout = snapshot.layout(); + let scratch_size = layout.get_scratch_size(); + let first_free = layout.get_first_free_scratch_gpa(); + let h2g_pages = layout.sandbox_memory_config.get_h2g_pool_pages(); + let scratch_base = hyperlight_common::layout::scratch_base_gpa(scratch_size); + + let pool_start = first_free; + let pool_end = first_free + (h2g_pages * PAGE_TABLE_SIZE) as u64; + let allocator_start = pool_end; + + // The metadata region lives at the very top of scratch. + // SCRATCH_TOP_EXN_STACK_OFFSET (0x50) is the highest offset. + // Two pages are reserved at the top for exception stack and metadata. + let scratch_end = scratch_base + scratch_size as u64; + let metadata_start = scratch_end - 2 * PAGE_TABLE_SIZE as u64; + + assert!( + pool_start >= scratch_base, + "{name}: pool starts before scratch (pool=0x{pool_start:x}, scratch=0x{scratch_base:x})" + ); + + assert!( + pool_end <= metadata_start, + "{name}: pool overlaps metadata (pool_end=0x{pool_end:x}, metadata=0x{metadata_start:x})" + ); + + assert_eq!( + allocator_start, pool_end, + "{name}: allocator should start immediately after pool" + ); + + assert!( + allocator_start < metadata_start, + "{name}: no room for COW allocations (allocator=0x{allocator_start:x}, metadata=0x{metadata_start:x})" + ); + } + + #[test] + fn test_pool_allocator_no_collision() { + let test_cases: Vec<(&str, SandboxConfiguration)> = vec![ + ("default", SandboxConfiguration::default()), + ("large pools", { + let mut cfg = SandboxConfiguration::default(); + cfg.set_h2g_pool_pages(16); + cfg.set_g2h_pool_pages(16); + cfg + }), + ("minimal scratch", { + let mut cfg = SandboxConfiguration::default(); + cfg.set_scratch_size(0x20000); + cfg + }), + ("large scratch", { + let mut cfg = SandboxConfiguration::default(); + cfg.set_scratch_size(0x100000); + cfg + }), + ("large heap + large pools", { + let mut cfg = SandboxConfiguration::default(); + cfg.set_heap_size(LARGE_HEAP_SIZE); + cfg.set_scratch_size(0x100000); + cfg.set_h2g_pool_pages(32); + cfg.set_g2h_pool_pages(32); + cfg + }), + ]; + + for (name, config) in test_cases { + verify_pool_allocator_no_collision(name, config); + } + } } diff --git a/src/hyperlight_host/src/mem/mod.rs b/src/hyperlight_host/src/mem/mod.rs index 64f5db2fe..4882bc75c 100644 --- a/src/hyperlight_host/src/mem/mod.rs +++ b/src/hyperlight_host/src/mem/mod.rs @@ -38,3 +38,5 @@ pub mod shared_mem; /// Utilities for writing shared memory tests #[cfg(all(test, not(miri)))] // uses proptest which isn't miri-compatible pub(crate) mod shared_mem_tests; +/// Host-side [`hyperlight_common::virtq::MemOps`] for virtqueue access. +pub(crate) mod virtq_mem; diff --git a/src/hyperlight_host/src/mem/shared_mem.rs b/src/hyperlight_host/src/mem/shared_mem.rs index 5f975f605..f8766c347 100644 --- a/src/hyperlight_host/src/mem/shared_mem.rs +++ b/src/hyperlight_host/src/mem/shared_mem.rs @@ -14,7 +14,6 @@ See the License for the specific language governing permissions and limitations under the License. */ -use std::any::type_name; use std::ffi::c_void; use std::io::Error; use std::mem::{align_of, size_of}; @@ -878,57 +877,25 @@ impl SharedMemory for GuestSharedMemory { } } -/// An unsafe marker trait for types for which all bit patterns are valid. -/// This is required in order for it to be safe to read a value of a particular -/// type out of the sandbox from the HostSharedMemory. -/// -/// # Safety -/// This must only be implemented for types for which all bit patterns -/// are valid. It requires that any (non-undef/poison) value of the -/// correct size can be transmuted to the type. -pub unsafe trait AllValid {} -unsafe impl AllValid for u8 {} -unsafe impl AllValid for u16 {} -unsafe impl AllValid for u32 {} -unsafe impl AllValid for u64 {} -unsafe impl AllValid for i8 {} -unsafe impl AllValid for i16 {} -unsafe impl AllValid for i32 {} -unsafe impl AllValid for i64 {} -unsafe impl AllValid for [u8; 16] {} - impl HostSharedMemory { - /// Read a value of type T, whose representation is the same - /// between the sandbox and the host, and which has no invalid bit - /// patterns - pub fn read(&self, offset: usize) -> Result { + /// Read a value of type T from the sandbox at the given offset. + /// + /// T must implement [`bytemuck::Pod`] which guarantees all bit + /// patterns are valid and there is no padding. + pub fn read(&self, offset: usize) -> Result { bounds_check!(offset, std::mem::size_of::(), self.mem_size()); - unsafe { - let mut ret: core::mem::MaybeUninit = core::mem::MaybeUninit::uninit(); - { - let slice: &mut [u8] = core::slice::from_raw_parts_mut( - ret.as_mut_ptr() as *mut u8, - std::mem::size_of::(), - ); - self.copy_to_slice(slice, offset)?; - } - Ok(ret.assume_init()) - } + let mut val = T::zeroed(); + self.copy_to_slice(bytemuck::bytes_of_mut(&mut val), offset)?; + Ok(val) } - /// Write a value of type T, whose representation is the same - /// between the sandbox and the host, and which has no invalid bit - /// patterns - pub fn write(&self, offset: usize, data: T) -> Result<()> { + /// Write a value of type T into the sandbox at the given offset. + /// + /// T must implement [`bytemuck::Pod`] which guarantees all bit + /// patterns are valid and there is no padding. + pub fn write(&self, offset: usize, data: T) -> Result<()> { bounds_check!(offset, std::mem::size_of::(), self.mem_size()); - unsafe { - let slice: &[u8] = core::slice::from_raw_parts( - core::ptr::addr_of!(data) as *const u8, - std::mem::size_of::(), - ); - self.copy_from_slice(slice, offset)?; - } - Ok(()) + self.copy_from_slice(bytemuck::bytes_of(&data), offset) } /// Copy the contents of the slice into the sandbox at the @@ -1080,145 +1047,6 @@ impl HostSharedMemory { drop(guard); Ok(()) } - - /// Pushes the given data onto shared memory to the buffer at the given offset. - /// NOTE! buffer_start_offset must point to the beginning of the buffer - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - pub fn push_buffer( - &mut self, - buffer_start_offset: usize, - buffer_size: usize, - data: &[u8], - ) -> Result<()> { - let stack_pointer_rel = self.read::(buffer_start_offset)? as usize; - let buffer_size_u64: u64 = buffer_size.try_into()?; - - if stack_pointer_rel > buffer_size || stack_pointer_rel < 8 { - return Err(new_error!( - "Unable to push data to buffer: Stack pointer is out of bounds. Stack pointer: {}, Buffer size: {}", - stack_pointer_rel, - buffer_size_u64 - )); - } - - let size_required = data.len() + 8; - let size_available = buffer_size - stack_pointer_rel; - - if size_required > size_available { - return Err(new_error!( - "Not enough space in buffer to push data. Required: {}, Available: {}", - size_required, - size_available - )); - } - - // get absolute - let stack_pointer_abs = stack_pointer_rel + buffer_start_offset; - - // write the actual data to the top of stack - self.copy_from_slice(data, stack_pointer_abs)?; - - // write the offset to the newly written data, to the top of stack. - // this is used when popping the stack, to know how far back to jump - self.write::(stack_pointer_abs + data.len(), stack_pointer_rel as u64)?; - - // update stack pointer to point to the next free address - self.write::( - buffer_start_offset, - (stack_pointer_rel + data.len() + 8) as u64, - )?; - Ok(()) - } - - /// Pops the given given buffer into a `T` and returns it. - /// NOTE! the data must be a size-prefixed flatbuffer, and - /// buffer_start_offset must point to the beginning of the buffer - pub fn try_pop_buffer_into( - &mut self, - buffer_start_offset: usize, - buffer_size: usize, - ) -> Result - where - T: for<'b> TryFrom<&'b [u8]>, - { - // get the stackpointer - let stack_pointer_rel = self.read::(buffer_start_offset)? as usize; - - if stack_pointer_rel > buffer_size || stack_pointer_rel < 16 { - return Err(new_error!( - "Unable to pop data from buffer: Stack pointer is out of bounds. Stack pointer: {}, Buffer size: {}", - stack_pointer_rel, - buffer_size - )); - } - - // make it absolute - let last_element_offset_abs = stack_pointer_rel + buffer_start_offset; - - // go back 8 bytes to get offset to element on top of stack - let last_element_offset_rel: usize = - self.read::(last_element_offset_abs - 8)? as usize; - - // Validate element offset (guest-writable): must be in [8, stack_pointer_rel - 16] - // to leave room for the 8-byte back-pointer plus at least 8 bytes of element data - // (the minimum for a size-prefixed flatbuffer: 4-byte prefix + 4-byte root offset). - if last_element_offset_rel > stack_pointer_rel.saturating_sub(16) - || last_element_offset_rel < 8 - { - return Err(new_error!( - "Corrupt buffer back-pointer: element offset {} is outside valid range [8, {}].", - last_element_offset_rel, - stack_pointer_rel.saturating_sub(16), - )); - } - - // make it absolute - let last_element_offset_abs = last_element_offset_rel + buffer_start_offset; - - // Max bytes the element can span (excluding the 8-byte back-pointer). - let max_element_size = stack_pointer_rel - last_element_offset_rel - 8; - - // Get the size of the flatbuffer buffer from memory - let fb_buffer_size = { - let raw_prefix = self.read::(last_element_offset_abs)?; - // flatbuffer byte arrays are prefixed by 4 bytes indicating - // the remaining size; add 4 for the prefix itself. - let total = raw_prefix.checked_add(4).ok_or_else(|| { - new_error!( - "Corrupt buffer size prefix: value {} overflows when adding 4-byte header.", - raw_prefix - ) - })?; - usize::try_from(total) - }?; - - if fb_buffer_size > max_element_size { - return Err(new_error!( - "Corrupt buffer size prefix: flatbuffer claims {} bytes but the element slot is only {} bytes.", - fb_buffer_size, - max_element_size - )); - } - - let mut result_buffer = vec![0; fb_buffer_size]; - - self.copy_to_slice(&mut result_buffer, last_element_offset_abs)?; - let to_return = T::try_from(result_buffer.as_slice()).map_err(|_e| { - new_error!( - "pop_buffer_into: failed to convert buffer to {}", - type_name::() - ) - })?; - - // update the stack pointer to point to the element we just popped off since that is now free - self.write::(buffer_start_offset, last_element_offset_rel as u64)?; - - // zero out the memory we just popped off - let num_bytes_to_zero = stack_pointer_rel - last_element_offset_rel; - self.fill(0, last_element_offset_abs, num_bytes_to_zero)?; - - Ok(to_return) - } } impl SharedMemory for HostSharedMemory { @@ -1726,192 +1554,6 @@ mod tests { } } - /// Bounds checking for `try_pop_buffer_into` against corrupt guest data. - mod try_pop_buffer_bounds { - use super::*; - - #[derive(Debug, PartialEq)] - struct RawBytes(Vec); - - impl TryFrom<&[u8]> for RawBytes { - type Error = String; - fn try_from(value: &[u8]) -> std::result::Result { - Ok(RawBytes(value.to_vec())) - } - } - - /// Create a buffer with stack pointer initialized to 8 (empty). - fn make_buffer(mem_size: usize) -> super::super::HostSharedMemory { - let eshm = ExclusiveSharedMemory::new(mem_size).unwrap(); - let (hshm, _) = eshm.build(); - hshm.write::(0, 8u64).unwrap(); - hshm - } - - #[test] - fn normal_push_pop_roundtrip() { - let mem_size = 4096; - let mut hshm = make_buffer(mem_size); - - // Size-prefixed flatbuffer-like payload: [size: u32 LE][payload] - let payload = b"hello"; - let mut data = Vec::new(); - data.extend_from_slice(&(payload.len() as u32).to_le_bytes()); - data.extend_from_slice(payload); - - hshm.push_buffer(0, mem_size, &data).unwrap(); - let result: RawBytes = hshm.try_pop_buffer_into(0, mem_size).unwrap(); - assert_eq!(result.0, data); - } - - #[test] - fn malicious_flatbuffer_size_prefix() { - let mem_size = 4096; - let mut hshm = make_buffer(mem_size); - - let payload = b"small"; - let mut data = Vec::new(); - data.extend_from_slice(&(payload.len() as u32).to_le_bytes()); - data.extend_from_slice(payload); - hshm.push_buffer(0, mem_size, &data).unwrap(); - - // Corrupt size prefix at element start (offset 8) to near u32::MAX. - hshm.write::(8, 0xFFFF_FFFBu32).unwrap(); // +4 = 0xFFFF_FFFF - - let result: Result = hshm.try_pop_buffer_into(0, mem_size); - let err_msg = format!("{}", result.unwrap_err()); - assert!( - err_msg.contains("Corrupt buffer size prefix: flatbuffer claims 4294967295 bytes but the element slot is only 9 bytes"), - "Unexpected error message: {}", - err_msg - ); - } - - #[test] - fn malicious_element_offset_too_small() { - let mem_size = 4096; - let mut hshm = make_buffer(mem_size); - - let payload = b"test"; - let mut data = Vec::new(); - data.extend_from_slice(&(payload.len() as u32).to_le_bytes()); - data.extend_from_slice(payload); - hshm.push_buffer(0, mem_size, &data).unwrap(); - - // Corrupt back-pointer (offset 16) to 0 (before valid range). - hshm.write::(16, 0u64).unwrap(); - - let result: Result = hshm.try_pop_buffer_into(0, mem_size); - let err_msg = format!("{}", result.unwrap_err()); - assert!( - err_msg.contains( - "Corrupt buffer back-pointer: element offset 0 is outside valid range [8, 8]" - ), - "Unexpected error message: {}", - err_msg - ); - } - - #[test] - fn malicious_element_offset_past_stack_pointer() { - let mem_size = 4096; - let mut hshm = make_buffer(mem_size); - - let payload = b"test"; - let mut data = Vec::new(); - data.extend_from_slice(&(payload.len() as u32).to_le_bytes()); - data.extend_from_slice(payload); - hshm.push_buffer(0, mem_size, &data).unwrap(); - - // Corrupt back-pointer (offset 16) to 9999 (past stack pointer 24). - hshm.write::(16, 9999u64).unwrap(); - - let result: Result = hshm.try_pop_buffer_into(0, mem_size); - let err_msg = format!("{}", result.unwrap_err()); - assert!( - err_msg.contains( - "Corrupt buffer back-pointer: element offset 9999 is outside valid range [8, 8]" - ), - "Unexpected error message: {}", - err_msg - ); - } - - #[test] - fn malicious_flatbuffer_size_off_by_one() { - let mem_size = 4096; - let mut hshm = make_buffer(mem_size); - - let payload = b"abcd"; - let mut data = Vec::new(); - data.extend_from_slice(&(payload.len() as u32).to_le_bytes()); - data.extend_from_slice(payload); - hshm.push_buffer(0, mem_size, &data).unwrap(); - - // Corrupt size prefix: claim 5 bytes (total 9), exceeding the 8-byte slot. - hshm.write::(8, 5u32).unwrap(); // fb_buffer_size = 5 + 4 = 9 - - let result: Result = hshm.try_pop_buffer_into(0, mem_size); - let err_msg = format!("{}", result.unwrap_err()); - assert!( - err_msg.contains("Corrupt buffer size prefix: flatbuffer claims 9 bytes but the element slot is only 8 bytes"), - "Unexpected error message: {}", - err_msg - ); - } - - /// Back-pointer just below stack_pointer causes underflow in - /// `stack_pointer_rel - last_element_offset_rel - 8`. - #[test] - fn back_pointer_near_stack_pointer_underflow() { - let mem_size = 4096; - let mut hshm = make_buffer(mem_size); - - let payload = b"test"; - let mut data = Vec::new(); - data.extend_from_slice(&(payload.len() as u32).to_le_bytes()); - data.extend_from_slice(payload); - hshm.push_buffer(0, mem_size, &data).unwrap(); - - // stack_pointer_rel = 24. Set back-pointer to 23 (> 24 - 16 = 8, so rejected). - hshm.write::(16, 23u64).unwrap(); - - let result: Result = hshm.try_pop_buffer_into(0, mem_size); - let err_msg = format!("{}", result.unwrap_err()); - assert!( - err_msg.contains( - "Corrupt buffer back-pointer: element offset 23 is outside valid range [8, 8]" - ), - "Unexpected error message: {}", - err_msg - ); - } - - /// Size prefix of 0xFFFF_FFFD causes u32 overflow: 0xFFFF_FFFD + 4 wraps. - #[test] - fn size_prefix_u32_overflow() { - let mem_size = 4096; - let mut hshm = make_buffer(mem_size); - - let payload = b"test"; - let mut data = Vec::new(); - data.extend_from_slice(&(payload.len() as u32).to_le_bytes()); - data.extend_from_slice(payload); - hshm.push_buffer(0, mem_size, &data).unwrap(); - - // Write 0xFFFF_FFFD as size prefix: checked_add(4) returns None. - hshm.write::(8, 0xFFFF_FFFDu32).unwrap(); - - let result: Result = hshm.try_pop_buffer_into(0, mem_size); - let err_msg = format!("{}", result.unwrap_err()); - assert!( - err_msg.contains("Corrupt buffer size prefix: value 4294967293 overflows when adding 4-byte header"), - "Unexpected error message: {}", - err_msg - ); - } - } - #[cfg(target_os = "linux")] mod guard_page_crash_test { use crate::mem::shared_mem::{ExclusiveSharedMemory, SharedMemory}; diff --git a/src/hyperlight_host/src/mem/virtq_mem.rs b/src/hyperlight_host/src/mem/virtq_mem.rs new file mode 100644 index 000000000..8f01c523b --- /dev/null +++ b/src/hyperlight_host/src/mem/virtq_mem.rs @@ -0,0 +1,126 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Host-side [`MemOps`] implementation for virtqueue access. +//! +//! Translates guest virtual addresses used in virtqueue descriptors +//! to offsets into the scratch [`HostSharedMemory`], reusing its +//! volatile access and locking patterns. + +use core::sync::atomic::{AtomicU16, Ordering}; + +use hyperlight_common::virtq::MemOps; + +use super::shared_mem::{HostSharedMemory, SharedMemory}; + +/// Error type for host memory operations. +#[derive(Debug, thiserror::Error)] +pub enum HostMemError { + #[error("address {addr:#x} out of bounds scratch_size={scratch_size}")] + OutOfBounds { addr: u64, scratch_size: usize }, + #[error("shared memory error: {0}")] + SharedMem(String), + #[error("as_slice/as_mut_slice not supported on host")] + DirectSliceNotSupported, +} + +/// Host-side memory accessor for virtqueue operations. +/// +/// Owns a clone of the scratch [`HostSharedMemory`] and translates +/// guest virtual addresses (in the scratch region) to offsets for the +/// existing volatile read/write methods. +#[derive(Clone)] +pub(crate) struct HostMemOps { + /// Cloned handle to the scratch shared memory + scratch: HostSharedMemory, + /// The guest virtual address that corresponds to scratch offset 0. + scratch_base_gva: u64, +} + +impl HostMemOps { + /// Create a new `HostMemOps` backed by shared memory. + pub(crate) fn new(scratch: &HostSharedMemory, scratch_base_gva: u64) -> Self { + Self { + scratch: scratch.clone(), + scratch_base_gva, + } + } + + /// Translate a guest virtual address to a scratch offset. + fn to_offset(&self, addr: u64) -> Result { + addr.checked_sub(self.scratch_base_gva) + .map(|o| o as usize) + .ok_or(HostMemError::OutOfBounds { + addr, + scratch_size: self.scratch.mem_size(), + }) + } + + /// Get a raw pointer into scratch memory at the given guest address. + fn raw_ptr(&self, addr: u64, len: usize) -> Result<*mut u8, HostMemError> { + let offset = self.to_offset(addr)?; + let scratch_size = self.scratch.mem_size(); + + if offset.checked_add(len).is_none_or(|end| end > scratch_size) { + return Err(HostMemError::OutOfBounds { addr, scratch_size }); + } + + Ok(self.scratch.base_ptr().wrapping_add(offset)) + } +} + +// SAFETY: HostMemOps bounds-checks guest addresses against scratch memory before +// accessing them and uses atomic operations for acquire/release accesses. +unsafe impl MemOps for HostMemOps { + type Error = HostMemError; + + fn read(&self, addr: u64, dst: &mut [u8]) -> Result<(), Self::Error> { + let offset = self.to_offset(addr)?; + self.scratch + .copy_to_slice(dst, offset) + .map_err(|e| HostMemError::SharedMem(e.to_string()))?; + Ok(()) + } + + fn write(&self, addr: u64, src: &[u8]) -> Result<(), Self::Error> { + let offset = self.to_offset(addr)?; + self.scratch + .copy_from_slice(src, offset) + .map_err(|e| HostMemError::SharedMem(e.to_string()))?; + Ok(()) + } + + fn load_acquire(&self, addr: u64) -> Result { + let ptr = self.raw_ptr(addr, core::mem::size_of::())?; + let atomic = unsafe { &*(ptr as *const AtomicU16) }; + Ok(atomic.load(Ordering::Acquire)) + } + + fn store_release(&self, addr: u64, val: u16) -> Result<(), Self::Error> { + let ptr = self.raw_ptr(addr, core::mem::size_of::())?; + let atomic = unsafe { &*(ptr as *const AtomicU16) }; + atomic.store(val, Ordering::Release); + Ok(()) + } + + unsafe fn as_slice(&self, _addr: u64, _len: usize) -> Result<&[u8], Self::Error> { + Err(HostMemError::DirectSliceNotSupported) + } + + unsafe fn as_mut_slice(&self, _addr: u64, _len: usize) -> Result<&mut [u8], Self::Error> { + Err(HostMemError::DirectSliceNotSupported) + } +} diff --git a/src/hyperlight_host/src/sandbox/config.rs b/src/hyperlight_host/src/sandbox/config.rs index f12387a0b..da068b2cd 100644 --- a/src/hyperlight_host/src/sandbox/config.rs +++ b/src/hyperlight_host/src/sandbox/config.rs @@ -14,9 +14,9 @@ See the License for the specific language governing permissions and limitations under the License. */ -use std::cmp::max; use std::time::Duration; +use hyperlight_common::mem::PAGE_SIZE_USIZE; #[cfg(target_os = "linux")] use libc::c_int; use tracing::{Span, instrument}; @@ -44,12 +44,6 @@ pub struct SandboxConfiguration { /// Guest gdb debug port #[cfg(gdb)] guest_debug_info: Option, - /// The size of the memory buffer that is made available for input to the - /// Guest Binary - input_data_size: usize, - /// The size of the memory buffer that is made available for input to the - /// Guest Binary - output_data_size: usize, /// The heap size to use in the guest sandbox. If set to 0, the heap /// size will be determined from the PE file header /// @@ -74,17 +68,29 @@ pub struct SandboxConfiguration { interrupt_vcpu_sigrtmin_offset: u8, /// How much writable memory to offer the guest scratch_size: usize, + /// Number of descriptors for the guest-to-host virtqueue. Must be a power of 2. + /// Default: 64 sized to 2x H2G depth for deadlock prevention. + g2h_queue_depth: usize, + /// Number of descriptors for the host-to-guest virtqueue. Must be a power of 2. + /// Default: 32 + h2g_queue_depth: usize, + /// Number of physical pages for the G2H (guest-to-host) buffer pool. + /// When None, falls back to deprecated `output_data_size` or default. + g2h_pool_pages: Option, + /// Number of physical pages for the H2G (host-to-guest) buffer pool. + /// When None, falls back to deprecated `input_data_size` or default. + h2g_pool_pages: Option, + /// Deprecated: use `g2h_pool_pages` instead. + /// When set (non-zero), translates to `g2h_pool_pages` if pool pages + /// are not explicitly configured. + output_data_size: usize, + /// Deprecated: use `h2g_pool_pages` instead. + /// When set (non-zero), translates to `h2g_pool_pages` if pool pages + /// are not explicitly configured. + input_data_size: usize, } impl SandboxConfiguration { - /// The default size of input data - pub const DEFAULT_INPUT_SIZE: usize = 0x4000; - /// The minimum size of input data - pub const MIN_INPUT_SIZE: usize = 0x2000; - /// The default size of output data - pub const DEFAULT_OUTPUT_SIZE: usize = 0x4000; - /// The minimum size of output data - pub const MIN_OUTPUT_SIZE: usize = 0x2000; /// The default interrupt retry delay pub const DEFAULT_INTERRUPT_RETRY_DELAY: Duration = Duration::from_micros(500); /// The default signal offset from `SIGRTMIN` used to determine the signal number for interrupting @@ -93,13 +99,19 @@ impl SandboxConfiguration { pub const DEFAULT_HEAP_SIZE: u64 = 131072; /// The default size of the scratch region pub const DEFAULT_SCRATCH_SIZE: usize = 0x48000; + /// The default G2H virtqueue depth (number of descriptors, must be power of 2) + pub const DEFAULT_G2H_QUEUE_DEPTH: usize = 64; + /// The default H2G virtqueue depth (number of descriptors, must be power of 2) + pub const DEFAULT_H2G_QUEUE_DEPTH: usize = 32; + /// The default number of G2H buffer pool pages + pub const DEFAULT_G2H_POOL_PAGES: usize = 8; + /// The default number of H2G buffer pool pages + pub const DEFAULT_H2G_POOL_PAGES: usize = 4; #[allow(clippy::too_many_arguments)] /// Create a new configuration for a sandbox with the given sizes. #[instrument(skip_all, parent = Span::current(), level= "Trace")] fn new( - input_data_size: usize, - output_data_size: usize, heap_size_override: Option, scratch_size: usize, interrupt_retry_delay: Duration, @@ -108,12 +120,16 @@ impl SandboxConfiguration { #[cfg(crashdump)] guest_core_dump: bool, ) -> Self { Self { - input_data_size: max(input_data_size, Self::MIN_INPUT_SIZE), - output_data_size: max(output_data_size, Self::MIN_OUTPUT_SIZE), heap_size_override: heap_size_override.unwrap_or(0), scratch_size, interrupt_retry_delay, interrupt_vcpu_sigrtmin_offset, + g2h_queue_depth: Self::DEFAULT_G2H_QUEUE_DEPTH, + h2g_queue_depth: Self::DEFAULT_H2G_QUEUE_DEPTH, + g2h_pool_pages: None, + h2g_pool_pages: None, + output_data_size: 0, + input_data_size: 0, #[cfg(gdb)] guest_debug_info, #[cfg(crashdump)] @@ -121,20 +137,6 @@ impl SandboxConfiguration { } } - /// Set the size of the memory buffer that is made available for input to the guest - /// the minimum value is MIN_INPUT_SIZE - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - pub fn set_input_data_size(&mut self, input_data_size: usize) { - self.input_data_size = max(input_data_size, Self::MIN_INPUT_SIZE); - } - - /// Set the size of the memory buffer that is made available for output from the guest - /// the minimum value is MIN_OUTPUT_SIZE - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - pub fn set_output_data_size(&mut self, output_data_size: usize) { - self.output_data_size = max(output_data_size, Self::MIN_OUTPUT_SIZE); - } - /// Set the heap size to use in the guest sandbox. If set to 0, the heap size will be determined from the PE file header #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub fn set_heap_size(&mut self, heap_size: u64) { @@ -195,18 +197,96 @@ impl SandboxConfiguration { } #[instrument(skip_all, parent = Span::current(), level= "Trace")] - pub(crate) fn get_input_data_size(&self) -> usize { - self.input_data_size + pub(crate) fn get_scratch_size(&self) -> usize { + self.scratch_size } + /// Get the G2H virtqueue depth (number of descriptors). #[instrument(skip_all, parent = Span::current(), level= "Trace")] - pub(crate) fn get_output_data_size(&self) -> usize { - self.output_data_size + pub fn get_g2h_queue_depth(&self) -> usize { + self.g2h_queue_depth } + /// Get the H2G virtqueue depth (number of descriptors). #[instrument(skip_all, parent = Span::current(), level= "Trace")] - pub(crate) fn get_scratch_size(&self) -> usize { - self.scratch_size + pub fn get_h2g_queue_depth(&self) -> usize { + self.h2g_queue_depth + } + + /// Set the G2H virtqueue depth (number of descriptors, must be power of 2). + #[instrument(skip_all, parent = Span::current(), level= "Trace")] + pub fn set_g2h_queue_depth(&mut self, depth: usize) { + self.g2h_queue_depth = depth; + } + + /// Set the H2G virtqueue depth (number of descriptors, must be power of 2). + #[instrument(skip_all, parent = Span::current(), level= "Trace")] + pub fn set_h2g_queue_depth(&mut self, depth: usize) { + self.h2g_queue_depth = depth; + } + + /// Get the number of G2H buffer pool pages. + /// + /// Priority: explicit `g2h_pool_pages` > derived from deprecated + /// `output_data_size` > default. + #[instrument(skip_all, parent = Span::current(), level= "Trace")] + pub fn get_g2h_pool_pages(&self) -> usize { + self.g2h_pool_pages.unwrap_or_else(|| { + if self.output_data_size > 0 { + self.output_data_size + .div_ceil(PAGE_SIZE_USIZE) + .max(Self::DEFAULT_G2H_POOL_PAGES) + } else { + Self::DEFAULT_G2H_POOL_PAGES + } + }) + } + + /// Get the number of H2G buffer pool pages. + /// + /// Priority: explicit `h2g_pool_pages` > derived from deprecated + /// `input_data_size` > default. + #[instrument(skip_all, parent = Span::current(), level= "Trace")] + pub fn get_h2g_pool_pages(&self) -> usize { + self.h2g_pool_pages.unwrap_or_else(|| { + if self.input_data_size > 0 { + self.input_data_size + .div_ceil(PAGE_SIZE_USIZE) + .max(Self::DEFAULT_H2G_POOL_PAGES) + } else { + Self::DEFAULT_H2G_POOL_PAGES + } + }) + } + + /// Set the number of G2H buffer pool pages. + #[instrument(skip_all, parent = Span::current(), level= "Trace")] + pub fn set_g2h_pool_pages(&mut self, pages: usize) { + self.g2h_pool_pages = Some(pages); + } + + /// Set the number of H2G buffer pool pages. + #[instrument(skip_all, parent = Span::current(), level= "Trace")] + pub fn set_h2g_pool_pages(&mut self, pages: usize) { + self.h2g_pool_pages = Some(pages); + } + + /// Deprecated: use [`set_g2h_pool_pages`](Self::set_g2h_pool_pages). + /// + /// Sets the output data size. If `g2h_pool_pages` is not explicitly + /// set, this value is translated to pool pages. + #[deprecated(note = "use set_g2h_pool_pages instead")] + pub fn set_output_data_size(&mut self, size: usize) { + self.output_data_size = size; + } + + /// Deprecated: use [`set_h2g_pool_pages`](Self::set_h2g_pool_pages). + /// + /// Sets the input data size. If `h2g_pool_pages` is not explicitly + /// set, this value is translated to pool pages. + #[deprecated(note = "use set_h2g_pool_pages instead")] + pub fn set_input_data_size(&mut self, size: usize) { + self.input_data_size = size; } /// Set the size of the scratch regiong @@ -245,8 +325,6 @@ impl Default for SandboxConfiguration { #[instrument(skip_all, parent = Span::current(), level= "Trace")] fn default() -> Self { Self::new( - Self::DEFAULT_INPUT_SIZE, - Self::DEFAULT_OUTPUT_SIZE, None, Self::DEFAULT_SCRATCH_SIZE, Self::DEFAULT_INTERRUPT_RETRY_DELAY, @@ -266,12 +344,8 @@ mod tests { #[test] fn overrides() { const HEAP_SIZE_OVERRIDE: u64 = 0x50000; - const INPUT_DATA_SIZE_OVERRIDE: usize = 0x4000; - const OUTPUT_DATA_SIZE_OVERRIDE: usize = 0x4001; const SCRATCH_SIZE_OVERRIDE: usize = 0x60000; - let mut cfg = SandboxConfiguration::new( - INPUT_DATA_SIZE_OVERRIDE, - OUTPUT_DATA_SIZE_OVERRIDE, + let cfg = SandboxConfiguration::new( Some(HEAP_SIZE_OVERRIDE), SCRATCH_SIZE_OVERRIDE, SandboxConfiguration::DEFAULT_INTERRUPT_RETRY_DELAY, @@ -286,38 +360,6 @@ mod tests { let scratch_size = cfg.get_scratch_size(); assert_eq!(HEAP_SIZE_OVERRIDE, heap_size); assert_eq!(SCRATCH_SIZE_OVERRIDE, scratch_size); - - cfg.heap_size_override = 2048; - cfg.scratch_size = 0x40000; - assert_eq!(2048, cfg.heap_size_override); - assert_eq!(0x40000, cfg.scratch_size); - assert_eq!(INPUT_DATA_SIZE_OVERRIDE, cfg.input_data_size); - assert_eq!(OUTPUT_DATA_SIZE_OVERRIDE, cfg.output_data_size); - } - - #[test] - fn min_sizes() { - let mut cfg = SandboxConfiguration::new( - SandboxConfiguration::MIN_INPUT_SIZE - 1, - SandboxConfiguration::MIN_OUTPUT_SIZE - 1, - None, - SandboxConfiguration::DEFAULT_SCRATCH_SIZE, - SandboxConfiguration::DEFAULT_INTERRUPT_RETRY_DELAY, - SandboxConfiguration::INTERRUPT_VCPU_SIGRTMIN_OFFSET, - #[cfg(gdb)] - None, - #[cfg(crashdump)] - true, - ); - assert_eq!(SandboxConfiguration::MIN_INPUT_SIZE, cfg.input_data_size); - assert_eq!(SandboxConfiguration::MIN_OUTPUT_SIZE, cfg.output_data_size); - assert_eq!(0, cfg.heap_size_override); - - cfg.set_input_data_size(SandboxConfiguration::MIN_INPUT_SIZE - 1); - cfg.set_output_data_size(SandboxConfiguration::MIN_OUTPUT_SIZE - 1); - - assert_eq!(SandboxConfiguration::MIN_INPUT_SIZE, cfg.input_data_size); - assert_eq!(SandboxConfiguration::MIN_OUTPUT_SIZE, cfg.output_data_size); } mod proptests { @@ -328,21 +370,6 @@ mod tests { use crate::sandbox::config::DebugInfo; proptest! { - #[test] - fn input_data_size(size in SandboxConfiguration::MIN_INPUT_SIZE..=SandboxConfiguration::MIN_INPUT_SIZE * 10) { - let mut cfg = SandboxConfiguration::default(); - cfg.set_input_data_size(size); - prop_assert_eq!(size, cfg.get_input_data_size()); - } - - #[test] - fn output_data_size(size in SandboxConfiguration::MIN_OUTPUT_SIZE..=SandboxConfiguration::MIN_OUTPUT_SIZE * 10) { - let mut cfg = SandboxConfiguration::default(); - cfg.set_output_data_size(size); - prop_assert_eq!(size, cfg.get_output_data_size()); - } - - #[test] fn heap_size_override(size in 0x1000..=0x10000u64) { let mut cfg = SandboxConfiguration::default(); diff --git a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs index 241622cab..bdcd43729 100644 --- a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs +++ b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs @@ -737,7 +737,7 @@ impl MultiUseSandbox { let mut builder = FlatBufferBuilder::with_capacity(estimated_capacity); let buffer = fc.encode(&mut builder); - self.mem_mgr.write_guest_function_call(buffer)?; + self.mem_mgr.write_guest_function_call_virtq(buffer)?; let dispatch_res = self.vm.dispatch_call_from_host( &mut self.mem_mgr, @@ -754,7 +754,7 @@ impl MultiUseSandbox { return Err(error); } - let guest_result = self.mem_mgr.get_guest_function_call_result()?.into_inner(); + let guest_result = self.mem_mgr.read_h2g_result_from_g2h()?.into_inner(); match guest_result { Ok(val) => Ok(val), @@ -782,8 +782,6 @@ impl MultiUseSandbox { // - any serialized host function call are zeroed out by us (the host) during deserialization, see `get_host_function_call` // - any serialized host function result is zeroed out by the guest during deserialization, see `get_host_return_value` if let Err(e) = &res { - self.mem_mgr.clear_io_buffers(); - // Determine if we should poison the sandbox. self.poisoned |= e.is_poison_error(); } @@ -1082,12 +1080,10 @@ mod tests { .unwrap(); } - /// Make sure input/output buffers are properly reset after guest call (with host call) + /// Make sure pool buffers are properly reset after guest call (with host call) #[test] fn io_buffer_reset() { - let mut cfg = SandboxConfiguration::default(); - cfg.set_input_data_size(4096); - cfg.set_output_data_size(4096); + let cfg = SandboxConfiguration::default(); let path = simple_guest_as_string().unwrap(); let mut sandbox = UninitializedSandbox::new(GuestBinary::FilePath(path), Some(cfg)).unwrap(); @@ -1131,21 +1127,21 @@ mod tests { assert_eq!(res, 0); } - // Tests to ensure that many (1000) function calls can be made in a call context with a small stack (24K) and heap(20K). + // Tests to ensure that many (1000) function calls can be made in a call context with a small stack (24K) and heap(32K). // This test effectively ensures that the stack is being properly reset after each call and we are not leaking memory in the Guest. #[test] fn test_with_small_stack_and_heap() { let mut cfg = SandboxConfiguration::default(); - cfg.set_heap_size(20 * 1024); + cfg.set_heap_size(32 * 1024); // min_scratch_size already includes 1 page (4k on most // platforms) of guest stack, so add 20k more to get 24k // total, and then add some more for the eagerly-copied page - // tables on amd64 + // tables on amd64 and virtq pool pages. let min_scratch = hyperlight_common::layout::min_scratch_size( - cfg.get_input_data_size(), - cfg.get_output_data_size(), + cfg.get_g2h_queue_depth(), + cfg.get_h2g_queue_depth(), ); - cfg.set_scratch_size(min_scratch + 0x10000 + 0x10000); + cfg.set_scratch_size(min_scratch + 0x10000 + 0x18000); let mut sbox1: MultiUseSandbox = { let path = simple_guest_as_string().unwrap(); @@ -1198,6 +1194,220 @@ mod tests { assert_eq!(res, 0); } + /// Many snapshot restore cycles with state-modifying guest calls. + #[test] + fn restore_stress_no_pool_corruption() { + let mut sbox: MultiUseSandbox = { + let path = simple_guest_as_string().unwrap(); + let u_sbox = UninitializedSandbox::new(GuestBinary::FilePath(path), None).unwrap(); + u_sbox.evolve() + } + .unwrap(); + + let snapshot = sbox.snapshot().unwrap(); + + for _ in 0..50 { + sbox.restore(snapshot.clone()).unwrap(); + let _ = sbox.call::("AddToStatic", 1i32).unwrap(); + + let res: i32 = sbox.call("GetStatic", ()).unwrap(); + assert_eq!(res, 1); + + let res: i32 = sbox.call("AddToStatic", 2i32).unwrap(); + assert_eq!(res, 3); + } + } + + /// Stress test: snapshot/restore with G2H queue pressure. + #[test] + fn restore_stress_with_host_calls() { + let mut sbox: MultiUseSandbox = { + let path = simple_guest_as_string().unwrap(); + let u_sbox = UninitializedSandbox::new(GuestBinary::FilePath(path), None).unwrap(); + u_sbox.evolve() + } + .unwrap(); + + let snapshot = sbox.snapshot().unwrap(); + + for i in 0..50 { + sbox.restore(snapshot.clone()).unwrap(); + + // Fire-and-forget log oneshots - multiple G2H entries queued + // without waiting for responses + sbox.call::<()>("LogMessageN", 5_i32).unwrap(); + + // G2H round-trip with returned data after logs filled the queue + let echo: String = sbox.call("Echo", "ping".to_string()).unwrap(); + assert_eq!(echo, "ping"); + + // Multiple calls without restore to exercise queue reuse + let res: i32 = sbox.call("AddToStatic", 1i32).unwrap(); + assert_eq!(res, 1); + + let echo2: String = sbox.call("Echo", format!("echo {i}")).unwrap(); + assert_eq!(echo2, format!("echo {i}")); + } + } + + /// Back-to-back restores without any guest call in between. + /// The generation bumps twice but the guest only sees the latest value. + #[test] + fn restore_back_to_back() { + let mut sbox: MultiUseSandbox = { + let path = simple_guest_as_string().unwrap(); + let u_sbox = UninitializedSandbox::new(GuestBinary::FilePath(path), None).unwrap(); + u_sbox.evolve() + } + .unwrap(); + + let _ = sbox.call::("AddToStatic", 42i32).unwrap(); + let snapshot = sbox.snapshot().unwrap(); + + // Two restores in a row, no guest calls between them + sbox.restore(snapshot.clone()).unwrap(); + sbox.restore(snapshot.clone()).unwrap(); + + // Guest should see the snapshot state (static = 42) + let res: i32 = sbox.call("GetStatic", ()).unwrap(); + assert_eq!(res, 42); + + // Another round: three restores, then call + sbox.restore(snapshot.clone()).unwrap(); + sbox.restore(snapshot.clone()).unwrap(); + sbox.restore(snapshot.clone()).unwrap(); + + let res: i32 = sbox.call("AddToStatic", 1i32).unwrap(); + assert_eq!(res, 43); + } + + /// Restore after flooding the G2H queue with log oneshots. + #[test] + fn restore_after_g2h_pressure() { + let mut sbox: MultiUseSandbox = { + let path = simple_guest_as_string().unwrap(); + let u_sbox = UninitializedSandbox::new(GuestBinary::FilePath(path), None).unwrap(); + u_sbox.evolve() + } + .unwrap(); + + let snapshot = sbox.snapshot().unwrap(); + + for _ in 0..20 { + // Flood G2H with many log oneshots to pressure the queue/pool + sbox.call::<()>("LogMessageN", 30_i32).unwrap(); + + // Restore after heavy G2H usage + sbox.restore(snapshot.clone()).unwrap(); + + // Verify queue works cleanly after restore + sbox.call::<()>("LogMessageN", 5_i32).unwrap(); + let echo: String = sbox.call("Echo", "ok".to_string()).unwrap(); + assert_eq!(echo, "ok"); + } + } + + /// Many calls cycling through all descriptor IDs, then restore. + /// Ensures restore handles post-wraparound ring state. + #[test] + fn restore_after_id_wraparound() { + let mut sbox: MultiUseSandbox = { + let path = simple_guest_as_string().unwrap(); + let u_sbox = UninitializedSandbox::new(GuestBinary::FilePath(path), None).unwrap(); + u_sbox.evolve() + } + .unwrap(); + + let snapshot = sbox.snapshot().unwrap(); + + for i in 0..200 { + let res: i32 = sbox.call("AddToStatic", 1i32).unwrap(); + assert_eq!(res, i + 1); + } + + // Restore after IDs have wrapped around many times + sbox.restore(snapshot.clone()).unwrap(); + + let res: i32 = sbox.call("GetStatic", ()).unwrap(); + assert_eq!(res, 0); + + // Do another round of wraparound + restore + for _ in 0..200 { + let _ = sbox.call::("AddToStatic", 1i32).unwrap(); + } + sbox.restore(snapshot.clone()).unwrap(); + + let echo: String = sbox.call("Echo", "after wraparound".to_string()).unwrap(); + assert_eq!(echo, "after wraparound"); + } + + /// Restore after a guest exception recovers the sandbox. + /// The virtqueue must be fully functional after restore despite + /// the guest having been in a broken state. + #[test] + fn restore_after_guest_error() { + let mut sbox: MultiUseSandbox = { + let path = simple_guest_as_string().unwrap(); + let u_sbox = UninitializedSandbox::new(GuestBinary::FilePath(path), None).unwrap(); + u_sbox.evolve() + } + .unwrap(); + + let snapshot = sbox.snapshot().unwrap(); + + // Normal call first + let res: i32 = sbox.call("AddToStatic", 5i32).unwrap(); + assert_eq!(res, 5); + + // Trigger an exception - guest is now in a broken state + let err = sbox.call::<()>("TriggerException", ()); + assert!(err.is_err()); + + // Restore should recover fully + sbox.restore(snapshot.clone()).unwrap(); + + // Verify everything works after recovery + let res: i32 = sbox.call("GetStatic", ()).unwrap(); + assert_eq!(res, 0); + + let echo: String = sbox.call("Echo", "recovered".to_string()).unwrap(); + assert_eq!(echo, "recovered"); + + sbox.call::<()>("LogMessageN", 5_i32).unwrap(); + let res: i32 = sbox.call("AddToStatic", 1i32).unwrap(); + assert_eq!(res, 1); + } + + /// Snapshot immediately after evolve, restore before any calls. + /// Baseline test: the virtqueue has never been used. + #[test] + fn restore_fresh_snapshot() { + let mut sbox: MultiUseSandbox = { + let path = simple_guest_as_string().unwrap(); + let u_sbox = UninitializedSandbox::new(GuestBinary::FilePath(path), None).unwrap(); + u_sbox.evolve() + } + .unwrap(); + + // Snapshot immediately - no guest calls yet + let snapshot = sbox.snapshot().unwrap(); + + sbox.restore(snapshot.clone()).unwrap(); + + // First-ever guest call after restore + let res: i32 = sbox.call("GetStatic", ()).unwrap(); + assert_eq!(res, 0); + + let echo: String = sbox.call("Echo", "first".to_string()).unwrap(); + assert_eq!(echo, "first"); + + // Restore again and verify + sbox.restore(snapshot.clone()).unwrap(); + sbox.call::<()>("LogMessageN", 10_i32).unwrap(); + let res: i32 = sbox.call("AddToStatic", 7i32).unwrap(); + assert_eq!(res, 7); + } + #[test] fn test_trigger_exception_on_guest() { let usbox = UninitializedSandbox::new( @@ -1546,7 +1756,7 @@ mod tests { for (name, heap_size) in test_cases { let mut cfg = SandboxConfiguration::default(); cfg.set_heap_size(heap_size); - cfg.set_scratch_size(0x100000); + cfg.set_scratch_size(heap_size as usize + 0x100000); let path = simple_guest_as_string().unwrap(); let sbox = UninitializedSandbox::new(GuestBinary::FilePath(path), Some(cfg)) diff --git a/src/hyperlight_host/src/sandbox/outb.rs b/src/hyperlight_host/src/sandbox/outb.rs index 9704a1fe3..f6d7cc4a8 100644 --- a/src/hyperlight_host/src/sandbox/outb.rs +++ b/src/hyperlight_host/src/sandbox/outb.rs @@ -16,10 +16,13 @@ limitations under the License. use std::sync::{Arc, Mutex}; +use hyperlight_common::flatbuffer_wrappers::function_call::FunctionCall; use hyperlight_common::flatbuffer_wrappers::function_types::{FunctionCallResult, ParameterValue}; use hyperlight_common::flatbuffer_wrappers::guest_error::{ErrorCode, GuestError}; use hyperlight_common::flatbuffer_wrappers::guest_log_data::GuestLogData; use hyperlight_common::outb::{Exception, OutBAction}; +use hyperlight_common::virtq::msg::{MsgKind, VirtqMsgHeader}; +use hyperlight_common::virtq::{self}; use log::{Level, Record}; use tracing::{Span, instrument}; use tracing_log::format_trace; @@ -61,46 +64,50 @@ pub enum HandleOutbError { MemProfile(String), } -#[instrument(err(Debug), skip_all, parent = Span::current(), level="Trace")] -pub(super) fn outb_log( - mgr: &mut SandboxMemoryManager, -) -> Result<(), HandleOutbError> { - // This code will create either a logging record or a tracing record for the GuestLogData depending on if the host has set up a tracing subscriber. - // In theory as we have enabled the log feature in the Cargo.toml for tracing this should happen - // automatically (based on if there is tracing subscriber present) but only works if the event created using macros. (see https://github.com/tokio-rs/tracing/blob/master/tracing/src/macros.rs#L2421 ) - // The reason that we don't want to use the tracing macros is that we want to be able to explicitly - // set the file and line number for the log record which is not possible with macros. - // This is because the file and line number come from the guest not the call site. - - let log_data: GuestLogData = mgr - .read_guest_log_data() - .map_err(|e| HandleOutbError::ReadLogData(e.to_string()))?; +/// Emit a guest log record from a virtqueue payload. +/// +/// Deserializes [`GuestLogData`] from the raw bytes and emits either +/// a tracing event or a log record. +pub(crate) fn emit_guest_log(payload: &[u8]) { + let Ok(log_data) = GuestLogData::try_from(payload) else { + return; + }; + + // This code will create either a logging record or a tracing record + // for the GuestLogData depending on if the host has set up a tracing + // subscriber. + // In theory as we have enabled the log feature in the Cargo.toml for + // tracing this should happen automatically (based on if there is a + // tracing subscriber present) but only works if the event is created + // using macros. + // (see https://github.com/tokio-rs/tracing/blob/master/tracing/src/macros.rs#L2421) + // The reason that we don't want to use the tracing macros is that we + // want to be able to explicitly set the file and line number for the + // log record which is not possible with macros. + // This is because the file and line number come from the guest not + // the call site. let record_level: Level = (&log_data.level).into(); - // Work out if we need to log or trace - // this API is marked as follows but it is the easiest way to work out if we should trace or log - - // Private API for internal use by tracing's macros. - // - // This function is *not* considered part of `tracing`'s public API, and has no - // stability guarantees. If you use it, and it breaks or disappears entirely, - // don't say we didn't warn you. - + // Work out if we need to log or trace. + // This API is marked as internal but it is the easiest way to work + // out if we should trace or log. let should_trace = tracing_core::dispatcher::has_been_set(); let source_file = Some(log_data.source_file.as_str()); let line = Some(log_data.line); let source = Some(log_data.source.as_str()); - // See https://github.com/rust-lang/rust/issues/42253 for the reason this has to be done this way + // See https://github.com/rust-lang/rust/issues/42253 for the reason + // this has to be done this way. if should_trace { - // Create a tracing event for the GuestLogData - // Ideally we would create tracing metadata based on the Guest Log Data - // but tracing derives the metadata at compile time + // Create a tracing event for the GuestLogData. + // Ideally we would create tracing metadata based on the Guest + // Log Data but tracing derives the metadata at compile time. // see https://github.com/tokio-rs/tracing/issues/2419 - // so we leave it up to the subscriber to figure out that there are logging fields present with this data - format_trace( + // So we leave it up to the subscriber to figure out that there + // are logging fields present with this data. + let _ = format_trace( &Record::builder() .args(format_args!("{}", log_data.message)) .level(record_level) @@ -109,8 +116,7 @@ pub(super) fn outb_log( .line(line) .module_path(source) .build(), - ) - .map_err(|e| HandleOutbError::TraceFormat(e.to_string()))?; + ); } else { // Create a log record for the GuestLogData log::logger().log( @@ -124,8 +130,6 @@ pub(super) fn outb_log( .build(), ); } - - Ok(()) } const ABORT_TERMINATOR: u8 = 0xFF; @@ -180,6 +184,128 @@ fn outb_abort( Ok(()) } +/// Handle a guest-to-host function call received via the G2H virtqueue. +/// +/// Log entries that arrive before the Request are processed inline. +fn outb_virtq_call( + mem_mgr: &mut SandboxMemoryManager, + host_funcs: &Arc>, +) -> Result<(), HandleOutbError> { + let g2h_pool_size = mem_mgr.g2h_pool_size(); + + let consumer = mem_mgr.g2h_consumer.as_mut().ok_or_else(|| { + HandleOutbError::ReadHostFunctionCall("G2H consumer not initialized".into()) + })?; + + // Drain entries, processing Log messages, until we find a Request. + let (entry, completion) = loop { + let Ok(maybe_next) = consumer.poll(g2h_pool_size) else { + return Err(HandleOutbError::ReadHostFunctionCall( + "G2H poll failed".into(), + )); + }; + + let Some((entry, completion)) = maybe_next else { + // No G2H entry - backpressure-only notify or prefill notify. + return Ok(()); + }; + + let hdr_size = VirtqMsgHeader::SIZE; + let entry_data = entry.data(); + + if entry_data.len() < hdr_size { + return Err(HandleOutbError::ReadHostFunctionCall( + "G2H entry too short".into(), + )); + } + + let hdr: VirtqMsgHeader = *bytemuck::from_bytes(&entry_data[..hdr_size]); + + match hdr.msg_kind() { + Ok(MsgKind::Log) => { + let available = entry_data.len() - hdr_size; + let log_len = (hdr.payload_len as usize).min(available); + let payload = &entry_data[hdr_size..hdr_size + log_len]; + + emit_guest_log(payload); + + consumer.complete(completion).map_err(|e| { + HandleOutbError::ReadHostFunctionCall(format!("G2H complete log: {e}")) + })?; + + continue; + } + Ok(MsgKind::Request) => break (entry, completion), + Ok(other) => { + return Err(HandleOutbError::ReadHostFunctionCall(format!( + "G2H: expected Request via outb, got {:?}", + other + ))); + } + Err(unknown) => { + return Err(HandleOutbError::ReadHostFunctionCall(format!( + "G2H: unknown message kind: 0x{unknown:02x}" + ))); + } + } + }; + + // Validate completion buffer before calling the host function + let virtq::SendCompletion::Writable(mut wc) = completion else { + return Err(HandleOutbError::WriteHostFunctionResponse( + "G2H: expected writable completion, got ack (ring corruption)".into(), + )); + }; + + let entry_data = entry.data(); + let hdr: VirtqMsgHeader = *bytemuck::from_bytes(&entry_data[..VirtqMsgHeader::SIZE]); + let available = entry_data.len() - VirtqMsgHeader::SIZE; + let payload_len = (hdr.payload_len as usize).min(available); + let payload = &entry_data[VirtqMsgHeader::SIZE..VirtqMsgHeader::SIZE + payload_len]; + + let call = FunctionCall::try_from(payload) + .map_err(|e| HandleOutbError::ReadHostFunctionCall(e.to_string()))?; + + let name = call.function_name.clone(); + let args: Vec = call.parameters.unwrap_or(vec![]); + + let registry = host_funcs + .try_lock() + .map_err(|e| HandleOutbError::LockFailed(file!(), line!(), e.to_string()))?; + + let res = registry + .call_host_function(&name, args) + .map_err(|e| GuestError::new(ErrorCode::HostFunctionError, e.to_string())); + + let func_result = FunctionCallResult::new(res); + let mut builder = flatbuffers::FlatBufferBuilder::new(); + let mut result_payload = func_result.encode(&mut builder).to_vec(); + + let total = VirtqMsgHeader::SIZE + result_payload.len(); + if total > wc.capacity() { + let too_large = GuestError::new( + ErrorCode::HostFunctionError, + "response too large for completion buffer".into(), + ); + let fallback = FunctionCallResult::new(Err(too_large)); + let mut fb = flatbuffers::FlatBufferBuilder::new(); + result_payload = fallback.encode(&mut fb).to_vec(); + } + + let resp_header = VirtqMsgHeader::new(MsgKind::Response, 0, result_payload.len() as u32); + let resp_header_bytes = bytemuck::bytes_of(&resp_header); + + wc.write_all(resp_header_bytes) + .map_err(|e| HandleOutbError::WriteHostFunctionResponse(format!("{e}")))?; + wc.write_all(&result_payload) + .map_err(|e| HandleOutbError::WriteHostFunctionResponse(format!("{e}")))?; + consumer + .complete(wc.into()) + .map_err(|e| HandleOutbError::WriteHostFunctionResponse(format!("{e}")))?; + + Ok(()) +} + /// Handles OutB operations from the guest. #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] pub(crate) fn handle_outb( @@ -194,25 +320,9 @@ pub(crate) fn handle_outb( .try_into() .map_err(|e: anyhow::Error| HandleOutbError::InvalidPort(e.to_string()))? { - OutBAction::Log => outb_log(mem_mgr), - OutBAction::CallFunction => { - let call = mem_mgr - .get_host_function_call() - .map_err(|e| HandleOutbError::ReadHostFunctionCall(e.to_string()))?; - let name = call.function_name.clone(); - let args: Vec = call.parameters.unwrap_or(vec![]); - let res = host_funcs - .try_lock() - .map_err(|e| HandleOutbError::LockFailed(file!(), line!(), e.to_string()))? - .call_host_function(&name, args) - .map_err(|e| GuestError::new(ErrorCode::HostFunctionError, e.to_string())); - - let func_result = FunctionCallResult::new(res); - - mem_mgr - .write_response_from_host_function_call(&func_result) - .map_err(|e| HandleOutbError::WriteHostFunctionResponse(e.to_string()))?; - + OutBAction::Log | OutBAction::CallFunction => { + // Legacy paths removed - these actions should no longer be + // emitted by the guest. Ignore gracefully. Ok(()) } OutBAction::Abort => outb_abort(mem_mgr, data), @@ -227,6 +337,7 @@ pub(crate) fn handle_outb( eprint!("{}", ch); Ok(()) } + OutBAction::VirtqNotify => outb_virtq_call(mem_mgr, host_funcs), #[cfg(feature = "trace_guest")] OutBAction::TraceBatch => Ok(()), #[cfg(feature = "mem_profile")] @@ -235,251 +346,3 @@ pub(crate) fn handle_outb( OutBAction::TraceMemoryFree => trace_info.handle_trace_mem_free(regs, mem_mgr), } } -#[cfg(test)] -mod tests { - use hyperlight_common::flatbuffer_wrappers::guest_log_level::LogLevel; - use hyperlight_testing::logger::{LOGGER, Logger}; - use hyperlight_testing::simple_guest_as_string; - use log::Level; - use tracing_core::callsite::rebuild_interest_cache; - - use super::outb_log; - use crate::GuestBinary; - use crate::mem::mgr::SandboxMemoryManager; - use crate::sandbox::SandboxConfiguration; - use crate::sandbox::outb::GuestLogData; - use crate::testing::log_values::test_value_as_str; - - fn new_guest_log_data(level: LogLevel) -> GuestLogData { - GuestLogData::new( - "test log".to_string(), - "test source".to_string(), - level, - "test caller".to_string(), - "test source file".to_string(), - 123, - ) - } - - #[test] - #[ignore] - fn test_log_outb_log() { - Logger::initialize_test_logger(); - LOGGER.set_max_level(log::LevelFilter::Off); - - let sandbox_cfg = SandboxConfiguration::default(); - - let new_mgr = || { - let bin = GuestBinary::FilePath(simple_guest_as_string().unwrap()); - let snapshot = crate::sandbox::snapshot::Snapshot::from_env(bin, sandbox_cfg).unwrap(); - let mgr = SandboxMemoryManager::from_snapshot(&snapshot).unwrap(); - let (hmgr, _) = mgr.build().unwrap(); - hmgr - }; - { - // We set a logger but there is no guest log data - // in memory, so expect a log operation to fail - let mut mgr = new_mgr(); - assert!(outb_log(&mut mgr).is_err()); - } - { - // Write a log message so outb_log will succeed. - // Since the logger level is set off, expect logs to be no-ops - let mut mgr = new_mgr(); - let log_msg = new_guest_log_data(LogLevel::Information); - - let guest_log_data_buffer: Vec = log_msg.try_into().unwrap(); - let offset = mgr.layout.get_output_data_buffer_scratch_host_offset(); - mgr.scratch_mem - .push_buffer( - offset, - sandbox_cfg.get_output_data_size(), - &guest_log_data_buffer, - ) - .unwrap(); - - let res = outb_log(&mut mgr); - assert!(res.is_ok()); - assert_eq!(0, LOGGER.num_log_calls()); - LOGGER.clear_log_calls(); - } - { - // now, test logging - LOGGER.set_max_level(log::LevelFilter::Trace); - let mut mgr = new_mgr(); - LOGGER.clear_log_calls(); - - // set up the logger and set the log level to the maximum - // possible (Trace) to ensure we're able to test all - // the possible branches of the match in outb_log - - let levels = vec![ - LogLevel::Trace, - LogLevel::Debug, - LogLevel::Information, - LogLevel::Warning, - LogLevel::Error, - LogLevel::Critical, - LogLevel::None, - ]; - for level in levels { - let layout = mgr.layout; - let log_data = new_guest_log_data(level); - - let guest_log_data_buffer: Vec = log_data.clone().try_into().unwrap(); - mgr.scratch_mem - .push_buffer( - layout.get_output_data_buffer_scratch_host_offset(), - sandbox_cfg.get_output_data_size(), - guest_log_data_buffer.as_slice(), - ) - .unwrap(); - - outb_log(&mut mgr).unwrap(); - - LOGGER.test_log_records(|log_calls| { - let expected_level: Level = (&level).into(); - - assert!( - log_calls - .iter() - .filter(|log_call| { - log_call.level == expected_level - && log_call.line == Some(log_data.line) - && log_call.args == log_data.message - && log_call.module_path == Some(log_data.source.clone()) - && log_call.file == Some(log_data.source_file.clone()) - }) - .count() - == 1, - "log call did not occur for level {:?}", - level.clone() - ); - }); - } - } - } - - // Tests that outb_log emits traces when a trace subscriber is set - // this test is ignored because it is incompatible with other tests , specifically those which require a logger for tracing - // marking this test as ignored means that running `cargo test` will not run this test but will allow a developer who runs that command - // from their workstation to be successful without needed to know about test interdependencies - // this test will be run explicitly as a part of the CI pipeline - #[ignore] - #[test] - fn test_trace_outb_log() { - Logger::initialize_log_tracer(); - rebuild_interest_cache(); - let subscriber = - hyperlight_testing::tracing_subscriber::TracingSubscriber::new(tracing::Level::TRACE); - let sandbox_cfg = SandboxConfiguration::default(); - tracing::subscriber::with_default(subscriber.clone(), || { - let new_mgr = || { - let bin = GuestBinary::FilePath(simple_guest_as_string().unwrap()); - let snapshot = - crate::sandbox::snapshot::Snapshot::from_env(bin, sandbox_cfg).unwrap(); - let mgr = SandboxMemoryManager::from_snapshot(&snapshot).unwrap(); - let (hmgr, _) = mgr.build().unwrap(); - hmgr - }; - - // as a span does not exist one will be automatically created - // after that there will be an event for each log message - // we are interested only in the events for the log messages that we created - - let levels = vec![ - LogLevel::Trace, - LogLevel::Debug, - LogLevel::Information, - LogLevel::Warning, - LogLevel::Error, - LogLevel::Critical, - LogLevel::None, - ]; - for level in levels { - let mut mgr = new_mgr(); - let layout = mgr.layout; - let log_data: GuestLogData = new_guest_log_data(level); - subscriber.clear(); - - let guest_log_data_buffer: Vec = log_data.try_into().unwrap(); - mgr.scratch_mem - .push_buffer( - layout.get_output_data_buffer_scratch_host_offset(), - sandbox_cfg.get_output_data_size(), - guest_log_data_buffer.as_slice(), - ) - .unwrap(); - subscriber.clear(); - outb_log(&mut mgr).unwrap(); - - subscriber.test_trace_records(|spans, events| { - let expected_level = match level { - LogLevel::Trace => "TRACE", - LogLevel::Debug => "DEBUG", - LogLevel::Information => "INFO", - LogLevel::Warning => "WARN", - LogLevel::Error => "ERROR", - LogLevel::Critical => "ERROR", - LogLevel::None => "TRACE", - }; - - // We cannot get the parent span using the `current_span()` method as by the time we get to this point that span has been exited so there is no current span - // We need to make sure that the span that we created is in the spans map instead - // We expect to have created 21 spans at this point. We are only interested in the first one that was created when calling outb_log. - - assert!( - spans.len() == 21, - "expected 21 spans, found {}", - spans.len() - ); - - let span_value = spans - .get(&1) - .unwrap() - .as_object() - .unwrap() - .get("span") - .unwrap() - .get("attributes") - .unwrap() - .as_object() - .unwrap() - .get("metadata") - .unwrap() - .as_object() - .unwrap(); - - //test_value_as_str(span_value, "level", "INFO"); - test_value_as_str(span_value, "module_path", "hyperlight_host::sandbox::outb"); - let expected_file = if cfg!(windows) { - "src\\hyperlight_host\\src\\sandbox\\outb.rs" - } else { - "src/hyperlight_host/src/sandbox/outb.rs" - }; - test_value_as_str(span_value, "file", expected_file); - test_value_as_str(span_value, "target", "hyperlight_host::sandbox::outb"); - - let mut count_matching_events = 0; - - for json_value in events { - let event_values = json_value.as_object().unwrap().get("event").unwrap(); - let metadata_values_map = - event_values.get("metadata").unwrap().as_object().unwrap(); - let event_values_map = event_values.as_object().unwrap(); - test_value_as_str(metadata_values_map, "level", expected_level); - test_value_as_str(event_values_map, "log.file", "test source file"); - test_value_as_str(event_values_map, "log.module_path", "test source"); - test_value_as_str(event_values_map, "log.target", "hyperlight_guest"); - count_matching_events += 1; - } - assert!( - count_matching_events == 1, - "trace log call did not occur for level {:?}", - level.clone() - ); - }); - } - }); - } -} diff --git a/src/hyperlight_host/src/sandbox/uninitialized.rs b/src/hyperlight_host/src/sandbox/uninitialized.rs index 23c01be28..f2a4fcfcd 100644 --- a/src/hyperlight_host/src/sandbox/uninitialized.rs +++ b/src/hyperlight_host/src/sandbox/uninitialized.rs @@ -636,8 +636,6 @@ mod tests { // Non default memory configuration let cfg = { let mut cfg = SandboxConfiguration::default(); - cfg.set_input_data_size(0x1000); - cfg.set_output_data_size(0x1000); cfg.set_heap_size(0x1000); Some(cfg) }; @@ -1390,11 +1388,11 @@ mod tests { let _evolved: MultiUseSandbox = sandbox.evolve().expect("Failed to evolve sandbox"); } - // Test 4: Create snapshot with custom input/output buffer sizes + // Test 4: Create snapshot with custom pool page sizes { let mut cfg = SandboxConfiguration::default(); - cfg.set_input_data_size(64 * 1024); // 64KB input - cfg.set_output_data_size(64 * 1024); // 64KB output + cfg.set_h2g_pool_pages(16); // 16 pages + cfg.set_g2h_pool_pages(16); // 16 pages let env = GuestEnvironment::new(GuestBinary::FilePath(binary_path.clone()), None); @@ -1418,9 +1416,7 @@ mod tests { { let mut cfg = SandboxConfiguration::default(); cfg.set_heap_size(32 * 1024 * 1024); // 32MB heap - cfg.set_scratch_size(256 * 1024 * 2); // 512KB scratch (256KB will be input/output) - cfg.set_input_data_size(128 * 1024); // 128KB input - cfg.set_output_data_size(128 * 1024); // 128KB output + cfg.set_scratch_size(256 * 1024 * 2); // 512KB scratch let env = GuestEnvironment::new(GuestBinary::FilePath(binary_path.clone()), None); diff --git a/src/hyperlight_host/src/testing/log_values.rs b/src/hyperlight_host/src/testing/log_values.rs deleted file mode 100644 index 47f40ae0a..000000000 --- a/src/hyperlight_host/src/testing/log_values.rs +++ /dev/null @@ -1,62 +0,0 @@ -/* -Copyright 2025 The Hyperlight Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -use serde_json::{Map, Value}; - -use crate::{Result, new_error}; - -/// Call `check_value_as_str` and panic if it returned an `Err`. Otherwise, -/// do nothing. -#[track_caller] -pub(crate) fn test_value_as_str(values: &Map, key: &str, expected_value: &str) { - if let Err(e) = check_value_as_str(values, key, expected_value) { - panic!("{e:?}"); - } -} - -/// Check to see if the value in `values` for key `key` matches -/// `expected_value`. If so, return `Ok(())`. Otherwise, return an `Err` -/// indicating the mismatch. -pub(crate) fn check_value_as_str( - values: &Map, - key: &str, - expected_value: &str, -) -> Result<()> { - let value = try_to_string(values, key)?; - if expected_value != value { - return Err(new_error!( - "expected value {} != value {}", - expected_value, - value - )); - } - Ok(()) -} - -/// Fetch the value in `values` with key `key` and, if it existed, convert -/// it to a string. If all those steps succeeded, return an `Ok` with the -/// string value inside. Otherwise, return an `Err`. -fn try_to_string<'a>(values: &'a Map, key: &'a str) -> Result<&'a str> { - if let Some(value) = values.get(key) { - if let Some(value_str) = value.as_str() { - Ok(value_str) - } else { - Err(new_error!("value with key {} was not a string", key)) - } - } else { - Err(new_error!("value for key {} was not found", key)) - } -} diff --git a/src/hyperlight_host/src/testing/mod.rs b/src/hyperlight_host/src/testing/mod.rs index 26776b405..503fda1ee 100644 --- a/src/hyperlight_host/src/testing/mod.rs +++ b/src/hyperlight_host/src/testing/mod.rs @@ -13,4 +13,3 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -pub(crate) mod log_values; diff --git a/src/hyperlight_host/tests/common/mod.rs b/src/hyperlight_host/tests/common/mod.rs index d58e60aa6..8b2f6de9f 100644 --- a/src/hyperlight_host/tests/common/mod.rs +++ b/src/hyperlight_host/tests/common/mod.rs @@ -80,6 +80,16 @@ where f(sandbox); } +/// Runs a test with a Rust guest UninitializedSandbox using custom configuration. +pub fn with_rust_uninit_sandbox_cfg(cfg: SandboxConfiguration, f: F) +where + F: FnOnce(UninitializedSandbox), +{ + let sandbox = + UninitializedSandbox::new(GuestBinary::FilePath(rust_guest_path()), Some(cfg)).unwrap(); + f(sandbox); +} + // ============================================================================= // C guest helpers // ============================================================================= diff --git a/src/hyperlight_host/tests/integration_test.rs b/src/hyperlight_host/tests/integration_test.rs index cc7b7587d..091ce699a 100644 --- a/src/hyperlight_host/tests/integration_test.rs +++ b/src/hyperlight_host/tests/integration_test.rs @@ -30,6 +30,7 @@ pub mod common; // pub to disable dead_code warning use crate::common::{ new_rust_sandbox, new_rust_uninit_sandbox, with_all_sandboxes, with_c_sandbox, with_c_uninit_sandbox, with_rust_sandbox, with_rust_sandbox_cfg, with_rust_uninit_sandbox, + with_rust_uninit_sandbox_cfg, }; // A host function cannot be interrupted, but we can at least make sure after requesting to interrupt a host call, @@ -535,7 +536,7 @@ fn guest_malloc_abort() { }); // allocate a vector (on heap) that is bigger than the heap - let heap_size = 0x4000; + let heap_size = 0x8000; let size_to_allocate = 0x10000; assert!( size_to_allocate > heap_size, @@ -544,6 +545,10 @@ fn guest_malloc_abort() { let mut cfg = SandboxConfiguration::default(); cfg.set_heap_size(heap_size); + cfg.set_g2h_queue_depth(2); + cfg.set_h2g_queue_depth(2); + cfg.set_g2h_pool_pages(3); + cfg.set_h2g_pool_pages(1); with_rust_sandbox_cfg(cfg, |mut sbox2| { let err = sbox2 .call::( @@ -578,48 +583,16 @@ fn guest_outb_with_invalid_port_poisons_sandbox() { }); } -#[test] -fn corrupt_output_size_prefix_rejected() { - with_rust_sandbox(|mut sbox| { - let res = sbox.call::("CorruptOutputSizePrefix", ()); - assert!( - res.is_err(), - "Expected error when guest corrupts size prefix, got: {:?}", - res, - ); - let err_msg = format!("{:?}", res.unwrap_err()); - assert!( - err_msg.contains("Corrupt buffer size prefix: flatbuffer claims 4294967295 bytes but the element slot is only 8 bytes"), - "Unexpected error message: {err_msg}" - ); - }); -} - -#[test] -fn corrupt_output_back_pointer_rejected() { - with_rust_sandbox(|mut sbox| { - let res = sbox.call::("CorruptOutputBackPointer", ()); - assert!( - res.is_err(), - "Expected error when guest corrupts back-pointer, got: {:?}", - res, - ); - let err_msg = format!("{:?}", res.unwrap_err()); - assert!( - err_msg.contains( - "Corrupt buffer back-pointer: element offset 57005 is outside valid range [8, 8]" - ), - "Unexpected error message: {err_msg}" - ); - }); -} - #[test] fn guest_panic_no_alloc() { - let heap_size = 0x4000; + let heap_size = 0x8000; let mut cfg = SandboxConfiguration::default(); cfg.set_heap_size(heap_size); + cfg.set_g2h_queue_depth(2); + cfg.set_h2g_queue_depth(2); + cfg.set_g2h_pool_pages(3); + cfg.set_h2g_pool_pages(1); with_rust_sandbox_cfg(cfg, |mut sbox| { let res = sbox .call::( @@ -743,10 +716,10 @@ fn log_message() { // follows: // - logs from trace level tracing spans created as logs because of the tracing `log` feature // - 4 from evolve call (generic_init + hyperlight_main) - // - 8 from guest call + // - 4 from guest call (call_host_function + read_n_bytes_from_user_memory) // and are multiplied because we make 6 calls to `log_test_messages` // NOTE: These numbers need to be updated if log messages or spans are added/removed - let num_fixed_trace_log = 12 * 6; + let num_fixed_trace_log = 8 * 6; // Calculate fixed info logs // - 4 logs per iteration from infrastructure at Info level (internal_dispatch_function) @@ -828,6 +801,167 @@ fn log_test_messages(levelfilter: Option) { } } +// The following tests depend on a global SimpleLogger or TracingSubscriber and +// cannot run in parallel with other tests. They are marked #[ignore] and run +// sequentially via `just test-isolated`. + +#[test] +#[ignore] +fn virtq_log_delivery() { + use hyperlight_testing::simplelogger::{LOGGER, SimpleLogger}; + + SimpleLogger::initialize_test_logger(); + LOGGER.clear_log_calls(); + + with_rust_uninit_sandbox(|mut sbox| { + sbox.set_max_guest_log_level(tracing_core::LevelFilter::TRACE); + let mut sandbox = sbox.evolve().unwrap(); + + sandbox + .call::<()>("LogMessage", ("virtq log test message".to_string(), 3_i32)) + .unwrap(); + + let count = LOGGER.num_log_calls(); + let mut found = false; + for i in 0..count { + if let Some(call) = LOGGER.get_log_call(i) + && call.target == "hyperlight_guest" + && call.args.contains("virtq log test") + { + found = true; + break; + } + } + assert!(found, "expected 'virtq log test' message from guest"); + LOGGER.clear_log_calls(); + }); +} + +#[test] +#[ignore] +fn virtq_log_backpressure() { + use hyperlight_testing::simplelogger::{LOGGER, SimpleLogger}; + + SimpleLogger::initialize_test_logger(); + LOGGER.clear_log_calls(); + + let mut cfg = SandboxConfiguration::default(); + cfg.set_g2h_pool_pages(2); + + with_rust_uninit_sandbox_cfg(cfg, |mut sbox| { + sbox.set_max_guest_log_level(tracing_core::LevelFilter::INFO); + let mut sandbox = sbox.evolve().unwrap(); + + sandbox.call::<()>("LogMessageN", 50_i32).unwrap(); + + let res: i32 = sandbox + .call("ThisIsNotARealFunctionButTheNameIsImportant", ()) + .unwrap(); + assert_eq!(res, 99); + + let guest_count = (0..LOGGER.num_log_calls()) + .filter_map(|i| LOGGER.get_log_call(i)) + .filter(|c| c.target == "hyperlight_guest" && c.args.contains("log entry")) + .count(); + assert_eq!(guest_count, 50, "expected 50 guest logs, got {guest_count}"); + LOGGER.clear_log_calls(); + }); +} + +#[test] +#[ignore] +fn virtq_log_backpressure_repeated() { + let mut cfg = SandboxConfiguration::default(); + cfg.set_g2h_pool_pages(2); + + with_rust_sandbox_cfg(cfg, |mut sandbox| { + for _ in 0..5 { + sandbox.call::<()>("LogMessageN", 30_i32).unwrap(); + } + }); +} + +#[test] +#[ignore] +fn virtq_backpressure_small_ring() { + use hyperlight_testing::simplelogger::{LOGGER, SimpleLogger}; + + SimpleLogger::initialize_test_logger(); + LOGGER.clear_log_calls(); + + let mut cfg = SandboxConfiguration::default(); + cfg.set_g2h_queue_depth(4); + + with_rust_uninit_sandbox_cfg(cfg, |mut sbox| { + sbox.set_max_guest_log_level(tracing_core::LevelFilter::INFO); + let mut sandbox = sbox.evolve().unwrap(); + + sandbox.call::<()>("LogMessageN", 20_i32).unwrap(); + + let guest_count = (0..LOGGER.num_log_calls()) + .filter_map(|i| LOGGER.get_log_call(i)) + .filter(|c| c.target == "hyperlight_guest" && c.args.contains("log entry")) + .count(); + assert_eq!(guest_count, 20, "expected 20 guest logs, got {guest_count}"); + LOGGER.clear_log_calls(); + }); +} + +#[test] +#[ignore] +fn virtq_log_tracing_delivery() { + use hyperlight_testing::tracing_subscriber::TracingSubscriber; + + let subscriber = TracingSubscriber::new(tracing::Level::TRACE); + + tracing::subscriber::with_default(subscriber.clone(), || { + with_rust_uninit_sandbox(|mut sbox| { + sbox.set_max_guest_log_level(tracing_core::LevelFilter::INFO); + let mut sandbox = sbox.evolve().unwrap(); + + subscriber.clear(); + + sandbox + .call::<()>("LogMessage", ("tracing delivery test".to_string(), 3_i32)) + .unwrap(); + + let events = subscriber.get_events(); + assert!( + !events.is_empty(), + "expected tracing events after guest log call, got none" + ); + }); + }); +} + +#[test] +#[ignore] +fn virtq_log_tracing_levels() { + use hyperlight_testing::tracing_subscriber::TracingSubscriber; + + let subscriber = TracingSubscriber::new(tracing::Level::TRACE); + + tracing::subscriber::with_default(subscriber.clone(), || { + with_rust_uninit_sandbox(|mut sbox| { + sbox.set_max_guest_log_level(tracing_core::LevelFilter::TRACE); + let mut sandbox = sbox.evolve().unwrap(); + + for level in [1_i32, 2, 3, 4, 5] { + subscriber.clear(); + let msg = format!("level-test-{}", level); + sandbox.call::<()>("LogMessage", (msg, level)).unwrap(); + + let events = subscriber.get_events(); + assert!( + !events.is_empty(), + "expected tracing events for guest log level {}", + level + ); + } + }); + }); +} + /// Tests whether host is able to return Bool as return type /// or not #[test] @@ -1679,7 +1813,9 @@ fn exception_handler_installation_and_validation() { /// This validates that the exception handling path does not require heap allocations. #[test] fn fill_heap_and_cause_exception() { - with_rust_sandbox(|mut sandbox| { + let mut cfg = SandboxConfiguration::default(); + cfg.set_scratch_size(0x60000); + with_rust_sandbox_cfg(cfg, |mut sandbox| { let result = sandbox.call::<()>("FillHeapAndCauseException", ()); // The call should fail with an exception error since there's no handler installed diff --git a/src/hyperlight_host/tests/sandbox_host_tests.rs b/src/hyperlight_host/tests/sandbox_host_tests.rs index e0daf969b..e48a0f38c 100644 --- a/src/hyperlight_host/tests/sandbox_host_tests.rs +++ b/src/hyperlight_host/tests/sandbox_host_tests.rs @@ -26,7 +26,7 @@ use hyperlight_testing::simple_guest_as_string; pub mod common; // pub to disable dead_code warning use crate::common::{ with_all_sandboxes, with_all_sandboxes_cfg, with_all_sandboxes_with_writer, - with_all_uninit_sandboxes, + with_all_uninit_sandboxes, with_rust_sandbox_cfg, with_rust_uninit_sandbox_cfg, }; #[test] @@ -212,9 +212,7 @@ fn incorrect_parameter_num() { #[test] fn small_scratch_sandbox() { let mut cfg = SandboxConfiguration::default(); - cfg.set_scratch_size(0x48000); - cfg.set_input_data_size(0x24000); - cfg.set_output_data_size(0x24000); + cfg.set_scratch_size(0x1000); let a = UninitializedSandbox::new( GuestBinary::FilePath(simple_guest_as_string().unwrap()), Some(cfg), @@ -375,3 +373,278 @@ fn host_function_error() { } }); } + +#[test] +fn virtq_log_with_callback() { + // Verify that log messages interleaved with host callbacks work + with_all_uninit_sandboxes(|mut sandbox| { + let (tx, _rx) = channel(); + sandbox + .register("HostMethod1", move |msg: String| { + let len = msg.len(); + tx.send(msg).unwrap(); + len as i32 + }) + .unwrap(); + let mut sandbox = sandbox.evolve().unwrap(); + + // Echo triggers guest-side logging infrastructure, then returns. + // This validates that log ReadOnly entries interleaved with + // function call ReadWrite entries don't corrupt the G2H queue. + let res: String = sandbox.call("Echo", "test".to_string()).unwrap(); + assert_eq!(res, "test"); + }); +} + +#[test] +fn virtq_backpressure_log_then_callback() { + // Logs fill the G2H ring, then a host callback needs ring space. + // call_host_function handles backpressure by notify + reclaim + retry. + let mut cfg = SandboxConfiguration::default(); + cfg.set_g2h_queue_depth(4); + cfg.set_g2h_pool_pages(2); + + with_rust_uninit_sandbox_cfg(cfg, |mut sbox| { + sbox.set_max_guest_log_level(tracing_core::LevelFilter::INFO); + sbox.register_print(|msg: String| msg.len() as i32).unwrap(); + let mut sandbox = sbox.evolve().unwrap(); + + // PrintOutput logs and calls HostPrint callback. + // With depth=4 the logs may fill the ring, requiring + // call_host_function to handle backpressure before + // submitting the callback entry. + let res: i32 = sandbox.call("PrintOutput", "bp-test".to_string()).unwrap(); + assert_eq!(res, 7); + }); +} + +#[test] +fn virtq_backpressure_no_data_loss() { + // After backpressure recovery, verify multiple function calls + // return correct results (completion data wasn't lost by reclaim). + let mut cfg = SandboxConfiguration::default(); + cfg.set_g2h_pool_pages(2); + cfg.set_g2h_queue_depth(4); + + with_rust_uninit_sandbox_cfg(cfg, |mut sbox| { + sbox.set_max_guest_log_level(tracing_core::LevelFilter::INFO); + let mut sandbox = sbox.evolve().unwrap(); + + // Trigger backpressure with logs + sandbox.call::<()>("LogMessageN", 20_i32).unwrap(); + + // Now verify multiple function calls with return values + let res: String = sandbox.call("Echo", "first".to_string()).unwrap(); + assert_eq!(res, "first"); + + let res: String = sandbox.call("Echo", "second".to_string()).unwrap(); + assert_eq!(res, "second"); + + let res: f64 = sandbox.call("EchoDouble", 1.234_f64).unwrap(); + assert!((res - 1.234).abs() < f64::EPSILON); + }); +} + +#[test] +fn virtq_invalid_guest_function_returns_error() { + // Calling a non-existent guest function should return a proper + // GuestError, not corrupt data or a hang. This validates that + // the virtq error path (MsgKind::Response with GuestError payload) + // works end-to-end. + with_rust_sandbox_cfg(SandboxConfiguration::default(), |mut sandbox| { + let res = sandbox.call::<()>("ThisFunctionDoesNotExist", ()); + assert!(res.is_err(), "expected error for non-existent function"); + let err = res.unwrap_err(); + assert!( + matches!( + err, + HyperlightError::GuestError( + hyperlight_common::flatbuffer_wrappers::guest_error::ErrorCode::GuestFunctionNotFound, + _ + ) + ), + "expected GuestFunctionNotFound, got {:?}", + err + ); + }); +} + +#[test] +fn virtq_large_payload_roundtrip() { + // Verify that larger payloads survive the virtq roundtrip without corruption. + with_rust_sandbox_cfg(SandboxConfiguration::default(), |mut sandbox| { + // 1KB string + let large_msg: String = "X".repeat(1024); + let res: String = sandbox.call("Echo", large_msg.clone()).unwrap(); + assert_eq!(res, large_msg); + + // 1KB byte array + let large_bytes = vec![0xABu8; 1024]; + let res: Vec = sandbox + .call("SetByteArrayToZero", large_bytes.clone()) + .unwrap(); + assert_eq!(res.len(), 1024); + assert!(res.iter().all(|&b| b == 0)); + }); +} + +#[test] +fn virtq_multi_descriptor_h2g_two_slots() { + // Payload exceeds a single H2G slot (4096 - header), requiring 2 descriptors. + let mut cfg = SandboxConfiguration::default(); + cfg.set_h2g_pool_pages(4); + with_rust_sandbox_cfg(cfg, |mut sandbox| { + let large_msg: String = "A".repeat(4200); + let res: String = sandbox.call("Echo", large_msg.clone()).unwrap(); + assert_eq!(res, large_msg); + }); +} + +#[test] +fn virtq_multi_descriptor_h2g_max_slots() { + // Payload spanning all available H2G pool slots. + let mut cfg = SandboxConfiguration::default(); + cfg.set_h2g_pool_pages(4); + with_rust_sandbox_cfg(cfg, |mut sandbox| { + let large_msg: String = "B".repeat(8200); + let res: String = sandbox.call("Echo", large_msg.clone()).unwrap(); + assert_eq!(res, large_msg); + }); +} + +#[test] +fn virtq_multi_descriptor_h2g_byte_array() { + // Multi-descriptor with byte array arguments to test binary payloads. + let mut cfg = SandboxConfiguration::default(); + cfg.set_h2g_pool_pages(8); + with_rust_sandbox_cfg(cfg, |mut sandbox| { + let large_bytes: Vec = (0..5000).map(|i| (i % 256) as u8).collect(); + let res: Vec = sandbox + .call("SetByteArrayToZero", large_bytes.clone()) + .unwrap(); + assert_eq!(res.len(), 5000); + assert!(res.iter().all(|&b| b == 0)); + }); +} + +#[test] +fn virtq_multi_descriptor_h2g_boundary() { + // Payload exactly at single-slot capacity boundary. + // Header is 8 bytes, so a single slot fits exactly 4088 bytes of payload. + // The FlatBuffer encoding adds overhead, so we test near the boundary + // to verify no off-by-one errors. + let mut cfg = SandboxConfiguration::default(); + cfg.set_h2g_pool_pages(4); + with_rust_sandbox_cfg(cfg, |mut sandbox| { + // This should fit in one descriptor (small overhead) + let msg_under: String = "C".repeat(3900); + let res: String = sandbox.call("Echo", msg_under.clone()).unwrap(); + assert_eq!(res, msg_under); + + // This should just barely spill into a second descriptor + let msg_over: String = "D".repeat(4100); + let res: String = sandbox.call("Echo", msg_over.clone()).unwrap(); + assert_eq!(res, msg_over); + }); +} + +#[test] +fn virtq_multi_descriptor_h2g_repeated_calls() { + // Multiple large calls in sequence to verify H2G refill works correctly + // after multi-descriptor consumption. + let mut cfg = SandboxConfiguration::default(); + cfg.set_h2g_pool_pages(8); + with_rust_sandbox_cfg(cfg, |mut sandbox| { + for i in 0..5 { + let ch = char::from(b'A' + i as u8); + let msg: String = std::iter::repeat_n(ch, 4500).collect(); + let res: String = sandbox.call("Echo", msg.clone()).unwrap(); + assert_eq!(res, msg, "mismatch on call {i}"); + } + }); +} + +/// Helper to create a sandbox with a "GetLargeResponse" host function +/// that returns `size` bytes filled with 0xAB. +fn sandbox_with_large_response(cfg: SandboxConfiguration) -> MultiUseSandbox { + let mut sandbox = UninitializedSandbox::new( + GuestBinary::FilePath(simple_guest_as_string().unwrap()), + Some(cfg), + ) + .unwrap(); + sandbox + .register("GetLargeResponse", |size: i32| -> Result> { + Ok(vec![0xABu8; size as usize]) + }) + .unwrap(); + sandbox.evolve().unwrap() +} + +#[test] +fn virtq_large_g2h_response_with_hint() { + // Host function returns >4096 bytes. Guest uses a sized completion + // hint so the pool allocates multiple adjacent slots. + let mut cfg = SandboxConfiguration::default(); + cfg.set_g2h_pool_pages(16); + let mut sandbox = sandbox_with_large_response(cfg); + + // 8000 bytes of response payload. With FlatBuffer + header overhead, + // the wire size is ~8100 bytes. A hint of 3*4096 = 12288 is enough. + let hint = 3 * 4096i32; + let res: Vec = sandbox + .call("CallGetLargeResponseWithHint", (8000i32, hint)) + .unwrap(); + assert_eq!(res.len(), 8000); + assert!(res.iter().all(|&b| b == 0xAB)); +} + +#[test] +fn virtq_large_g2h_response_too_large_without_hint() { + // Without a hint, the default 4096-byte completion buffer is used. + // A response >4096 bytes should trigger the host's "response too + // large" fallback error instead of a transport crash. + let mut cfg = SandboxConfiguration::default(); + cfg.set_g2h_pool_pages(16); + let mut sandbox = sandbox_with_large_response(cfg); + + let res = sandbox.call::>("CallGetLargeResponseDefault", 8000i32); + assert!( + res.is_err(), + "expected error for oversized response without hint" + ); +} + +#[test] +fn virtq_large_g2h_response_boundary() { + // Response that fits exactly in one page (with overhead) should work + // without needing a hint. + let mut cfg = SandboxConfiguration::default(); + cfg.set_g2h_pool_pages(16); + let mut sandbox = sandbox_with_large_response(cfg); + + // Small response that fits in default 4096 buffer + let res: Vec = sandbox + .call("CallGetLargeResponseDefault", 1000i32) + .unwrap(); + assert_eq!(res.len(), 1000); + assert!(res.iter().all(|&b| b == 0xAB)); +} + +#[test] +fn virtq_large_g2h_response_after_log_backpressure() { + // Logs fill the G2H pool, then a large host response (with hint) + // needs multi-slot allocation. The backpressure path must drain + // completed log entries to free pool slots for the large completion. + let mut cfg = SandboxConfiguration::default(); + cfg.set_g2h_pool_pages(16); + let mut sandbox = sandbox_with_large_response(cfg); + + // Emit 20 log entries to consume pool slots, then request 8KB + // response with a 12KB hint (3 upper-slab slots). + let hint = 3 * 4096i32; + let res: Vec = sandbox + .call("LogThenLargeResponse", (20i32, 8000i32, hint)) + .unwrap(); + assert_eq!(res.len(), 8000); + assert!(res.iter().all(|&b| b == 0xAB)); +} diff --git a/src/hyperlight_libc/third_party/mimalloc b/src/hyperlight_libc/third_party/mimalloc new file mode 160000 index 000000000..09a27098a --- /dev/null +++ b/src/hyperlight_libc/third_party/mimalloc @@ -0,0 +1 @@ +Subproject commit 09a27098aa6e9286518bd9c74e6ffa7199c3f04e diff --git a/src/tests/rust_guests/dummyguest/Cargo.lock b/src/tests/rust_guests/dummyguest/Cargo.lock index 9cb6885bb..060a340fa 100644 --- a/src/tests/rust_guests/dummyguest/Cargo.lock +++ b/src/tests/rust_guests/dummyguest/Cargo.lock @@ -72,11 +72,17 @@ dependencies = [ "syn", ] +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" + [[package]] name = "cc" -version = "1.2.60" +version = "1.2.62" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43c5703da9466b66a946814e1adf53ea2c90f10063b86290cc9eb67ce3478a20" +checksum = "a1dce859f0832a7d088c4f1119888ab94ef4b5d6795d1ce05afb7fe159d79f98" dependencies = [ "find-msvc-tools", "shlex", @@ -134,6 +140,12 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" +[[package]] +name = "fixedbitset" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" + [[package]] name = "flatbuffers" version = "25.12.19" @@ -152,9 +164,9 @@ checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" [[package]] name = "hashbrown" -version = "0.17.0" +version = "0.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f467dd6dccf739c208452f8014c75c18bb8301b050ad1cfb27153803edb0f51" +checksum = "ed5909b6e89a2db4456e54cd5f673791d7eca6732202bbf2a9cc504fe2f9b84a" [[package]] name = "hyperlight-common" @@ -163,6 +175,8 @@ dependencies = [ "anyhow", "bitflags", "bytemuck", + "bytes", + "fixedbitset", "flatbuffers", "log", "smallvec", @@ -176,6 +190,7 @@ name = "hyperlight-guest" version = "0.15.0" dependencies = [ "anyhow", + "bytemuck", "flatbuffers", "hyperlight-common", "hyperlight-guest-tracing", @@ -593,9 +608,9 @@ checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" [[package]] name = "winnow" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09dac053f1cd375980747450bfc7250c264eaae0583872e845c0c7cd578872b5" +checksum = "2ee1708bef14716a11bae175f579062d4554d95be2c6829f518df847b7b3fdd0" dependencies = [ "memchr", ] diff --git a/src/tests/rust_guests/simpleguest/Cargo.lock b/src/tests/rust_guests/simpleguest/Cargo.lock index 23475a643..7327a0332 100644 --- a/src/tests/rust_guests/simpleguest/Cargo.lock +++ b/src/tests/rust_guests/simpleguest/Cargo.lock @@ -72,11 +72,17 @@ dependencies = [ "syn", ] +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" + [[package]] name = "cc" -version = "1.2.60" +version = "1.2.62" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43c5703da9466b66a946814e1adf53ea2c90f10063b86290cc9eb67ce3478a20" +checksum = "a1dce859f0832a7d088c4f1119888ab94ef4b5d6795d1ce05afb7fe159d79f98" dependencies = [ "find-msvc-tools", "shlex", @@ -126,6 +132,12 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" +[[package]] +name = "fixedbitset" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" + [[package]] name = "flatbuffers" version = "25.12.19" @@ -144,9 +156,9 @@ checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" [[package]] name = "hashbrown" -version = "0.17.0" +version = "0.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f467dd6dccf739c208452f8014c75c18bb8301b050ad1cfb27153803edb0f51" +checksum = "ed5909b6e89a2db4456e54cd5f673791d7eca6732202bbf2a9cc504fe2f9b84a" [[package]] name = "hyperlight-common" @@ -155,6 +167,8 @@ dependencies = [ "anyhow", "bitflags", "bytemuck", + "bytes", + "fixedbitset", "flatbuffers", "log", "smallvec", @@ -168,6 +182,7 @@ name = "hyperlight-guest" version = "0.15.0" dependencies = [ "anyhow", + "bytemuck", "flatbuffers", "hyperlight-common", "hyperlight-guest-tracing", @@ -598,9 +613,9 @@ checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" [[package]] name = "winnow" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09dac053f1cd375980747450bfc7250c264eaae0583872e845c0c7cd578872b5" +checksum = "2ee1708bef14716a11bae175f579062d4554d95be2c6829f518df847b7b3fdd0" dependencies = [ "memchr", ] diff --git a/src/tests/rust_guests/simpleguest/src/main.rs b/src/tests/rust_guests/simpleguest/src/main.rs index b6844a716..acb672f1c 100644 --- a/src/tests/rust_guests/simpleguest/src/main.rs +++ b/src/tests/rust_guests/simpleguest/src/main.rs @@ -49,11 +49,11 @@ use hyperlight_guest_bin::exception::arch::{Context, ExceptionInfo}; use hyperlight_guest_bin::guest_function::definition::{GuestFunc, GuestFunctionDefinition}; use hyperlight_guest_bin::guest_function::register::register_function; use hyperlight_guest_bin::host_comm::{ - call_host_function, call_host_function_without_returning_result, get_host_return_value_raw, - print_output_with_host_print, read_n_bytes_from_user_memory, + call_host_function, call_host_function_with_hint, print_output_with_host_print, + read_n_bytes_from_user_memory, }; use hyperlight_guest_bin::memory::malloc; -use hyperlight_guest_bin::{GUEST_HANDLE, guest_function, guest_logger, host_function}; +use hyperlight_guest_bin::{guest_function, guest_logger, host_function}; use log::{LevelFilter, error}; use tracing::{Span, instrument}; @@ -380,6 +380,49 @@ fn echo(value: String) -> String { value } +/// Calls a host function "GetLargeResponse" with an explicit response +/// capacity hint. The `hint` parameter is the total completion buffer +/// size in bytes. +#[guest_function("CallGetLargeResponseWithHint")] +fn call_get_large_response_with_hint(size: i32, hint: i32) -> Vec { + call_host_function_with_hint::>( + "GetLargeResponse", + Some(vec![ParameterValue::Int(size)]), + ReturnType::VecBytes, + hint as usize, + ) + .expect("GetLargeResponse call failed") +} + +/// Calls a host function "GetLargeResponse" WITHOUT a hint, using the +/// default 4096-byte completion buffer. +#[guest_function("CallGetLargeResponseDefault")] +fn call_get_large_response_default(size: i32) -> Vec { + call_host_function::>( + "GetLargeResponse", + Some(vec![ParameterValue::Int(size)]), + ReturnType::VecBytes, + ) + .expect("GetLargeResponse call failed") +} + +/// Emits `log_count` log entries to fill the G2H queue, then calls +/// "GetLargeResponse" with a sized hint. Tests that backpressure +/// draining of logs frees pool slots for the large completion. +#[guest_function("LogThenLargeResponse")] +fn log_then_large_response(log_count: i32, size: i32, hint: i32) -> Vec { + for i in 0..log_count { + log::info!("backpressure log {}", i); + } + call_host_function_with_hint::>( + "GetLargeResponse", + Some(vec![ParameterValue::Int(size)]), + ReturnType::VecBytes, + hint as usize, + ) + .expect("GetLargeResponse after logs failed") +} + #[guest_function("GetSizePrefixedBuffer")] fn get_size_prefixed_buffer(data: Vec) -> Vec { data @@ -479,6 +522,13 @@ fn log_message(message: String, level: i32) { } } +#[guest_function("LogMessageN")] +fn log_message_n(count: i32) { + for i in 0..count { + log::info!("log entry {}", i); + } +} + #[guest_function("TriggerException")] fn trigger_exception() { // trigger an undefined instruction exception @@ -975,53 +1025,6 @@ fn fuzz_guest_trace(max_depth: u32, msg: String) -> u32 { fuzz_traced_function(0, max_depth, &msg) } -#[guest_function("CorruptOutputSizePrefix")] -fn corrupt_output_size_prefix() -> i32 { - unsafe { - let peb_ptr = core::ptr::addr_of!(GUEST_HANDLE).read().peb().unwrap(); - let output_stack_ptr = (*peb_ptr).output_stack.ptr as *mut u8; - - // Write a fake stack entry with a ~4 GB size prefix (0xFFFF_FFFB + 4). - let buf = core::slice::from_raw_parts_mut(output_stack_ptr, 24); - buf[0..8].copy_from_slice(&24_u64.to_le_bytes()); - buf[8..12].copy_from_slice(&0xFFFF_FFFBu32.to_le_bytes()); - buf[12..16].copy_from_slice(&[0u8; 4]); - buf[16..24].copy_from_slice(&8_u64.to_le_bytes()); - - core::arch::asm!( - "out dx, eax", - "cli", - "hlt", - in("dx") hyperlight_common::outb::VmAction::Halt as u16, - in("eax") 0u32, - options(noreturn), - ); - } -} - -#[guest_function("CorruptOutputBackPointer")] -fn corrupt_output_back_pointer() -> i32 { - unsafe { - let peb_ptr = core::ptr::addr_of!(GUEST_HANDLE).read().peb().unwrap(); - let output_stack_ptr = (*peb_ptr).output_stack.ptr as *mut u8; - - // Write a fake stack entry with back-pointer 0xDEAD (past stack pointer 24). - let buf = core::slice::from_raw_parts_mut(output_stack_ptr, 24); - buf[0..8].copy_from_slice(&24_u64.to_le_bytes()); - buf[8..16].copy_from_slice(&[0u8; 8]); - buf[16..24].copy_from_slice(&0xDEAD_u64.to_le_bytes()); - - core::arch::asm!( - "out dx, eax", - "cli", - "hlt", - in("dx") hyperlight_common::outb::VmAction::Halt as u16, - in("eax") 0u32, - options(noreturn), - ); - } -} - // Interprets the given guest function call as a host function call and dispatches it to the host. fn fuzz_host_function(func: FunctionCall) -> Result> { let mut params = func.parameters.unwrap(); @@ -1038,32 +1041,23 @@ fn fuzz_host_function(func: FunctionCall) -> Result> { } }; - // Because we do not know at compile time the actual return type of the host function to be called - // we cannot use the `call_host_function` generic function. - // We need to use the `call_host_function_without_returning_result` function that does not retrieve the return - // value - call_host_function_without_returning_result( - &host_func_name, - Some(params), - func.expected_return_type, - ) - .expect("failed to call host function"); - - let host_return = get_host_return_value_raw(); - match host_return { - Ok(return_value) => match return_value { - ReturnValue::Int(i) => Ok(get_flatbuffer_result(i)), - ReturnValue::UInt(i) => Ok(get_flatbuffer_result(i)), - ReturnValue::Long(i) => Ok(get_flatbuffer_result(i)), - ReturnValue::ULong(i) => Ok(get_flatbuffer_result(i)), - ReturnValue::Float(i) => Ok(get_flatbuffer_result(i)), - ReturnValue::Double(i) => Ok(get_flatbuffer_result(i)), - ReturnValue::String(str) => Ok(get_flatbuffer_result(str.as_str())), - ReturnValue::Bool(bool) => Ok(get_flatbuffer_result(bool)), - ReturnValue::Void(()) => Ok(get_flatbuffer_result(())), - ReturnValue::VecBytes(byte) => Ok(get_flatbuffer_result(byte.as_slice())), - }, - Err(e) => Err(e), + // Call the host function with dynamic return type. Since we don't + // know T at compile time, use ReturnValue as the return type and + // match on the result. + let return_value: ReturnValue = + call_host_function(&host_func_name, Some(params), func.expected_return_type)?; + + match return_value { + ReturnValue::Int(i) => Ok(get_flatbuffer_result(i)), + ReturnValue::UInt(i) => Ok(get_flatbuffer_result(i)), + ReturnValue::Long(i) => Ok(get_flatbuffer_result(i)), + ReturnValue::ULong(i) => Ok(get_flatbuffer_result(i)), + ReturnValue::Float(i) => Ok(get_flatbuffer_result(i)), + ReturnValue::Double(i) => Ok(get_flatbuffer_result(i)), + ReturnValue::String(str) => Ok(get_flatbuffer_result(str.as_str())), + ReturnValue::Bool(bool) => Ok(get_flatbuffer_result(bool)), + ReturnValue::Void(()) => Ok(get_flatbuffer_result(())), + ReturnValue::VecBytes(byte) => Ok(get_flatbuffer_result(byte.as_slice())), } } diff --git a/src/tests/rust_guests/witguest/Cargo.lock b/src/tests/rust_guests/witguest/Cargo.lock index f9a6ffa6c..7e7516cd0 100644 --- a/src/tests/rust_guests/witguest/Cargo.lock +++ b/src/tests/rust_guests/witguest/Cargo.lock @@ -122,11 +122,17 @@ dependencies = [ "syn", ] +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" + [[package]] name = "cc" -version = "1.2.61" +version = "1.2.62" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d16d90359e986641506914ba71350897565610e87ce0ad9e6f28569db3dd5c6d" +checksum = "a1dce859f0832a7d088c4f1119888ab94ef4b5d6795d1ce05afb7fe159d79f98" dependencies = [ "find-msvc-tools", "shlex", @@ -205,6 +211,12 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" +[[package]] +name = "fixedbitset" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" + [[package]] name = "flatbuffers" version = "25.12.19" @@ -229,9 +241,9 @@ checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" [[package]] name = "hashbrown" -version = "0.17.0" +version = "0.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f467dd6dccf739c208452f8014c75c18bb8301b050ad1cfb27153803edb0f51" +checksum = "ed5909b6e89a2db4456e54cd5f673791d7eca6732202bbf2a9cc504fe2f9b84a" dependencies = [ "foldhash", "serde", @@ -245,6 +257,8 @@ dependencies = [ "anyhow", "bitflags", "bytemuck", + "bytes", + "fixedbitset", "flatbuffers", "log", "smallvec", @@ -285,6 +299,7 @@ name = "hyperlight-guest" version = "0.15.0" dependencies = [ "anyhow", + "bytemuck", "flatbuffers", "hyperlight-common", "hyperlight-guest-tracing",