Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 99 additions & 21 deletions src/devices/src/virtio/vsock/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ pub struct VsockPacket {
hdr: *mut u8,
buf: Option<*mut u8>,
buf_size: usize,
owned_buf: Option<Vec<u8>>,
}

fn get_host_address<T: GuestMemory>(
Expand Down Expand Up @@ -224,6 +225,7 @@ impl VsockPacket {
.map_err(VsockError::GuestMemoryMmap)?,
buf: None,
buf_size: 0,
owned_buf: None,
};

// No point looking for a data/buffer descriptor, if the packet is zero-lengthed.
Expand All @@ -237,25 +239,92 @@ impl VsockPacket {
return Err(VsockError::InvalidPktLen(pkt.len()));
}

// If the packet header showed a non-zero length, there should be a data descriptor here.
let buf_desc = head.next_descriptor().ok_or(VsockError::BufDescMissing)?;
let head_data_size = head.len as usize - VSOCK_PKT_HDR_SIZE;

// Single combined descriptor: header + data with no next descriptor.
if !head.has_next() {
if head_data_size == 0 {
return Err(VsockError::BufDescMissing);
}
let buf_addr = head
.addr
.checked_add(VSOCK_PKT_HDR_SIZE as u64)
.ok_or(VsockError::GuestMemoryBounds)?;
pkt.buf_size = head_data_size;
pkt.buf = Some(
get_host_address(head.mem, buf_addr, pkt.buf_size)
.map_err(VsockError::GuestMemoryMmap)?,
);
if pkt.buf_size < pkt.len() as usize {
return Err(VsockError::BufDescTooSmall);
}
return Ok(pkt);
}

// TX data should be read-only.
let buf_desc = head.next_descriptor().ok_or(VsockError::BufDescMissing)?;
if buf_desc.is_write_only() {
return Err(VsockError::UnreadableDescriptor);
}

// The data buffer should be large enough to fit the size of the data, as described by
// the header descriptor.
if buf_desc.len < pkt.len() {
// Classic two-descriptor case: header in first, all data in second. Zero-copy.
if head_data_size == 0 && !buf_desc.has_next() {
if buf_desc.len < pkt.len() {
return Err(VsockError::BufDescTooSmall);
}
pkt.buf_size = buf_desc.len as usize;
pkt.buf = Some(
get_host_address(buf_desc.mem, buf_desc.addr, pkt.buf_size)
.map_err(VsockError::GuestMemoryMmap)?,
);
return Ok(pkt);
}

// Multiple data regions: inline data after header and/or multiple data descriptors.
// Copy into a contiguous owned buffer.
let mut owned_buf: Vec<u8> = Vec::with_capacity(pkt.len() as usize);

if head_data_size > 0 {
let buf_addr = head
.addr
.checked_add(VSOCK_PKT_HDR_SIZE as u64)
.ok_or(VsockError::GuestMemoryBounds)?;
let src = get_host_address(head.mem, buf_addr, head_data_size)
.map_err(VsockError::GuestMemoryMmap)?;
owned_buf.extend_from_slice(unsafe {
std::slice::from_raw_parts(src as *const u8, head_data_size)
});
}

// First data descriptor (already validated as readable above).
if buf_desc.len > 0 {
let src = get_host_address(buf_desc.mem, buf_desc.addr, buf_desc.len as usize)
.map_err(VsockError::GuestMemoryMmap)?;
owned_buf.extend_from_slice(unsafe {
std::slice::from_raw_parts(src as *const u8, buf_desc.len as usize)
});
}

let mut next = buf_desc.next_descriptor();
while let Some(desc) = next {
if desc.is_write_only() {
return Err(VsockError::UnreadableDescriptor);
}
if desc.len > 0 {
let src = get_host_address(desc.mem, desc.addr, desc.len as usize)
.map_err(VsockError::GuestMemoryMmap)?;
owned_buf.extend_from_slice(unsafe {
std::slice::from_raw_parts(src as *const u8, desc.len as usize)
});
}
next = desc.next_descriptor();
}

if (owned_buf.len() as u32) < pkt.len() {
return Err(VsockError::BufDescTooSmall);
}

pkt.buf_size = buf_desc.len as usize;
pkt.buf = Some(
get_host_address(buf_desc.mem, buf_desc.addr, pkt.buf_size)
.map_err(VsockError::GuestMemoryMmap)?,
);
pkt.buf_size = owned_buf.len();
pkt.owned_buf = Some(owned_buf);

Ok(pkt)
}
Expand All @@ -281,6 +350,7 @@ impl VsockPacket {
.map_err(VsockError::GuestMemoryMmap)?,
buf: None,
buf_size: 0,
owned_buf: None,
};

// Starting from Linux 6.2 the virtio-vsock driver can use a single descriptor for both
Expand Down Expand Up @@ -331,11 +401,15 @@ impl VsockPacket {
/// (and often is) larger than the length of the packet data. The packet data length
/// is stored in the packet header, and accessible via `VsockPacket::len()`.
pub fn buf(&self) -> Option<&[u8]> {
self.buf.map(|ptr| {
// This is safe since bound checks have already been performed when creating the packet
// from the virtq descriptor.
unsafe { std::slice::from_raw_parts(ptr as *const u8, self.buf_size) }
})
if let Some(ref owned) = self.owned_buf {
Some(owned.as_slice())
} else {
self.buf.map(|ptr| {
// This is safe since bound checks have already been performed when creating the
// packet from the virtq descriptor.
unsafe { std::slice::from_raw_parts(ptr as *const u8, self.buf_size) }
})
}
}

/// Provides in-place, byte-slice, mutable access to the vsock packet data buffer.
Expand All @@ -346,11 +420,15 @@ impl VsockPacket {
/// (and often is) larger than the length of the packet data. The packet data length
/// is stored in the packet header, and accessible via `VsockPacket::len()`.
pub fn buf_mut(&mut self) -> Option<&mut [u8]> {
self.buf.map(|ptr| {
// This is safe since bound checks have already been performed when creating the packet
// from the virtq descriptor.
unsafe { std::slice::from_raw_parts_mut(ptr, self.buf_size) }
})
if let Some(ref mut owned) = self.owned_buf {
Some(owned.as_mut_slice())
} else {
self.buf.map(|ptr| {
// This is safe since bound checks have already been performed when creating the
// packet from the virtq descriptor.
unsafe { std::slice::from_raw_parts_mut(ptr, self.buf_size) }
})
Comment thread
slp marked this conversation as resolved.
}
}

pub fn src_cid(&self) -> u64 {
Expand Down
Loading