From acf1c6328124ce75e4a9f41422bf8334e76b83de Mon Sep 17 00:00:00 2001 From: Guillermo Perez Date: Mon, 16 Mar 2026 10:18:13 +0100 Subject: [PATCH] multi-get --- meta_memcache_socket.pyi | 18 +++ src/memcache_socket.rs | 70 +++++++++- tests/test_memcache_socket.py | 235 ++++++++++++++++++++++++++++++++++ 3 files changed, 322 insertions(+), 1 deletion(-) diff --git a/meta_memcache_socket.pyi b/meta_memcache_socket.pyi index 95b29a0..f5f962d 100644 --- a/meta_memcache_socket.pyi +++ b/meta_memcache_socket.pyi @@ -349,3 +349,21 @@ class MemcacheSocket: key: Union[str, bytes], request_flags: Optional[RequestFlags] = None, ) -> Union[Value, Success, Miss, NotStored, Conflict]: ... + + # Batch operations + def meta_multiget( + self, + keys: list[Union[str, bytes]], + request_flags: Optional[RequestFlags] = None, + ) -> list[Union[Value, Success, Miss, NotStored, Conflict]]: + """ + Send multiple meta get commands and return all responses in one batch. + + Builds all commands into one buffer, sends in a single operation, then + receives all responses in a tight Rust loop. GIL is released during + all socket I/O. Returns a list of responses in the same order as keys. + + :param keys: List of keys to get + :param request_flags: The flags to use for all keys + """ + ... diff --git a/src/memcache_socket.rs b/src/memcache_socket.rs index d399119..f382b33 100644 --- a/src/memcache_socket.rs +++ b/src/memcache_socket.rs @@ -5,7 +5,7 @@ use log::warn; use pyo3::BoundObject; use pyo3::exceptions::{PyConnectionError, PyTimeoutError, PyValueError}; use pyo3::prelude::*; -use pyo3::types::PyBytes; +use pyo3::types::{PyBytes, PyList}; use crate::constants::*; use crate::encode_key::extract_key; @@ -807,4 +807,72 @@ impl MemcacheSocket { CmdResult::Response((header, value_data)) => self.make_response(py, header, value_data), } } + + // ----------------------------------------------------------------------- + // Batch: meta_multiget (pipelined multi-key get in one Rust call) + // ----------------------------------------------------------------------- + + /// Send multiple meta get commands and return all responses in one batch. + /// + /// Builds all commands into one buffer, sends in a single operation, then + /// receives all responses in a tight Rust loop. GIL is released during + /// all socket I/O. Returns a list of responses in the same order as keys. + #[pyo3(signature = (keys, request_flags=None))] + pub fn meta_multiget( + &mut self, + py: Python<'_>, + keys: &Bound<'_, PyList>, + request_flags: Option<&RequestFlags>, + ) -> PyResult>> { + let num_keys = keys.len(); + if num_keys == 0 { + return Ok(Vec::new()); + } + + // Build all mg commands into one buffer (GIL held for key extraction) + let mut cmd_buf: Vec = Vec::with_capacity(num_keys * 64); + for i in 0..num_keys { + let key_obj = keys.get_item(i)?; + let cmd = self.build_cmd(b"mg", &key_obj, None, request_flags)?; + cmd_buf.extend_from_slice(&cmd.buf); + } + + // Send all commands, then receive all responses (GIL released). + // Value data is copied to owned Vecs so buffer resets during + // subsequent reads don't invalidate earlier responses. + let io = &mut self.io; + let raw_responses: Vec<(ParsedHeader, Option>)> = py + .detach(|| { + send_all(io.fd, &cmd_buf, io.timeout_ms)?; + let mut responses = Vec::with_capacity(num_keys); + for _ in 0..num_keys { + let header = io.get_header()?; + let value = if header.response_type == Some(RESPONSE_VALUE) { + let size = header.size.unwrap_or(0) as usize; + let vd = io.ensure_value(size)?; + Some(match vd { + ValueData::InBuffer(start) => { + // Convert to vector, as the buffer will be + // reused and data will be overwritten + io.buf[start..start + size].to_vec() + } + ValueData::Allocated(data) => data, + }) + } else { + None + }; + responses.push((header, value)); + } + Ok(responses) + }) + .map_err(|e: std::io::Error| socket_err_io("Error in meta_multiget", e))?; + + // Convert raw responses to Python objects (GIL held) + let mut results = Vec::with_capacity(num_keys); + for (header, value_bytes) in raw_responses { + let value_data = value_bytes.map(ValueData::Allocated); + results.push(self.make_response(py, header, value_data)?); + } + Ok(results) + } } diff --git a/tests/test_memcache_socket.py b/tests/test_memcache_socket.py index 9c44bb4..879a7ba 100644 --- a/tests/test_memcache_socket.py +++ b/tests/test_memcache_socket.py @@ -936,6 +936,241 @@ def test_set_socket_updates_timeout(self, socket_pair): d.close() +# --- meta_multiget (batch) --- + + +class TestMetaMultiget: + def test_empty_keys(self, socket_pair): + a, b = socket_pair + ms = MemcacheSocket(a) + results = ms.meta_multiget([]) + assert results == [] + + def test_single_key_hit(self, socket_pair): + a, b = socket_pair + ms = MemcacheSocket(a) + b.sendall(b"VA 5 c1\r\nhello\r\n") + results = ms.meta_multiget([b"key1"], RequestFlags(return_cas_token=True, return_value=True)) + assert len(results) == 1 + assert isinstance(results[0], Value) + assert results[0].value == b"hello" + assert results[0].flags.cas_token == 1 + + def test_single_key_miss(self, socket_pair): + a, b = socket_pair + ms = MemcacheSocket(a) + b.sendall(b"EN\r\n") + results = ms.meta_multiget([b"key1"]) + assert len(results) == 1 + assert isinstance(results[0], Miss) + + def test_multiple_keys_all_hits(self, socket_pair): + a, b = socket_pair + ms = MemcacheSocket(a) + flags = RequestFlags(return_cas_token=True, return_value=True) + b.sendall( + b"VA 3 c1\r\nfoo\r\n" + b"VA 3 c2\r\nbar\r\n" + b"VA 3 c3\r\nbaz\r\n" + ) + results = ms.meta_multiget([b"k1", b"k2", b"k3"], flags) + assert len(results) == 3 + assert all(isinstance(r, Value) for r in results) + assert results[0].value == b"foo" + assert results[0].flags.cas_token == 1 + assert results[1].value == b"bar" + assert results[1].flags.cas_token == 2 + assert results[2].value == b"baz" + assert results[2].flags.cas_token == 3 + + def test_multiple_keys_all_misses(self, socket_pair): + a, b = socket_pair + ms = MemcacheSocket(a) + b.sendall(b"EN\r\n" * 5) + results = ms.meta_multiget([b"k1", b"k2", b"k3", b"k4", b"k5"]) + assert len(results) == 5 + assert all(isinstance(r, Miss) for r in results) + + def test_mixed_hits_and_misses(self, socket_pair): + a, b = socket_pair + ms = MemcacheSocket(a) + flags = RequestFlags(return_value=True) + b.sendall( + b"VA 3 f1\r\nfoo\r\n" + b"EN\r\n" + b"VA 3 f2\r\nbar\r\n" + b"EN\r\n" + b"VA 3 f3\r\nbaz\r\n" + ) + results = ms.meta_multiget([b"k1", b"k2", b"k3", b"k4", b"k5"], flags) + assert len(results) == 5 + assert isinstance(results[0], Value) + assert results[0].value == b"foo" + assert isinstance(results[1], Miss) + assert isinstance(results[2], Value) + assert results[2].value == b"bar" + assert isinstance(results[3], Miss) + assert isinstance(results[4], Value) + assert results[4].value == b"baz" + + def test_verifies_wire_format(self, socket_pair): + """meta_multiget sends correct wire commands.""" + a, b = socket_pair + ms = MemcacheSocket(a) + flags = RequestFlags(return_cas_token=True, cache_ttl=300) + b.sendall(b"EN\r\nEN\r\nEN\r\n") + ms.meta_multiget([b"key1", b"key2", b"key3"], flags) + data = b.recv(4096) + assert data == ( + b"mg key1 c T300\r\n" + b"mg key2 c T300\r\n" + b"mg key3 c T300\r\n" + ) + + def test_string_keys(self, socket_pair): + """String keys should work (extracted as UTF-8 bytes).""" + a, b = socket_pair + ms = MemcacheSocket(a) + flags = RequestFlags(return_value=True) + b.sendall(b"VA 2\r\nhi\r\nEN\r\n") + results = ms.meta_multiget(["mykey", "other"], flags) + assert len(results) == 2 + assert isinstance(results[0], Value) + assert results[0].value == b"hi" + assert isinstance(results[1], Miss) + # Verify wire format + data = b.recv(4096) + assert data == b"mg mykey v\r\nmg other v\r\n" + + def test_mixed_key_types(self, socket_pair): + """Mix of str and bytes keys.""" + a, b = socket_pair + ms = MemcacheSocket(a) + b.sendall(b"EN\r\nEN\r\n") + results = ms.meta_multiget(["strkey", b"byteskey"]) + assert len(results) == 2 + data = b.recv(4096) + assert data == b"mg strkey\r\nmg byteskey\r\n" + + def test_empty_key_raises(self, socket_pair): + a, b = socket_pair + ms = MemcacheSocket(a) + with pytest.raises(ValueError): + ms.meta_multiget([b"good", b"", b"also_good"]) + + def test_large_value_in_batch(self, socket_pair): + """Values larger than buffer should work in batch mode.""" + a, b = socket_pair + ms = MemcacheSocket(a, buffer_size=100) + payload = b"x" * 200 + b.sendall( + b"VA 200\r\n" + payload + b"\r\n" + b"EN\r\n" + b"VA 3\r\nfoo\r\n" + ) + results = ms.meta_multiget([b"k1", b"k2", b"k3"], RequestFlags(return_value=True)) + assert len(results) == 3 + assert isinstance(results[0], Value) + assert results[0].value == payload + assert isinstance(results[1], Miss) + assert isinstance(results[2], Value) + assert results[2].value == b"foo" + + def test_small_buffer_many_keys(self, socket_pair): + """Stress buffer reset logic with many keys and small buffer.""" + a, b = socket_pair + ms = MemcacheSocket(a, buffer_size=32) + num_keys = 50 + keys = [f"k{i}".encode() for i in range(num_keys)] + # Alternate hits and misses + response = b"" + for i in range(num_keys): + if i % 2 == 0: + response += b"VA 1\r\n" + str(i % 10).encode() + b"\r\n" + else: + response += b"EN\r\n" + b.sendall(response) + results = ms.meta_multiget(keys, RequestFlags(return_value=True)) + assert len(results) == num_keys + for i, r in enumerate(results): + if i % 2 == 0: + assert isinstance(r, Value), f"Expected Value at index {i}" + assert r.value == str(i % 10).encode() + else: + assert isinstance(r, Miss), f"Expected Miss at index {i}" + + def test_multiget_then_regular_get(self, socket_pair): + """Socket state should be clean after meta_multiget.""" + a, b = socket_pair + ms = MemcacheSocket(a) + b.sendall(b"VA 3\r\nfoo\r\nEN\r\n") + results = ms.meta_multiget([b"k1", b"k2"]) + assert len(results) == 2 + + # Regular get should still work + b.sendall(b"VA 3 c99\r\nbar\r\n") + resp = ms.meta_get(b"k3", RequestFlags(return_cas_token=True)) + assert isinstance(resp, Value) + assert resp.value == b"bar" + assert resp.flags.cas_token == 99 + + def test_noop_draining_before_multiget(self, socket_pair): + """Pending NOOPs should be drained before multiget responses.""" + a, b = socket_pair + ms = MemcacheSocket(a) + # Send a no_reply delete (injects noop) + flags_noreply = RequestFlags(no_reply=True) + ms.meta_delete(b"old_key", flags_noreply) + + # Now do a multiget — server sends noop from delete, then multiget responses + b.sendall(b"MN\r\nVA 2\r\nhi\r\nEN\r\n") + results = ms.meta_multiget([b"k1", b"k2"]) + assert len(results) == 2 + assert isinstance(results[0], Value) + assert results[0].value == b"hi" + assert isinstance(results[1], Miss) + + def test_with_flags(self, socket_pair): + """All response flags should be correctly parsed in batch mode.""" + a, b = socket_pair + ms = MemcacheSocket(a) + flags = RequestFlags( + return_cas_token=True, + return_value=True, + return_ttl=True, + return_client_flag=True, + ) + b.sendall( + b"VA 3 c10 t300 f42\r\nabc\r\n" + b"VA 2 c20 t600 f99\r\nxy\r\n" + ) + results = ms.meta_multiget([b"k1", b"k2"], flags) + assert len(results) == 2 + assert results[0].flags.cas_token == 10 + assert results[0].flags.ttl == 300 + assert results[0].flags.client_flag == 42 + assert results[0].value == b"abc" + assert results[1].flags.cas_token == 20 + assert results[1].flags.ttl == 600 + assert results[1].flags.client_flag == 99 + assert results[1].value == b"xy" + + def test_timeout_socket(self, socket_pair): + """meta_multiget works with sockets that have a timeout set.""" + a, b = socket_pair + a.settimeout(5.0) + ms = MemcacheSocket(a) + b.sendall( + b"VA 3\r\nfoo\r\n" + b"EN\r\n" + ) + results = ms.meta_multiget([b"k1", b"k2"]) + assert len(results) == 2 + assert isinstance(results[0], Value) + assert results[0].value == b"foo" + assert isinstance(results[1], Miss) + + # --- String key encoding ---