diff --git a/biofuse/encoder_host.py b/biofuse/encoder_host.py index e9f5f13..13feb4e 100644 --- a/biofuse/encoder_host.py +++ b/biofuse/encoder_host.py @@ -83,6 +83,14 @@ def __init__( async def read(self, off: int, size: int) -> bytes: if self._closed: raise OSError(errno.EIO, "stream handle is closed") + # Fast path: serve from the encoder's in-memory cache without + # taking the lock or dispatching to a worker thread. + # ``try_cached_read`` is documented thread-safe and never + # advances the iterator, so it is safe to call concurrently + # with an in-flight slow-path ``encoder.read`` on another task. + cached = self._encoder.try_cached_read(off, size) + if cached is not None: + return cached async with self._lock: if self._closed: raise OSError(errno.EIO, "stream handle is closed") diff --git a/tests/test_encoder_host.py b/tests/test_encoder_host.py index 2e4f117..35fd41f 100644 --- a/tests/test_encoder_host.py +++ b/tests/test_encoder_host.py @@ -74,8 +74,13 @@ def __init__(self) -> None: self.release_read = threading.Event() self.entered_read = threading.Event() self.read_calls: list[tuple[int, int]] = [] + self.try_cached_read_calls: list[tuple[int, int]] = [] self.close_calls = 0 self._payload = b"X" + # When set, ``try_cached_read`` returns a slice of this buffer + # so the StreamHandle fast path can be exercised. Default ``None`` + # means cache miss — every call falls back to ``read``. + self.cached_payload: bytes | None = None def set_payload(self, body: bytes) -> None: self._payload = body @@ -86,6 +91,12 @@ def read(self, off: int, size: int) -> bytes: self.release_read.wait() return self._payload + def try_cached_read(self, off: int, size: int) -> bytes | None: + self.try_cached_read_calls.append((off, size)) + if self.cached_payload is None: + return None + return self.cached_payload[off : off + size] + def close(self) -> None: self.close_calls += 1 @@ -140,6 +151,63 @@ async def test_read_after_close_raises_eio(self): assert excinfo.value.errno == errno.EIO +class TestStreamHandleCachedFastPath: + async def test_cache_hit_returns_bytes_without_thread_dispatch(self): + encoder = _FakeEncoder() + encoder.cached_payload = b"hello" + # release_read deliberately not set: if the fast path doesn't + # short-circuit, the slow-path worker thread will block forever + # and the test will hang. + handle = encoder_host.StreamHandle(encoder) + try: + got = await handle.read(0, 5) + assert got == b"hello" + assert encoder.try_cached_read_calls == [(0, 5)] + assert encoder.read_calls == [] + assert not encoder.entered_read.is_set() + finally: + encoder.release_read.set() + await handle.aclose() + + async def test_cache_miss_falls_back_to_slow_path(self): + encoder = _FakeEncoder() + encoder.set_payload(b"world") + encoder.release_read.set() + handle = encoder_host.StreamHandle(encoder) + try: + got = await handle.read(0, 5) + assert got == b"world" + assert encoder.try_cached_read_calls == [(0, 5)] + assert encoder.read_calls == [(0, 5)] + finally: + await handle.aclose() + + async def test_cache_hit_does_not_enter_encoder_read(self): + encoder = _FakeEncoder() + encoder.cached_payload = b"ABCDEFGH" + handle = encoder_host.StreamHandle(encoder) + try: + async with trio.open_nursery() as nursery: + for off in range(4): + nursery.start_soon(handle.read, off, 2) + assert len(encoder.try_cached_read_calls) == 4 + assert encoder.read_calls == [] + finally: + encoder.release_read.set() + await handle.aclose() + + async def test_cache_hit_after_close_raises_eio_without_calling_encoder(self): + encoder = _FakeEncoder() + encoder.cached_payload = b"never seen" + encoder.release_read.set() + handle = encoder_host.StreamHandle(encoder) + await handle.aclose() + with pytest.raises(OSError, match="stream handle is closed") as excinfo: + await handle.read(0, 1) + assert excinfo.value.errno == errno.EIO + assert encoder.try_cached_read_calls == [] + + class TestStreamHandleSerialisation: async def test_concurrent_reads_serialise_on_one_handle(self): """Two concurrent ``handle.read`` calls must not enter the diff --git a/uv.lock b/uv.lock index d769d41..3b2d3c5 100644 --- a/uv.lock +++ b/uv.lock @@ -1737,8 +1737,8 @@ all = [ [[package]] name = "vcztools" -version = "0.1.3.dev353" -source = { git = "https://github.com/sgkit-dev/vcztools.git?rev=main#725a92378115025d5d6e77bd4b79da3708f9d087" } +version = "0.1.3.dev356" +source = { git = "https://github.com/sgkit-dev/vcztools.git?rev=main#11ead2fe565c81dc7b8d024dae9ce373a45be4ba" } dependencies = [ { name = "click" }, { name = "humanfriendly" },