Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions meta_memcache_socket.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -349,3 +349,21 @@ class MemcacheSocket:
key: Union[str, bytes],
request_flags: Optional[RequestFlags] = None,
) -> Union[Value, Success, Miss, NotStored, Conflict]: ...

# Batch operations
def meta_multiget(
self,
keys: list[Union[str, bytes]],
request_flags: Optional[RequestFlags] = None,
) -> list[Union[Value, Success, Miss, NotStored, Conflict]]:
"""
Send multiple meta get commands and return all responses in one batch.

Builds all commands into one buffer, sends in a single operation, then
receives all responses in a tight Rust loop. GIL is released during
all socket I/O. Returns a list of responses in the same order as keys.

:param keys: List of keys to get
:param request_flags: The flags to use for all keys
"""
...
70 changes: 69 additions & 1 deletion src/memcache_socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use log::warn;
use pyo3::BoundObject;
use pyo3::exceptions::{PyConnectionError, PyTimeoutError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::PyBytes;
use pyo3::types::{PyBytes, PyList};

use crate::constants::*;
use crate::encode_key::extract_key;
Expand Down Expand Up @@ -807,4 +807,72 @@ impl MemcacheSocket {
CmdResult::Response((header, value_data)) => self.make_response(py, header, value_data),
}
}

// -----------------------------------------------------------------------
// Batch: meta_multiget (pipelined multi-key get in one Rust call)
// -----------------------------------------------------------------------

/// Send multiple meta get commands and return all responses in one batch.
///
/// Builds all commands into one buffer, sends in a single operation, then
/// receives all responses in a tight Rust loop. GIL is released during
/// all socket I/O. Returns a list of responses in the same order as keys.
#[pyo3(signature = (keys, request_flags=None))]
pub fn meta_multiget(
&mut self,
py: Python<'_>,
keys: &Bound<'_, PyList>,
request_flags: Option<&RequestFlags>,
) -> PyResult<Vec<Py<PyAny>>> {
let num_keys = keys.len();
if num_keys == 0 {
return Ok(Vec::new());
}

// Build all mg commands into one buffer (GIL held for key extraction)
let mut cmd_buf: Vec<u8> = Vec::with_capacity(num_keys * 64);
for i in 0..num_keys {
let key_obj = keys.get_item(i)?;
let cmd = self.build_cmd(b"mg", &key_obj, None, request_flags)?;
cmd_buf.extend_from_slice(&cmd.buf);
}

// Send all commands, then receive all responses (GIL released).
// Value data is copied to owned Vecs so buffer resets during
// subsequent reads don't invalidate earlier responses.
let io = &mut self.io;
let raw_responses: Vec<(ParsedHeader, Option<Vec<u8>>)> = py
.detach(|| {
send_all(io.fd, &cmd_buf, io.timeout_ms)?;
let mut responses = Vec::with_capacity(num_keys);
for _ in 0..num_keys {
let header = io.get_header()?;
let value = if header.response_type == Some(RESPONSE_VALUE) {
let size = header.size.unwrap_or(0) as usize;
let vd = io.ensure_value(size)?;
Some(match vd {
ValueData::InBuffer(start) => {
// Convert to vector, as the buffer will be
// reused and data will be overwritten
io.buf[start..start + size].to_vec()
}
ValueData::Allocated(data) => data,
})
} else {
None
};
responses.push((header, value));
}
Ok(responses)
})
.map_err(|e: std::io::Error| socket_err_io("Error in meta_multiget", e))?;

// Convert raw responses to Python objects (GIL held)
let mut results = Vec::with_capacity(num_keys);
for (header, value_bytes) in raw_responses {
let value_data = value_bytes.map(ValueData::Allocated);
results.push(self.make_response(py, header, value_data)?);
}
Ok(results)
}
}
235 changes: 235 additions & 0 deletions tests/test_memcache_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,6 +936,241 @@ def test_set_socket_updates_timeout(self, socket_pair):
d.close()


# --- meta_multiget (batch) ---


class TestMetaMultiget:
def test_empty_keys(self, socket_pair):
a, b = socket_pair
ms = MemcacheSocket(a)
results = ms.meta_multiget([])
assert results == []

def test_single_key_hit(self, socket_pair):
a, b = socket_pair
ms = MemcacheSocket(a)
b.sendall(b"VA 5 c1\r\nhello\r\n")
results = ms.meta_multiget([b"key1"], RequestFlags(return_cas_token=True, return_value=True))
assert len(results) == 1
assert isinstance(results[0], Value)
assert results[0].value == b"hello"
assert results[0].flags.cas_token == 1

def test_single_key_miss(self, socket_pair):
a, b = socket_pair
ms = MemcacheSocket(a)
b.sendall(b"EN\r\n")
results = ms.meta_multiget([b"key1"])
assert len(results) == 1
assert isinstance(results[0], Miss)

def test_multiple_keys_all_hits(self, socket_pair):
a, b = socket_pair
ms = MemcacheSocket(a)
flags = RequestFlags(return_cas_token=True, return_value=True)
b.sendall(
b"VA 3 c1\r\nfoo\r\n"
b"VA 3 c2\r\nbar\r\n"
b"VA 3 c3\r\nbaz\r\n"
)
results = ms.meta_multiget([b"k1", b"k2", b"k3"], flags)
assert len(results) == 3
assert all(isinstance(r, Value) for r in results)
assert results[0].value == b"foo"
assert results[0].flags.cas_token == 1
assert results[1].value == b"bar"
assert results[1].flags.cas_token == 2
assert results[2].value == b"baz"
assert results[2].flags.cas_token == 3

def test_multiple_keys_all_misses(self, socket_pair):
a, b = socket_pair
ms = MemcacheSocket(a)
b.sendall(b"EN\r\n" * 5)
results = ms.meta_multiget([b"k1", b"k2", b"k3", b"k4", b"k5"])
assert len(results) == 5
assert all(isinstance(r, Miss) for r in results)

def test_mixed_hits_and_misses(self, socket_pair):
a, b = socket_pair
ms = MemcacheSocket(a)
flags = RequestFlags(return_value=True)
b.sendall(
b"VA 3 f1\r\nfoo\r\n"
b"EN\r\n"
b"VA 3 f2\r\nbar\r\n"
b"EN\r\n"
b"VA 3 f3\r\nbaz\r\n"
)
results = ms.meta_multiget([b"k1", b"k2", b"k3", b"k4", b"k5"], flags)
assert len(results) == 5
assert isinstance(results[0], Value)
assert results[0].value == b"foo"
assert isinstance(results[1], Miss)
assert isinstance(results[2], Value)
assert results[2].value == b"bar"
assert isinstance(results[3], Miss)
assert isinstance(results[4], Value)
assert results[4].value == b"baz"

def test_verifies_wire_format(self, socket_pair):
"""meta_multiget sends correct wire commands."""
a, b = socket_pair
ms = MemcacheSocket(a)
flags = RequestFlags(return_cas_token=True, cache_ttl=300)
b.sendall(b"EN\r\nEN\r\nEN\r\n")
ms.meta_multiget([b"key1", b"key2", b"key3"], flags)
data = b.recv(4096)
assert data == (
b"mg key1 c T300\r\n"
b"mg key2 c T300\r\n"
b"mg key3 c T300\r\n"
)

def test_string_keys(self, socket_pair):
"""String keys should work (extracted as UTF-8 bytes)."""
a, b = socket_pair
ms = MemcacheSocket(a)
flags = RequestFlags(return_value=True)
b.sendall(b"VA 2\r\nhi\r\nEN\r\n")
results = ms.meta_multiget(["mykey", "other"], flags)
assert len(results) == 2
assert isinstance(results[0], Value)
assert results[0].value == b"hi"
assert isinstance(results[1], Miss)
# Verify wire format
data = b.recv(4096)
assert data == b"mg mykey v\r\nmg other v\r\n"

def test_mixed_key_types(self, socket_pair):
"""Mix of str and bytes keys."""
a, b = socket_pair
ms = MemcacheSocket(a)
b.sendall(b"EN\r\nEN\r\n")
results = ms.meta_multiget(["strkey", b"byteskey"])
assert len(results) == 2
data = b.recv(4096)
assert data == b"mg strkey\r\nmg byteskey\r\n"

def test_empty_key_raises(self, socket_pair):
a, b = socket_pair
ms = MemcacheSocket(a)
with pytest.raises(ValueError):
ms.meta_multiget([b"good", b"", b"also_good"])

def test_large_value_in_batch(self, socket_pair):
"""Values larger than buffer should work in batch mode."""
a, b = socket_pair
ms = MemcacheSocket(a, buffer_size=100)
payload = b"x" * 200
b.sendall(
b"VA 200\r\n" + payload + b"\r\n"
b"EN\r\n"
b"VA 3\r\nfoo\r\n"
)
results = ms.meta_multiget([b"k1", b"k2", b"k3"], RequestFlags(return_value=True))
assert len(results) == 3
assert isinstance(results[0], Value)
assert results[0].value == payload
assert isinstance(results[1], Miss)
assert isinstance(results[2], Value)
assert results[2].value == b"foo"
Comment thread
bisho marked this conversation as resolved.

def test_small_buffer_many_keys(self, socket_pair):
"""Stress buffer reset logic with many keys and small buffer."""
a, b = socket_pair
ms = MemcacheSocket(a, buffer_size=32)
num_keys = 50
keys = [f"k{i}".encode() for i in range(num_keys)]
# Alternate hits and misses
response = b""
for i in range(num_keys):
if i % 2 == 0:
response += b"VA 1\r\n" + str(i % 10).encode() + b"\r\n"
else:
response += b"EN\r\n"
b.sendall(response)
results = ms.meta_multiget(keys, RequestFlags(return_value=True))
assert len(results) == num_keys
for i, r in enumerate(results):
if i % 2 == 0:
assert isinstance(r, Value), f"Expected Value at index {i}"
assert r.value == str(i % 10).encode()
else:
assert isinstance(r, Miss), f"Expected Miss at index {i}"
Comment thread
bisho marked this conversation as resolved.

def test_multiget_then_regular_get(self, socket_pair):
"""Socket state should be clean after meta_multiget."""
a, b = socket_pair
ms = MemcacheSocket(a)
b.sendall(b"VA 3\r\nfoo\r\nEN\r\n")
results = ms.meta_multiget([b"k1", b"k2"])
assert len(results) == 2

# Regular get should still work
b.sendall(b"VA 3 c99\r\nbar\r\n")
resp = ms.meta_get(b"k3", RequestFlags(return_cas_token=True))
assert isinstance(resp, Value)
assert resp.value == b"bar"
assert resp.flags.cas_token == 99

def test_noop_draining_before_multiget(self, socket_pair):
"""Pending NOOPs should be drained before multiget responses."""
a, b = socket_pair
ms = MemcacheSocket(a)
# Send a no_reply delete (injects noop)
flags_noreply = RequestFlags(no_reply=True)
ms.meta_delete(b"old_key", flags_noreply)

# Now do a multiget — server sends noop from delete, then multiget responses
b.sendall(b"MN\r\nVA 2\r\nhi\r\nEN\r\n")
results = ms.meta_multiget([b"k1", b"k2"])
assert len(results) == 2
assert isinstance(results[0], Value)
assert results[0].value == b"hi"
assert isinstance(results[1], Miss)

def test_with_flags(self, socket_pair):
"""All response flags should be correctly parsed in batch mode."""
a, b = socket_pair
ms = MemcacheSocket(a)
flags = RequestFlags(
return_cas_token=True,
return_value=True,
return_ttl=True,
return_client_flag=True,
)
b.sendall(
b"VA 3 c10 t300 f42\r\nabc\r\n"
b"VA 2 c20 t600 f99\r\nxy\r\n"
)
results = ms.meta_multiget([b"k1", b"k2"], flags)
assert len(results) == 2
assert results[0].flags.cas_token == 10
assert results[0].flags.ttl == 300
assert results[0].flags.client_flag == 42
assert results[0].value == b"abc"
assert results[1].flags.cas_token == 20
assert results[1].flags.ttl == 600
assert results[1].flags.client_flag == 99
assert results[1].value == b"xy"

def test_timeout_socket(self, socket_pair):
"""meta_multiget works with sockets that have a timeout set."""
a, b = socket_pair
a.settimeout(5.0)
ms = MemcacheSocket(a)
b.sendall(
b"VA 3\r\nfoo\r\n"
b"EN\r\n"
)
results = ms.meta_multiget([b"k1", b"k2"])
assert len(results) == 2
assert isinstance(results[0], Value)
assert results[0].value == b"foo"
assert isinstance(results[1], Miss)


# --- String key encoding ---


Expand Down
Loading