Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions vortex-ipc/src/messages/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,21 @@ enum State {

#[derive(Debug)]
pub enum PollRead {
/// A complete message was decoded.
Some(DecoderMessage),
/// Returns the _total_ number of bytes needed to make progress.
/// Note this is _not_ the incremental number of bytes needed to make progress.
/// The decoder needs more data to make progress.
///
/// The inner value is the **total*k number of bytes the buffer should contain, not the
/// incremental amount needed. Callers should:
///
/// 1. Resize the buffer to this length.
/// 2. Fill the buffer completely (handling partial reads as needed).
/// 3. Only then call [`MessageDecoder::read_next`] again.
///
/// The decoder checks [`bytes::Buf::remaining`] to determine available data, which for
/// [`bytes::BytesMut`] returns the buffer length regardless of how many bytes were actually
/// written. Calling `read_next` before the buffer is fully populated will cause the decoder
/// to read garbage data.
NeedMore(usize),
}

Expand Down
108 changes: 83 additions & 25 deletions vortex-ipc/src/messages/reader_async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use futures::Stream;
use pin_project_lite::pin_project;
use vortex_array::session::ArrayRegistry;
use vortex_error::VortexResult;
use vortex_error::vortex_err;

use crate::messages::DecoderMessage;
use crate::messages::MessageDecoder;
Expand All @@ -24,7 +25,7 @@ pin_project! {
read: R,
buffer: BytesMut,
decoder: MessageDecoder,
bytes_read: usize,
state: ReadState,
}
}

Expand All @@ -34,40 +35,97 @@ impl<R> AsyncMessageReader<R> {
read,
buffer: BytesMut::new(),
decoder: MessageDecoder::new(registry),
bytes_read: 0,
state: ReadState::default(),
}
}
}

/// The state of an in-progress read operation.
#[derive(Default)]
enum ReadState {
/// Ready to consult the decoder for the next operation.
#[default]
AwaitingDecoder,
/// Filling the buffer with data from the underlying reader.
///
/// Async readers may return fewer bytes than requested (partial reads), especially over network
/// connections. This state persists across multiple `poll_next` calls until the buffer is
/// completely filled, at which point we transition back to [`Self::AwaitingDecoder`].
Filling {
/// The number of bytes read into the buffer so far.
total_bytes_read: usize,
},
}

/// Result of polling the reader to fill the buffer.
enum FillResult {
/// The buffer has been completely filled.
Filled,
/// Need more data (partial read occurred).
Pending,
/// Clean EOF at a message boundary.
Eof,
}

/// Polls the reader to fill the buffer, handling partial reads.
fn poll_fill_buffer<R: AsyncRead>(
read: Pin<&mut R>,
buffer: &mut [u8],
total_bytes_read: &mut usize,
cx: &mut Context<'_>,
) -> Poll<VortexResult<FillResult>> {
let unfilled = &mut buffer[*total_bytes_read..];

let bytes_read = ready!(read.poll_read(cx, unfilled))?;

// `0` bytes read indicates an EOF.
Poll::Ready(if bytes_read == 0 {
if *total_bytes_read > 0 {
Err(vortex_err!(
"unexpected EOF during partial read: read {total_bytes_read} of {} expected bytes",
buffer.len()
))
} else {
Ok(FillResult::Eof)
}
} else {
*total_bytes_read += bytes_read;
if *total_bytes_read == buffer.len() {
Ok(FillResult::Filled)
} else {
debug_assert!(*total_bytes_read < buffer.len());
Ok(FillResult::Pending)
}
})
}

impl<R: AsyncRead> Stream for AsyncMessageReader<R> {
type Item = VortexResult<DecoderMessage>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
loop {
match this.decoder.read_next(this.buffer)? {
PollRead::Some(msg) => return Poll::Ready(Some(Ok(msg))),
PollRead::NeedMore(nbytes) => {
this.buffer.resize(nbytes, 0x00);

match ready!(
this.read
.as_mut()
.poll_read(cx, &mut this.buffer.as_mut()[*this.bytes_read..])
) {
Ok(0) => {
// End of file
return Poll::Ready(None);
}
Ok(nbytes) => {
*this.bytes_read += nbytes;
// If we've finished the read operation, then we continue the loop
// and the decoder should present us with a new response.
if *this.bytes_read == nbytes {
*this.bytes_read = 0;
}
}
Err(e) => return Poll::Ready(Some(Err(e.into()))),
match this.state {
ReadState::AwaitingDecoder => match this.decoder.read_next(this.buffer)? {
PollRead::Some(msg) => return Poll::Ready(Some(Ok(msg))),
PollRead::NeedMore(new_len) => {
this.buffer.resize(new_len, 0x00);
*this.state = ReadState::Filling {
total_bytes_read: 0,
};
}
},
ReadState::Filling { total_bytes_read } => {
match ready!(poll_fill_buffer(
this.read.as_mut(),
this.buffer,
total_bytes_read,
cx
)) {
Err(e) => return Poll::Ready(Some(Err(e))),
Ok(FillResult::Eof) => return Poll::Ready(None),
Ok(FillResult::Filled) => *this.state = ReadState::AwaitingDecoder,
Ok(FillResult::Pending) => {}
}
}
}
Expand Down
77 changes: 77 additions & 0 deletions vortex-ipc/src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,11 @@ impl Stream for ArrayStreamIPCBytes {

#[cfg(test)]
mod test {
use std::io;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;

use futures::io::Cursor;
use vortex_array::IntoArray as _;
use vortex_array::ToCanonical;
Expand Down Expand Up @@ -232,4 +237,76 @@ mod test {
result.as_slice::<i32>()
);
}

/// Wrapper that limits reads to small chunks to simulate network behavior
struct ChunkedReader<R> {
inner: R,
chunk_size: usize,
}

impl<R: AsyncRead + Unpin> AsyncRead for ChunkedReader<R> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let chunk_size = self.chunk_size.min(buf.len());
Pin::new(&mut self.inner).poll_read(cx, &mut buf[..chunk_size])
}
}

#[tokio::test]
async fn test_async_stream_chunked() {
let session = ArraySession::default();
let array = buffer![1i32, 2, 3, 4, 5, 6, 7, 8, 9, 10].into_array();
let ipc_buffer = array
.to_array_stream()
.into_ipc()
.collect_to_buffer()
.await
.unwrap();

let chunked = ChunkedReader {
inner: Cursor::new(ipc_buffer),
chunk_size: 3,
};

let reader = AsyncIPCReader::try_new(chunked, session.registry().clone())
.await
.unwrap();

let result = reader.read_all().await.unwrap().to_primitive();
assert_eq!(
&[1i32, 2, 3, 4, 5, 6, 7, 8, 9, 10],
result.as_slice::<i32>()
);
}

/// Test with 1-byte chunks to stress-test partial read handling.
#[tokio::test]
async fn test_async_stream_single_byte_chunks() {
let session = ArraySession::default();
let array = buffer![42i64, -1, 0, i64::MAX, i64::MIN].into_array();
let ipc_buffer = array
.to_array_stream()
.into_ipc()
.collect_to_buffer()
.await
.unwrap();

let chunked = ChunkedReader {
inner: Cursor::new(ipc_buffer),
chunk_size: 1,
};

let reader = AsyncIPCReader::try_new(chunked, session.registry().clone())
.await
.unwrap();

let result = reader.read_all().await.unwrap().to_primitive();
assert_eq!(
&[42i64, -1, 0, i64::MAX, i64::MIN],
result.as_slice::<i64>()
);
}
}
Loading