From f30afc3ab17bd3607729f9c275e62e0ea9202988 Mon Sep 17 00:00:00 2001 From: Guillermo Perez Date: Wed, 11 Mar 2026 23:06:30 +0100 Subject: [PATCH] Direct meta commands api in socket --- Cargo.lock | 38 ++- Cargo.toml | 2 + meta_memcache_socket.pyi | 50 +++- src/impl_build_cmd.rs | 19 +- src/impl_build_cmd_tests.rs | 52 ++-- src/lib.rs | 56 +++- src/memcache_socket.rs | 549 ++++++++++++++++++++++++++-------- src/request_flags.rs | 13 +- src/request_flags_tests.rs | 2 +- tests/test_memcache_socket.py | 432 ++++++++++++++++++-------- 10 files changed, 917 insertions(+), 296 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6ed298b..976edee 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,15 @@ # It is not intended for manual editing. version = 4 +[[package]] +name = "arc-swap" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a07d1f37ff60921c83bdfc7407723bdefe89b44b98a9b772f225c8f9d67141a6" +dependencies = [ + "rustversion", +] + [[package]] name = "atoi" version = "2.0.0" @@ -31,9 +40,9 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" [[package]] name = "itoa" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" +checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" [[package]] name = "libc" @@ -41,6 +50,12 @@ version = "0.2.183" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5b646652bf6661599e1da8901b3b9522896f01e736bad5f723fe7a3a27f899d" +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + [[package]] name = "memchr" version = "2.8.0" @@ -55,8 +70,10 @@ dependencies = [ "base64", "itoa", "libc", + "log", "memchr", "pyo3", + "pyo3-log", ] [[package]] @@ -122,6 +139,17 @@ dependencies = [ "pyo3-build-config", ] +[[package]] +name = "pyo3-log" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26c2ec80932c5c3b2d4fbc578c9b56b2d4502098587edb8bef5b6bfcad43682e" +dependencies = [ + "arc-swap", + "log", + "pyo3", +] + [[package]] name = "pyo3-macros" version = "0.28.2" @@ -156,6 +184,12 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + [[package]] name = "syn" version = "2.0.117" diff --git a/Cargo.toml b/Cargo.toml index 0066997..06f3d51 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,3 +14,5 @@ itoa = "1" libc = "0.2" memchr = "2" pyo3 = { version = "0.28", features = ["extension-module"] } +log = "0.4" +pyo3-log = "0.13" diff --git a/meta_memcache_socket.pyi b/meta_memcache_socket.pyi index e00828c..f4f1396 100644 --- a/meta_memcache_socket.pyi +++ b/meta_memcache_socket.pyi @@ -1,5 +1,5 @@ -from typing import Any, Optional, Tuple, Union import socket +from typing import Any, Optional, Tuple, Union RESPONSE_VALUE: int # 1 - VALUE (VA) RESPONSE_SUCCESS: int # 2 - SUCCESS (OK or HD) @@ -151,7 +151,6 @@ class ResponseFlags: opaque: Optional[bytes] = None, ) -> None: ... def __str__(self) -> str: ... - @staticmethod def from_success_header(header: bytes) -> "ResponseFlags": """Parse response flags from a success (HD) header.""" @@ -305,4 +304,49 @@ class MemcacheSocket: def close(self) -> None: ... def sendall(self, data: bytes, with_noop: bool) -> None: ... def get_response(self) -> Union[Value, Success, Miss, NotStored, Conflict]: ... - def get_value(self, size: int) -> bytes: ... + # send_meta_* methods (for pipelining — send only, read later with get_response()) + # Mutations automatically inject NOOP when no_reply is set in request_flags. + def send_meta_get( + self, + key: bytes, + request_flags: Optional[RequestFlags] = None, + ) -> None: ... + def send_meta_set( + self, + key: bytes, + value: bytes, + request_flags: Optional[RequestFlags] = None, + ) -> None: ... + def send_meta_delete( + self, + key: bytes, + request_flags: Optional[RequestFlags] = None, + ) -> None: ... + def send_meta_arithmetic( + self, + key: bytes, + request_flags: Optional[RequestFlags] = None, + ) -> None: ... + + # meta_* methods (blocking — send + recv in one call) + def meta_get( + self, + key: bytes, + request_flags: Optional[RequestFlags] = None, + ) -> Union[Value, Success, Miss, NotStored, Conflict]: ... + def meta_set( + self, + key: bytes, + value: bytes, + request_flags: Optional[RequestFlags] = None, + ) -> Union[Value, Success, Miss, NotStored, Conflict]: ... + def meta_delete( + self, + key: bytes, + request_flags: Optional[RequestFlags] = None, + ) -> Union[Value, Success, Miss, NotStored, Conflict]: ... + def meta_arithmetic( + self, + key: bytes, + request_flags: Optional[RequestFlags] = None, + ) -> Union[Value, Success, Miss, NotStored, Conflict]: ... diff --git a/src/impl_build_cmd.rs b/src/impl_build_cmd.rs index 2cbca96..73a3a0a 100644 --- a/src/impl_build_cmd.rs +++ b/src/impl_build_cmd.rs @@ -5,13 +5,19 @@ use crate::RequestFlags; const MAX_KEY_LEN: usize = 250; const MAX_BIN_KEY_LEN: usize = 187; // 250 * 3 / 4 due to b64 encoding +pub struct BuiltCmd { + pub buf: Vec, + pub no_reply: bool, +} + pub fn impl_build_cmd( cmd: &[u8], key: &[u8], size: Option, request_flags: Option<&RequestFlags>, legacy_size_format: bool, -) -> Option> { + allow_no_reply_flag: bool, +) -> Option { if key.is_empty() || key.len() >= MAX_KEY_LEN { return None; } @@ -55,10 +61,13 @@ pub fn impl_build_cmd( buf.push(b' '); buf.push(b'b'); } - if let Some(request_flags) = request_flags { - request_flags.push_bytes(&mut buf); - } + let no_reply = if let Some(request_flags) = request_flags { + request_flags.push_bytes(&mut buf, allow_no_reply_flag); + allow_no_reply_flag && request_flags.is_no_reply() + } else { + false + }; buf.push(b'\r'); buf.push(b'\n'); - Some(buf) + Some(BuiltCmd { buf, no_reply }) } diff --git a/src/impl_build_cmd_tests.rs b/src/impl_build_cmd_tests.rs index 6ce2cf4..5379fed 100644 --- a/src/impl_build_cmd_tests.rs +++ b/src/impl_build_cmd_tests.rs @@ -30,11 +30,12 @@ mod tests { Some(b'A'), // mode (APPEND) ); - let result = impl_build_cmd(cmd, key, None, Some(&request_flags), false).unwrap(); - let string = String::from_utf8_lossy(&result); + let built = impl_build_cmd(cmd, key, None, Some(&request_flags), false, true).unwrap(); + let string = String::from_utf8_lossy(&built.buf); println!("{:?}", string); + assert!(built.no_reply); assert_eq!( - result, + built.buf, b"mg key q f c v t s l h k u I T111 R222 N333 F444 J555 D666 C777 Oopaque MA\r\n" ); } @@ -66,10 +67,11 @@ mod tests { None, // mode ); - let result = impl_build_cmd(cmd, key, None, Some(&request_flags), false).unwrap(); - let string = String::from_utf8_lossy(&result); + let built = impl_build_cmd(cmd, key, None, Some(&request_flags), false, true).unwrap(); + let string = String::from_utf8_lossy(&built.buf); println!("{:?}", string); - assert_eq!(result, b"mg key\r\n"); + assert!(!built.no_reply); + assert_eq!(built.buf, b"mg key\r\n"); } #[test] @@ -99,10 +101,10 @@ mod tests { None, // mode ); - let result = impl_build_cmd(cmd, key, None, Some(&request_flags), false).unwrap(); - let string = String::from_utf8_lossy(&result); + let built = impl_build_cmd(cmd, key, None, Some(&request_flags), false, true).unwrap(); + let string = String::from_utf8_lossy(&built.buf); println!("{:?}", string); - assert_eq!(result, b"mg S2V5X3dpdGhfYmluYXJ5AA== b\r\n"); + assert_eq!(built.buf, b"mg S2V5X3dpdGhfYmluYXJ5AA== b\r\n"); } #[test] @@ -132,50 +134,50 @@ mod tests { None, // mode ); - let result = impl_build_cmd(cmd, key, None, Some(&request_flags), false).unwrap(); - let string = String::from_utf8_lossy(&result); + let built = impl_build_cmd(cmd, key, None, Some(&request_flags), false, true).unwrap(); + let string = String::from_utf8_lossy(&built.buf); println!("{:?}", string); - assert_eq!(result, b"mg S2V5IHdpdGggc3BhY2Vz b\r\n"); + assert_eq!(built.buf, b"mg S2V5IHdpdGggc3BhY2Vz b\r\n"); } #[test] fn test_empty_key_rejected() { - assert!(impl_build_cmd(b"mg", b"", None, None, false).is_none()); + assert!(impl_build_cmd(b"mg", b"", None, None, false, true).is_none()); } #[test] fn test_key_at_max_length() { // 249 bytes is OK (< 250) let key = &vec![b'X'; 249]; - assert!(impl_build_cmd(b"mg", key, None, None, false).is_some()); + assert!(impl_build_cmd(b"mg", key, None, None, false, true).is_some()); } #[test] fn test_key_at_exact_max_length() { // 250 bytes is rejected (>= MAX_KEY_LEN) let key = &vec![b'X'; 250]; - assert!(impl_build_cmd(b"mg", key, None, None, false).is_none()); + assert!(impl_build_cmd(b"mg", key, None, None, false, true).is_none()); } #[test] fn test_binary_key_at_max_length() { // 186 binary bytes is OK (< 187 = MAX_BIN_KEY_LEN) let key = &vec![0x00u8; 186]; - assert!(impl_build_cmd(b"mg", key, None, None, false).is_some()); + assert!(impl_build_cmd(b"mg", key, None, None, false, true).is_some()); } #[test] fn test_binary_key_at_exact_max_length() { // 187 binary bytes is rejected (>= MAX_BIN_KEY_LEN) let key = &vec![0x00u8; 187]; - assert!(impl_build_cmd(b"mg", key, None, None, false).is_none()); + assert!(impl_build_cmd(b"mg", key, None, None, false, true).is_none()); } #[test] fn test_impl_build_cmd_large_key() { let cmd = b"mg"; let key = &vec![b'X'; 250]; - assert!(impl_build_cmd(cmd, key, None, None, false).is_none()); + assert!(impl_build_cmd(cmd, key, None, None, false, true).is_none()); } #[test] @@ -206,10 +208,11 @@ mod tests { None, // mode ); - let result = impl_build_cmd(cmd, key, Some(size), Some(&request_flags), false).unwrap(); - let string = String::from_utf8_lossy(&result); + let built = + impl_build_cmd(cmd, key, Some(size), Some(&request_flags), false, true).unwrap(); + let string = String::from_utf8_lossy(&built.buf); println!("{:?}", string); - assert_eq!(result, b"ms key 123 T111\r\n"); + assert_eq!(built.buf, b"ms key 123 T111\r\n"); } #[test] @@ -218,9 +221,10 @@ mod tests { let key = b"key"; let size = 123; - let result = impl_build_cmd(cmd, key, Some(size), None, true).unwrap(); - let string = String::from_utf8_lossy(&result); + let built = impl_build_cmd(cmd, key, Some(size), None, true, true).unwrap(); + let string = String::from_utf8_lossy(&built.buf); println!("{:?}", string); - assert_eq!(result, b"ms key S123\r\n"); + assert!(!built.no_reply); + assert_eq!(built.buf, b"ms key S123\r\n"); } } diff --git a/src/lib.rs b/src/lib.rs index 3596f82..af63206 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -65,8 +65,15 @@ pub fn build_cmd<'py>( request_flags: Option<&RequestFlags>, legacy_size_format: bool, ) -> PyResult> { - match impl_build_cmd(cmd, key, size, request_flags, legacy_size_format) { - Some(buf) => Ok(PyBytes::new(py, &buf)), + match impl_build_cmd( + cmd, + key, + size, + request_flags, + legacy_size_format, + /* allow_no_reply_flag */ true, + ) { + Some(built) => Ok(PyBytes::new(py, &built.buf)), None => Err(pyo3::exceptions::PyValueError::new_err("Key is too long")), } } @@ -84,8 +91,15 @@ pub fn build_meta_get<'py>( key: &[u8], request_flags: Option<&RequestFlags>, ) -> PyResult> { - match impl_build_cmd(b"mg", key, None, request_flags, false) { - Some(buf) => Ok(PyBytes::new(py, &buf)), + match impl_build_cmd( + b"mg", + key, + None, + request_flags, + /* legacy_size_format */ false, + /* allow_no_reply_flag */ true, + ) { + Some(built) => Ok(PyBytes::new(py, &built.buf)), None => Err(pyo3::exceptions::PyValueError::new_err("Key is too long")), } } @@ -107,8 +121,15 @@ pub fn build_meta_set<'py>( request_flags: Option<&RequestFlags>, legacy_size_format: bool, ) -> PyResult> { - match impl_build_cmd(b"ms", key, Some(size), request_flags, legacy_size_format) { - Some(buf) => Ok(PyBytes::new(py, &buf)), + match impl_build_cmd( + b"ms", + key, + Some(size), + request_flags, + legacy_size_format, + /* allow_no_reply_flag */ true, + ) { + Some(built) => Ok(PyBytes::new(py, &built.buf)), None => Err(pyo3::exceptions::PyValueError::new_err("Key is too long")), } } @@ -126,8 +147,15 @@ pub fn build_meta_delete<'py>( key: &[u8], request_flags: Option<&RequestFlags>, ) -> PyResult> { - match impl_build_cmd(b"md", key, None, request_flags, false) { - Some(buf) => Ok(PyBytes::new(py, &buf)), + match impl_build_cmd( + b"md", + key, + None, + request_flags, + /* legacy_size_format */ false, + /* allow_no_reply_flag */ true, + ) { + Some(built) => Ok(PyBytes::new(py, &built.buf)), None => Err(pyo3::exceptions::PyValueError::new_err("Key is too long")), } } @@ -145,14 +173,22 @@ pub fn build_meta_arithmetic<'py>( key: &[u8], request_flags: Option<&RequestFlags>, ) -> PyResult> { - match impl_build_cmd(b"ma", key, None, request_flags, false) { - Some(buf) => Ok(PyBytes::new(py, &buf)), + match impl_build_cmd( + b"ma", + key, + None, + request_flags, + /* legacy_size_format */ false, + /* allow_no_reply_flag */ true, + ) { + Some(built) => Ok(PyBytes::new(py, &built.buf)), None => Err(pyo3::exceptions::PyValueError::new_err("Key is too long")), } } #[pymodule(gil_used = false)] fn meta_memcache_socket(module: &Bound<'_, PyModule>) -> PyResult<()> { + pyo3_log::init(); // Classes module.add_class::()?; module.add_class::()?; diff --git a/src/memcache_socket.rs b/src/memcache_socket.rs index 7eb4cd2..0daba09 100644 --- a/src/memcache_socket.rs +++ b/src/memcache_socket.rs @@ -1,12 +1,17 @@ use std::os::fd::RawFd; +use log::warn; + use pyo3::BoundObject; -use pyo3::exceptions::{PyConnectionError, PyTimeoutError}; +use pyo3::exceptions::{PyConnectionError, PyTimeoutError, PyValueError}; use pyo3::prelude::*; use pyo3::types::PyBytes; use crate::constants::*; +use crate::impl_build_cmd::{BuiltCmd, impl_build_cmd}; use crate::impl_parse_header::{ParsedHeader, impl_parse_header}; +use crate::request_flags::RequestFlags; +use crate::response_flags::ResponseFlags; use crate::response_types::*; const DEFAULT_BUFFER_SIZE: usize = 4096; @@ -61,31 +66,33 @@ fn poll_fd( events: libc::c_short, timeout_ms: libc::c_int, ) -> Result<(), std::io::Error> { - let mut pfd = libc::pollfd { - fd, - events, - revents: 0, - }; - // SAFETY: pfd is a valid pollfd struct on the stack, nfds=1 - let ret = unsafe { libc::poll(&mut pfd, 1, timeout_ms) }; - if ret < 0 { - let err = std::io::Error::last_os_error(); - if err.kind() == std::io::ErrorKind::Interrupted { - return poll_fd(fd, events, timeout_ms); // signal interrupted us, retry + loop { + let mut pfd = libc::pollfd { + fd, + events, + revents: 0, + }; + // SAFETY: pfd is a valid pollfd struct on the stack, nfds=1 + let ret = unsafe { libc::poll(&mut pfd, 1, timeout_ms) }; + if ret < 0 { + let err = std::io::Error::last_os_error(); + if err.kind() == std::io::ErrorKind::Interrupted { + continue; + } + return Err(err); + } else if ret == 0 { + return Err(std::io::Error::new( + std::io::ErrorKind::TimedOut, + "timed out", + )); + } else if pfd.revents & (libc::POLLERR | libc::POLLHUP | libc::POLLNVAL) != 0 { + return Err(std::io::Error::new( + std::io::ErrorKind::ConnectionReset, + "poll error on socket", + )); + } else { + return Ok(()); } - Err(err) - } else if ret == 0 { - Err(std::io::Error::new( - std::io::ErrorKind::TimedOut, - "timed out", - )) - } else if pfd.revents & (libc::POLLERR | libc::POLLHUP | libc::POLLNVAL) != 0 { - Err(std::io::Error::new( - std::io::ErrorKind::ConnectionReset, - "poll error on socket", - )) - } else { - Ok(()) } } @@ -122,28 +129,25 @@ fn send_all(fd: RawFd, data: &[u8], timeout_ms: libc::c_int) -> Result<(), std:: Ok(()) } -/// Send data + NOOP command in a single write when possible. -/// Uses writev() to avoid concatenation on the happy path, falls back -/// to send_all() for partial writes and EAGAIN. +/// Send multiple buffers in a single writev() syscall. +/// Falls back to send_all() for partial writes. #[inline] -fn send_all_with_noop( - fd: RawFd, - data: &[u8], - timeout_ms: libc::c_int, -) -> Result<(), std::io::Error> { - let iov = [ - libc::iovec { - iov_base: data.as_ptr() as *mut libc::c_void, - iov_len: data.len(), - }, - libc::iovec { - iov_base: NOOP_CMD.as_ptr() as *mut libc::c_void, - iov_len: NOOP_CMD.len(), - }, - ]; - // SAFETY: iov array has 2 valid entries pointing to data and NOOP_CMD - let n = unsafe { libc::writev(fd, iov.as_ptr(), 2) }; - let total_len = data.len() + NOOP_CMD.len(); +fn send_iovecs(fd: RawFd, slices: &[&[u8]], timeout_ms: libc::c_int) -> Result<(), std::io::Error> { + let total_len: usize = slices.iter().map(|s| s.len()).sum(); + if total_len == 0 { + return Ok(()); + } + + let mut iovecs: Vec = Vec::with_capacity(slices.len()); + for slice in slices { + iovecs.push(libc::iovec { + iov_base: slice.as_ptr() as *mut libc::c_void, + iov_len: slice.len(), + }); + } + + // SAFETY: iovecs entries point to valid byte slices for the duration of writev + let n = unsafe { libc::writev(fd, iovecs.as_ptr(), iovecs.len() as i32) }; let written = if n >= 0 { n as usize } else { @@ -153,11 +157,23 @@ fn send_all_with_noop( _ => return Err(err), } }; - if written < total_len { - let combined: Vec = [data, NOOP_CMD].concat(); - send_all(fd, &combined[written..], timeout_ms)?; + + if written >= total_len { + return Ok(()); } - Ok(()) + + // Partial write: concatenate remaining and send_all + let mut combined: Vec = Vec::with_capacity(total_len - written); + let mut skip = written; + for slice in slices { + if skip >= slice.len() { + skip -= slice.len(); + } else { + combined.extend_from_slice(&slice[skip..]); + skip = 0; + } + } + send_all(fd, &combined, timeout_ms) } /// Recv into buffer slice, returns bytes read. Handles EAGAIN by polling. @@ -215,12 +231,18 @@ fn recv_fill(fd: RawFd, buf: &mut [u8], timeout_ms: libc::c_int) -> Result), } +enum CmdResult { + NoReply, + Response((ParsedHeader, Option)), +} + /// Inner I/O state — no Python objects, so it is Send/Ungil. /// This allows releasing the GIL during socket I/O via py.detach(). struct SocketIO { @@ -255,7 +277,7 @@ impl SocketIO { } fn get_single_header(&mut self) -> Result { - if self.pos >= self.read { + if self.read == self.pos { self.read = 0; self.pos = 0; } else if self.pos > self.reset_buffer_size { @@ -296,19 +318,36 @@ impl SocketIO { self.get_single_header() } - fn sendall_impl(&self, data: &[u8], with_noop: bool) -> Result<(), std::io::Error> { + fn send_cmd(&mut self, cmd: &[u8], with_noop: bool) -> Result<(), std::io::Error> { if with_noop { - send_all_with_noop(self.fd, data, self.timeout_ms) + send_iovecs(self.fd, &[cmd, NOOP_CMD], self.timeout_ms)?; + self.noop_expected += 1; } else { - send_all(self.fd, data, self.timeout_ms) + send_all(self.fd, cmd, self.timeout_ms)?; } + Ok(()) + } + + fn send_cmd_with_value( + &mut self, + cmd: &[u8], + value: &[u8], + with_noop: bool, + ) -> Result<(), std::io::Error> { + if with_noop { + send_iovecs(self.fd, &[cmd, value, ENDL, NOOP_CMD], self.timeout_ms)?; + self.noop_expected += 1; + } else { + send_iovecs(self.fd, &[cmd, value, ENDL], self.timeout_ms)?; + } + Ok(()) } /// Ensure value data is available for reading. + /// Advances pos past the value and ENDL on success. /// - /// For the common case (value fits in buffer), returns InBuffer — - /// the data is at buf[pos..pos+size] with ENDL validated. - /// The caller creates PyBytes directly from the buffer slice (zero-copy to Rust). + /// For the common case (value fits in buffer), returns InBuffer(start) — + /// the data is at buf[start..start+size]. /// /// For large values exceeding the buffer, returns Allocated with the data. fn ensure_value(&mut self, size: usize) -> Result { @@ -327,21 +366,23 @@ impl SocketIO { data_in_buf = self.read - self.pos; } + let data_start = self.pos; + if data_in_buf >= message_size { // Value + ENDL fully in buffer — validate ENDL in place - let data_end = self.pos + size; + let data_end = data_start + size; if self.buf[data_end] != b'\r' || self.buf[data_end + 1] != b'\n' { return Err(std::io::Error::new( std::io::ErrorKind::InvalidData, "Value not terminated with \\r\\n", )); } - // Don't advance pos yet — caller reads buf[pos..pos+size] then advances - Ok(ValueData::InBuffer) + self.pos = data_end + ENDL_LEN; + Ok(ValueData::InBuffer(data_start)) } else if data_in_buf >= size { // Value in buffer but ENDL partially/not in buffer. - // We still return InBuffer, but need to read+validate the ENDL. - let data_end = self.pos + size; + // Read and validate the ENDL from buffer/socket. + let data_end = data_start + size; let endl_in_buf = data_in_buf - size; let mut endl_buf = [0u8; ENDL_LEN]; if endl_in_buf > 0 { @@ -356,11 +397,9 @@ impl SocketIO { "Value not terminated with \\r\\n", )); } - // Discard partial ENDL bytes from the buffer's tracked range. - // The caller will advance pos by size + ENDL_LEN, which will - // overshoot read — get_single_header handles this via pos >= read. - self.read = data_end; - Ok(ValueData::InBuffer) + // ENDL was consumed from buffer/socket directly; buffer is fully consumed + self.pos = self.read; + Ok(ValueData::InBuffer(data_start)) } else { // Value doesn't fit in buffer — allocate and read into temp buffer let mut message = vec![0u8; message_size]; @@ -379,6 +418,21 @@ impl SocketIO { Ok(ValueData::Allocated(message)) } } + + /// Read and parse the next response header, including value data for + /// Value responses. All socket I/O happens in this method (no GIL needed). + fn get_response_with_value( + &mut self, + ) -> Result<(ParsedHeader, Option), std::io::Error> { + let header = self.get_header()?; + let value_data = if header.response_type == Some(RESPONSE_VALUE) { + let size = header.size.unwrap_or(0) as usize; + Some(self.ensure_value(size)?) + } else { + None + }; + Ok((header, value_data)) + } } #[pyclass] @@ -389,6 +443,87 @@ pub struct MemcacheSocket { version: u8, } +/// Private helpers +impl MemcacheSocket { + fn build_cmd( + &self, + cmd: &[u8], + key: &[u8], + size: Option, + request_flags: Option<&RequestFlags>, + ) -> Option { + let legacy_size_format = cmd == b"ms" && self.version == SERVER_VERSION_AWS_1_6_6; + let allow_no_reply_flag = cmd != b"mg"; + impl_build_cmd( + cmd, + key, + size, + request_flags, + legacy_size_format, + allow_no_reply_flag, + ) + } + + /// Convert a parsed header + optional value data into a Python response object. + fn make_response( + &self, + py: Python<'_>, + header: ParsedHeader, + value_data: Option, + ) -> PyResult> { + match header.response_type { + Some(RESPONSE_VALUE) => { + let size = header + .size + .ok_or_else(|| socket_err("Value response missing size"))?; + let flags = header + .flags + .ok_or_else(|| socket_err("Value response missing flags"))?; + let py_bytes = match value_data { + Some(ValueData::InBuffer(start)) => { + PyBytes::new(py, &self.io.buf[start..start + size as usize]) + } + Some(ValueData::Allocated(data)) => PyBytes::new(py, &data), + None => PyBytes::new(py, b""), + }; + into_py( + py, + Value::new(size, flags, Some(py_bytes.into_any().unbind())), + ) + } + Some(RESPONSE_SUCCESS) => { + let flags = header + .flags + .ok_or_else(|| socket_err("Success response missing flags"))?; + into_py(py, Success::new(flags)) + } + Some(RESPONSE_NOT_STORED) => into_py(py, NotStored::new()), + Some(RESPONSE_CONFLICT) => into_py(py, Conflict::new()), + Some(RESPONSE_MISS) => into_py(py, Miss::new()), + _ => Err(socket_err(&format!( + "Unknown response code: {:?}", + header.response_type + ))), + } + } + + /// Create a Success response with empty flags (for no_reply commands). + fn success_no_reply(py: Python<'_>) -> PyResult> { + let flags = ResponseFlags { + cas_token: None, + fetched: None, + last_access: None, + ttl: None, + client_flag: None, + win: None, + stale: false, + size: None, + opaque: None, + }; + into_py(py, Success::new(flags)) + } +} + #[pymethods] impl MemcacheSocket { #[new] @@ -410,9 +545,11 @@ impl MemcacheSocket { ) }; if ret != 0 { - // Non-fatal: log would be ideal but we don't have a logger here. - // The socket will still work with the kernel default buffer size. - let _ = ret; + // Non-fatal: the socket will still work with the kernel default buffer size. + warn!( + "SO_RCVBUF setsockopt failed (fd={}, requested={}), using kernel default", + fd, buffer_size + ); } Ok(MemcacheSocket { @@ -458,79 +595,231 @@ impl MemcacheSocket { Ok(()) } - /// Send data to the socket, optionally appending a NOOP command. + // ----------------------------------------------------------------------- + // Low-level: sendall + get_response + // ----------------------------------------------------------------------- + + /// Send raw data to the socket, optionally appending a NOOP command. /// Releases the GIL during socket I/O. pub fn sendall(&mut self, py: Python<'_>, data: &[u8], with_noop: bool) -> PyResult<()> { let io = &mut self.io; - py.detach(|| io.sendall_impl(data, with_noop)) + py.detach(|| io.send_cmd(data, with_noop)) .map_err(|e| socket_err_io("Error sending data", e))?; - if with_noop { - self.io.noop_expected += 1; - } Ok(()) } - /// Read and parse the next response header. - /// Releases the GIL during socket I/O and header parsing. + /// Read and parse the next response, including value data for Value responses. + /// For Value responses, `.value` is set to the raw bytes from the wire. + /// Releases the GIL during socket I/O. pub fn get_response(&mut self, py: Python<'_>) -> PyResult> { let io = &mut self.io; - let header = py - .detach(|| io.get_header()) - .map_err(|e| socket_err_io("Error reading header", e))?; + let (header, value_data) = py + .detach(|| io.get_response_with_value()) + .map_err(|e| socket_err_io("Error reading response", e))?; + self.make_response(py, header, value_data) + } - match header.response_type { - Some(RESPONSE_VALUE) => { - let size = header - .size - .ok_or_else(|| socket_err("Value response missing size"))?; - let flags = header - .flags - .ok_or_else(|| socket_err("Value response missing flags"))?; - into_py(py, Value::new(size, flags, None)) - } - Some(RESPONSE_SUCCESS) => { - let flags = header - .flags - .ok_or_else(|| socket_err("Success response missing flags"))?; - into_py(py, Success::new(flags)) - } - Some(RESPONSE_NOT_STORED) => into_py(py, NotStored::new()), - Some(RESPONSE_CONFLICT) => into_py(py, Conflict::new()), - Some(RESPONSE_MISS) => into_py(py, Miss::new()), - _ => Err(socket_err(&format!( - "Unknown response code: {:?}", - header.response_type - ))), + // ----------------------------------------------------------------------- + // Tier 1: send_meta_* (for pipelining — send only, read later) + // ----------------------------------------------------------------------- + + /// Send a meta get command. Use get_response() to read the result later. + /// Note: no_reply on mg only suppresses misses, hits still return data, + /// so noop is never injected for get commands. + #[pyo3(signature = (key, request_flags=None))] + pub fn send_meta_get( + &mut self, + py: Python<'_>, + key: &[u8], + request_flags: Option<&RequestFlags>, + ) -> PyResult<()> { + let cmd = self + .build_cmd(b"mg", key, None, request_flags) + .ok_or_else(|| PyValueError::new_err("Key is too long or empty"))?; + if cmd.no_reply { + return Err(socket_err( + "internal error: build_cmd produced no_reply=true for mg command", + )); } + let io = &mut self.io; + py.detach(|| io.send_cmd(&cmd.buf, false)) + .map_err(|e| socket_err_io("Error sending meta get", e))?; + Ok(()) } - /// Read value data from the socket. - /// Releases the GIL during socket I/O. - /// For the common case (value fits in buffer), creates PyBytes directly - /// from the buffer — no intermediate allocation. - pub fn get_value<'py>(&mut self, py: Python<'py>, size: u32) -> PyResult> { - let size_usize = size as usize; + /// Send a meta set command with value. Use get_response() to read the result later. + /// Uses writev() to send cmd + value + ENDL in a single syscall (zero concatenation). + /// If no_reply is set, automatically appends a NOOP command. + #[pyo3(signature = (key, value, request_flags=None))] + pub fn send_meta_set( + &mut self, + py: Python<'_>, + key: &[u8], + value: &[u8], + request_flags: Option<&RequestFlags>, + ) -> PyResult<()> { + let cmd = self + .build_cmd(b"ms", key, Some(value.len() as u32), request_flags) + .ok_or_else(|| PyValueError::new_err("Key is too long or empty"))?; + let io = &mut self.io; + py.detach(|| io.send_cmd_with_value(&cmd.buf, value, cmd.no_reply)) + .map_err(|e| socket_err_io("Error sending meta set", e))?; + Ok(()) + } + + /// Send a meta delete command. Use get_response() to read the result later. + /// If no_reply is set, automatically appends a NOOP command. + #[pyo3(signature = (key, request_flags=None))] + pub fn send_meta_delete( + &mut self, + py: Python<'_>, + key: &[u8], + request_flags: Option<&RequestFlags>, + ) -> PyResult<()> { + let cmd = self + .build_cmd(b"md", key, None, request_flags) + .ok_or_else(|| PyValueError::new_err("Key is too long or empty"))?; let io = &mut self.io; + py.detach(|| io.send_cmd(&cmd.buf, cmd.no_reply)) + .map_err(|e| socket_err_io("Error sending meta delete", e))?; + Ok(()) + } - // Phase 1: recv data without GIL - let location = py - .detach(|| io.ensure_value(size_usize)) - .map_err(|e| socket_err_io("Error receiving value", e))?; - - // Phase 2: create PyBytes with GIL - match location { - ValueData::InBuffer => { - // Common path: value is in io.buf — create PyBytes directly, no extra alloc - let data_start = self.io.pos; - let data_end = data_start + size_usize; - let result = PyBytes::new(py, &self.io.buf[data_start..data_end]); - self.io.pos = data_end + ENDL_LEN; - Ok(result) - } - ValueData::Allocated(data) => { - // Large value path: data already in Vec, create PyBytes from it - Ok(PyBytes::new(py, &data)) - } + /// Send a meta arithmetic command. Use get_response() to read the result later. + /// If no_reply is set, automatically appends a NOOP command. + #[pyo3(signature = (key, request_flags=None))] + pub fn send_meta_arithmetic( + &mut self, + py: Python<'_>, + key: &[u8], + request_flags: Option<&RequestFlags>, + ) -> PyResult<()> { + let cmd = self + .build_cmd(b"ma", key, None, request_flags) + .ok_or_else(|| PyValueError::new_err("Key is too long or empty"))?; + let io = &mut self.io; + py.detach(|| io.send_cmd(&cmd.buf, cmd.no_reply)) + .map_err(|e| socket_err_io("Error sending meta arithmetic", e))?; + Ok(()) + } + + // ----------------------------------------------------------------------- + // Tier 2: meta_* (blocking — send + recv in one call) + // ----------------------------------------------------------------------- + + /// Send a meta get command and return the response. + /// The entire send + recv happens in a single GIL-released block. + #[pyo3(signature = (key, request_flags=None))] + pub fn meta_get( + &mut self, + py: Python<'_>, + key: &[u8], + request_flags: Option<&RequestFlags>, + ) -> PyResult> { + let cmd = self + .build_cmd(b"mg", key, None, request_flags) + .ok_or_else(|| PyValueError::new_err("Key is too long or empty"))?; + if cmd.no_reply { + return Err(socket_err( + "internal error: build_cmd produced no_reply=true for mg command", + )); + } + let io = &mut self.io; + let (header, value_data) = py + .detach(|| { + io.send_cmd(&cmd.buf, false)?; + io.get_response_with_value() + }) + .map_err(|e| socket_err_io("Error in meta_get", e))?; + self.make_response(py, header, value_data) + } + + /// Send a meta set command with value and return the response. + /// For no_reply commands, sends with NOOP and returns Success immediately. + /// Otherwise, the entire send + recv happens in a single GIL-released block. + #[pyo3(signature = (key, value, request_flags=None))] + pub fn meta_set( + &mut self, + py: Python<'_>, + key: &[u8], + value: &[u8], + request_flags: Option<&RequestFlags>, + ) -> PyResult> { + let cmd = self + .build_cmd(b"ms", key, Some(value.len() as u32), request_flags) + .ok_or_else(|| PyValueError::new_err("Key is too long or empty"))?; + let io = &mut self.io; + let result = py + .detach(|| { + io.send_cmd_with_value(&cmd.buf, value, cmd.no_reply)?; + if cmd.no_reply { + Ok(CmdResult::NoReply) + } else { + Ok(CmdResult::Response(io.get_response_with_value()?)) + } + }) + .map_err(|e| socket_err_io("Error in meta_set", e))?; + match result { + CmdResult::NoReply => Self::success_no_reply(py), + CmdResult::Response((header, value_data)) => self.make_response(py, header, value_data), + } + } + + /// Send a meta delete command and return the response. + /// For no_reply commands, sends with NOOP and returns Success immediately. + #[pyo3(signature = (key, request_flags=None))] + pub fn meta_delete( + &mut self, + py: Python<'_>, + key: &[u8], + request_flags: Option<&RequestFlags>, + ) -> PyResult> { + let cmd = self + .build_cmd(b"md", key, None, request_flags) + .ok_or_else(|| PyValueError::new_err("Key is too long or empty"))?; + let io = &mut self.io; + let result = py + .detach(|| { + io.send_cmd(&cmd.buf, cmd.no_reply)?; + if cmd.no_reply { + Ok(CmdResult::NoReply) + } else { + Ok(CmdResult::Response(io.get_response_with_value()?)) + } + }) + .map_err(|e| socket_err_io("Error in meta_delete", e))?; + match result { + CmdResult::NoReply => Self::success_no_reply(py), + CmdResult::Response((header, value_data)) => self.make_response(py, header, value_data), + } + } + + /// Send a meta arithmetic command and return the response. + /// For no_reply commands, sends with NOOP and returns Success immediately. + #[pyo3(signature = (key, request_flags=None))] + pub fn meta_arithmetic( + &mut self, + py: Python<'_>, + key: &[u8], + request_flags: Option<&RequestFlags>, + ) -> PyResult> { + let cmd = self + .build_cmd(b"ma", key, None, request_flags) + .ok_or_else(|| PyValueError::new_err("Key is too long or empty"))?; + let io = &mut self.io; + let result = py + .detach(|| { + io.send_cmd(&cmd.buf, cmd.no_reply)?; + if cmd.no_reply { + Ok(CmdResult::NoReply) + } else { + Ok(CmdResult::Response(io.get_response_with_value()?)) + } + }) + .map_err(|e| socket_err_io("Error in meta_arithmetic", e))?; + match result { + CmdResult::NoReply => Self::success_no_reply(py), + CmdResult::Response((header, value_data)) => self.make_response(py, header, value_data), } } } diff --git a/src/request_flags.rs b/src/request_flags.rs index 8d2f6f6..4f129c6 100644 --- a/src/request_flags.rs +++ b/src/request_flags.rs @@ -49,9 +49,16 @@ pub struct RequestFlags { } impl RequestFlags { - pub fn push_bytes(&self, buf: &mut Vec) { + /// Check if the no_reply flag is set (crate-internal use). + pub(crate) fn is_no_reply(&self) -> bool { + self.no_reply + } + + pub fn push_bytes(&self, buf: &mut Vec, allow_no_reply_flag: bool) { let mut itoa_buf = itoa::Buffer::new(); - if self.no_reply { + // allow_no_reply_flag controls whether the wire-level `q` flag is emitted + // for no_reply flag. + if allow_no_reply_flag && self.no_reply { buf.push(b' '); buf.push(b'q'); } @@ -276,7 +283,7 @@ impl RequestFlags { pub fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> { let mut flags: Vec = Vec::with_capacity(64); - self.push_bytes(&mut flags); + self.push_bytes(&mut flags, /* allow_no_reply_flag */ true); PyBytes::new(py, &flags) } } diff --git a/src/request_flags_tests.rs b/src/request_flags_tests.rs index d7e7fdb..a587360 100644 --- a/src/request_flags_tests.rs +++ b/src/request_flags_tests.rs @@ -12,7 +12,7 @@ mod tests { fn push_to_vec(flags: &RequestFlags) -> Vec { let mut buf = Vec::new(); - flags.push_bytes(&mut buf); + flags.push_bytes(&mut buf, /* allow_no_reply_flag */ true); buf } diff --git a/tests/test_memcache_socket.py b/tests/test_memcache_socket.py index cf64439..efed415 100644 --- a/tests/test_memcache_socket.py +++ b/tests/test_memcache_socket.py @@ -13,6 +13,7 @@ MemcacheSocket, Miss, NotStored, + RequestFlags, ResponseFlags, Success, Value, @@ -208,11 +209,12 @@ def test_hd_stale(self, socket_pair): assert resp.flags.stale is True -# --- get_response: Value --- +# --- get_response: Value (now includes value bytes) --- class TestGetResponseValue: - def test_value_response(self, socket_pair): + def test_value_response_includes_bytes(self, socket_pair): + """get_response() now reads value bytes automatically.""" a, b = socket_pair ms = MemcacheSocket(a) b.sendall(b"VA 2 c1\r\nOK\r\n") @@ -220,7 +222,7 @@ def test_value_response(self, socket_pair): assert isinstance(resp, Value) assert resp.size == 2 assert resp.flags.cas_token == 1 - assert resp.value is None # Not yet read + assert resp.value == b"OK" def test_value_with_all_flags(self, socket_pair): a, b = socket_pair @@ -229,6 +231,7 @@ def test_value_with_all_flags(self, socket_pair): resp = ms.get_response() assert isinstance(resp, Value) assert resp.size == 3 + assert resp.value == b"foo" assert resp.flags.cas_token == 999 assert resp.flags.fetched is False assert resp.flags.last_access == 60 @@ -244,51 +247,10 @@ def test_value_stale_and_lost(self, socket_pair): b.sendall(b"VA 1 X Z\r\nx\r\n") resp = ms.get_response() assert isinstance(resp, Value) + assert resp.value == b"x" assert resp.flags.stale is True assert resp.flags.win is False - -# --- get_response: server version 1.6.6 --- - - -class TestGetResponse166: - def test_ok_response_as_success(self, socket_pair): - a, b = socket_pair - ms = MemcacheSocket(a, version=SERVER_VERSION_AWS_1_6_6) - b.sendall(b"OK c1\r\n") - resp = ms.get_response() - assert isinstance(resp, Success) - assert resp.flags.cas_token == 1 - - def test_value_and_ok(self, socket_pair): - a, b = socket_pair - ms = MemcacheSocket(a, version=SERVER_VERSION_AWS_1_6_6) - b.sendall(b"VA 2 c1\r\nOK\r\nOK c2\r\n") - resp = ms.get_response() - assert isinstance(resp, Value) - assert resp.size == 2 - val = ms.get_value(resp.size) - assert val == b"OK" - - resp2 = ms.get_response() - assert isinstance(resp2, Success) - assert resp2.flags.cas_token == 2 - - -# --- get_value --- - - -class TestGetValue: - def test_small_value(self, socket_pair): - a, b = socket_pair - ms = MemcacheSocket(a) - b.sendall(b"VA 5\r\nhello\r\n") - resp = ms.get_response() - assert isinstance(resp, Value) - val = ms.get_value(resp.size) - assert val == b"hello" - assert isinstance(val, bytes) - def test_empty_value(self, socket_pair): a, b = socket_pair ms = MemcacheSocket(a) @@ -296,8 +258,7 @@ def test_empty_value(self, socket_pair): resp = ms.get_response() assert isinstance(resp, Value) assert resp.size == 0 - val = ms.get_value(resp.size) - assert val == b"" + assert resp.value == b"" def test_large_value_exceeding_buffer(self, socket_pair): """Value larger than the buffer size triggers temporary allocation.""" @@ -311,9 +272,8 @@ def test_large_value_exceeding_buffer(self, socket_pair): assert resp.flags.cas_token == 1 assert resp.flags.win is True assert bytes(resp.flags.opaque) == b"xxx" - val = ms.get_value(resp.size) - assert len(val) == 200 - assert val == payload + assert len(resp.value) == 200 + assert resp.value == payload def test_value_with_incomplete_endl(self, socket_pair): """Buffer is just big enough for value but ENDL splits across reads.""" @@ -323,28 +283,17 @@ def test_value_with_incomplete_endl(self, socket_pair): resp = ms.get_response() assert isinstance(resp, Value) assert resp.size == 10 - val = ms.get_value(resp.size) - assert val == b"1234567890" + assert resp.value == b"1234567890" def test_value_with_incomplete_endl_then_response(self, socket_pair): - """After ENDL-split value read, buffer state must allow further responses. - - Regression test: the ENDL-split path in ensure_value consumed ENDL bytes - from the socket but not from the buffer. The caller then advanced pos past - read, corrupting buffer state for subsequent operations. - """ + """After ENDL-split value read, buffer state must allow further responses.""" a, b = socket_pair - # buffer_size=18: header "VA 10\r\n" is 7 bytes, value is 10 bytes, - # so value fills the buffer and \r\n splits across reads. ms = MemcacheSocket(a, buffer_size=18) b.sendall(b"VA 10\r\n1234567890\r\nEN\r\n") resp = ms.get_response() assert isinstance(resp, Value) - val = ms.get_value(resp.size) - assert val == b"1234567890" + assert resp.value == b"1234567890" - # This second get_response would panic/corrupt without the fix, - # because pos > read after the ENDL-split path. resp2 = ms.get_response() assert isinstance(resp2, Miss) @@ -355,25 +304,22 @@ def test_value_with_incomplete_endl_then_value(self, socket_pair): b.sendall(b"VA 10\r\n1234567890\r\nVA 3\r\nfoo\r\n") resp = ms.get_response() assert isinstance(resp, Value) - assert ms.get_value(resp.size) == b"1234567890" + assert resp.value == b"1234567890" resp2 = ms.get_response() assert isinstance(resp2, Value) - assert ms.get_value(resp2.size) == b"foo" + assert resp2.value == b"foo" def test_multiple_values(self, socket_pair): a, b = socket_pair ms = MemcacheSocket(a) b.sendall( - b"VA 3\r\nfoo\r\n" - b"VA 3\r\nbar\r\n" - b"VA 3\r\nbaz\r\n" + b"VA 3\r\nfoo\r\n" b"VA 3\r\nbar\r\n" b"VA 3\r\nbaz\r\n" ) for expected in [b"foo", b"bar", b"baz"]: resp = ms.get_response() assert isinstance(resp, Value) - val = ms.get_value(resp.size) - assert val == expected + assert resp.value == expected def test_value_then_miss(self, socket_pair): """Read a value, then a simple response.""" @@ -382,8 +328,7 @@ def test_value_then_miss(self, socket_pair): b.sendall(b"VA 5 f1\r\nhello\r\nEN\r\n") resp1 = ms.get_response() assert isinstance(resp1, Value) - val = ms.get_value(resp1.size) - assert val == b"hello" + assert resp1.value == b"hello" resp2 = ms.get_response() assert isinstance(resp2, Miss) @@ -401,7 +346,7 @@ def test_interleaved_responses(self, socket_pair): # Value r = ms.get_response() assert isinstance(r, Value) - assert ms.get_value(r.size) == b"hi" + assert r.value == b"hi" # Success r = ms.get_response() assert isinstance(r, Success) @@ -412,7 +357,33 @@ def test_interleaved_responses(self, socket_pair): # Value r = ms.get_response() assert isinstance(r, Value) - assert ms.get_value(r.size) == b"bye" + assert r.value == b"bye" + + +# --- get_response: server version 1.6.6 --- + + +class TestGetResponse166: + def test_ok_response_as_success(self, socket_pair): + a, b = socket_pair + ms = MemcacheSocket(a, version=SERVER_VERSION_AWS_1_6_6) + b.sendall(b"OK c1\r\n") + resp = ms.get_response() + assert isinstance(resp, Success) + assert resp.flags.cas_token == 1 + + def test_value_and_ok(self, socket_pair): + a, b = socket_pair + ms = MemcacheSocket(a, version=SERVER_VERSION_AWS_1_6_6) + b.sendall(b"VA 2 c1\r\nOK\r\nOK c2\r\n") + resp = ms.get_response() + assert isinstance(resp, Value) + assert resp.size == 2 + assert resp.value == b"OK" + + resp2 = ms.get_response() + assert isinstance(resp2, Success) + assert resp2.flags.cas_token == 2 # --- NOOP handling --- @@ -476,16 +447,6 @@ def test_closed_socket_on_get_response(self, socket_pair): with pytest.raises(ConnectionError): ms.get_response() - def test_closed_socket_on_get_value(self, socket_pair): - a, b = socket_pair - ms = MemcacheSocket(a) - b.sendall(b"VA 100\r\n") # Claim 100 bytes but close - resp = ms.get_response() - assert isinstance(resp, Value) - b.close() - with pytest.raises(ConnectionError): - ms.get_value(resp.size) - def test_closed_socket_on_sendall(self, socket_pair): a, b = socket_pair ms = MemcacheSocket(a) @@ -563,8 +524,7 @@ def test_small_buffer_values(self, socket_pair): b.sendall(b"VA 5\r\nhello\r\n") resp = ms.get_response() assert isinstance(resp, Value) - val = ms.get_value(resp.size) - assert val == b"hello" + assert resp.value == b"hello" def test_responses_spanning_buffer_boundary(self, socket_pair): """Header that arrives across multiple recv calls.""" @@ -577,19 +537,264 @@ def test_responses_spanning_buffer_boundary(self, socket_pair): resp = ms.get_response() assert isinstance(resp, Value) assert resp.flags.cas_token == 42 - val = ms.get_value(resp.size) - assert val == b"foo" + assert resp.value == b"foo" -# --- Version constants --- +# --- send_meta_* (Tier 1: pipelining) --- + + +class TestSendMeta: + def test_send_meta_get(self, socket_pair): + a, b = socket_pair + ms = MemcacheSocket(a) + flags = RequestFlags(return_cas_token=True, cache_ttl=300) + ms.send_meta_get(b"mykey", flags) + data = b.recv(1024) + assert data == b"mg mykey c T300\r\n" + + def test_send_meta_get_no_flags(self, socket_pair): + a, b = socket_pair + ms = MemcacheSocket(a) + ms.send_meta_get(b"mykey") + data = b.recv(1024) + assert data == b"mg mykey\r\n" + + def test_send_meta_set(self, socket_pair): + a, b = socket_pair + ms = MemcacheSocket(a) + flags = RequestFlags(cache_ttl=300, client_flag=0) + ms.send_meta_set(b"mykey", b"hello", flags) + data = b.recv(1024) + assert data == b"ms mykey 5 T300 F0\r\nhello\r\n" + + def test_send_meta_set_with_noop(self, socket_pair): + a, b = socket_pair + ms = MemcacheSocket(a) + ms.send_meta_set(b"mykey", b"hello") + data_without = b.recv(1024) + assert data_without == b"ms mykey 5\r\nhello\r\n" + + def test_send_meta_set_with_noop_flag(self, socket_pair): + """no_reply flag auto-injects NOOP.""" + a, b = socket_pair + ms = MemcacheSocket(a) + flags = RequestFlags(no_reply=True) + ms.send_meta_set(b"mykey", b"hi", flags) + data = b.recv(1024) + assert data == b"ms mykey 2 q\r\nhi\r\nmn\r\n" + + def test_send_meta_set_empty_value(self, socket_pair): + a, b = socket_pair + ms = MemcacheSocket(a) + ms.send_meta_set(b"mykey", b"") + data = b.recv(1024) + assert data == b"ms mykey 0\r\n\r\n" + + def test_send_meta_delete(self, socket_pair): + a, b = socket_pair + ms = MemcacheSocket(a) + flags = RequestFlags(cache_ttl=300) + ms.send_meta_delete(b"mykey", flags) + data = b.recv(1024) + assert data == b"md mykey T300\r\n" + + def test_send_meta_arithmetic(self, socket_pair): + a, b = socket_pair + ms = MemcacheSocket(a) + flags = RequestFlags(ma_delta_value=5) + ms.send_meta_arithmetic(b"mykey", flags) + data = b.recv(1024) + assert data == b"ma mykey D5\r\n" + + def test_send_meta_get_invalid_key(self, socket_pair): + a, b = socket_pair + ms = MemcacheSocket(a) + with pytest.raises(ValueError): + ms.send_meta_get(b"") + with pytest.raises(ValueError): + ms.send_meta_get(b"x" * 250) + + def test_pipeline_send_then_recv(self, socket_pair): + """Full pipeline: send multiple, then recv multiple.""" + a, b = socket_pair + ms = MemcacheSocket(a) + flags = RequestFlags(return_cas_token=True) + + # Send 3 gets + ms.send_meta_get(b"key1", flags) + ms.send_meta_get(b"key2", flags) + ms.send_meta_get(b"key3", flags) + + # Server responds + b.sendall( + b"VA 3 c1\r\nfoo\r\n" + b"EN\r\n" + b"VA 3 c3\r\nbar\r\n" + ) + + r1 = ms.get_response() + assert isinstance(r1, Value) + assert r1.value == b"foo" + assert r1.flags.cas_token == 1 + + r2 = ms.get_response() + assert isinstance(r2, Miss) + + r3 = ms.get_response() + assert isinstance(r3, Value) + assert r3.value == b"bar" + assert r3.flags.cas_token == 3 + + +# --- meta_* (Tier 3: blocking) --- + + +class TestMetaBlocking: + def test_meta_get_miss(self, socket_pair): + a, b = socket_pair + ms = MemcacheSocket(a) + # Server responds with miss after receiving get + b.sendall(b"EN\r\n") + resp = ms.meta_get(b"mykey") + assert isinstance(resp, Miss) + + def test_meta_get_value(self, socket_pair): + a, b = socket_pair + ms = MemcacheSocket(a) + flags = RequestFlags(return_cas_token=True, return_value=True) + b.sendall(b"VA 5 c42\r\nhello\r\n") + resp = ms.meta_get(b"mykey", flags) + assert isinstance(resp, Value) + assert resp.value == b"hello" + assert resp.flags.cas_token == 42 + + def test_meta_get_verifies_wire(self, socket_pair): + """meta_get sends the correct wire format.""" + a, b = socket_pair + ms = MemcacheSocket(a) + flags = RequestFlags(return_cas_token=True, cache_ttl=300) + # Need to respond so the blocking call doesn't hang + b.sendall(b"EN\r\n") + ms.meta_get(b"testkey", flags) + # Check what was sent + data = b.recv(1024) + assert data == b"mg testkey c T300\r\n" + + def test_meta_set_success(self, socket_pair): + a, b = socket_pair + ms = MemcacheSocket(a) + flags = RequestFlags(cache_ttl=300, client_flag=0) + b.sendall(b"HD c1\r\n") + resp = ms.meta_set(b"mykey", b"hello", flags) + assert isinstance(resp, Success) + assert resp.flags.cas_token == 1 + # Check wire format + data = b.recv(1024) + assert data == b"ms mykey 5 T300 F0\r\nhello\r\n" + + def test_meta_set_no_reply(self, socket_pair): + a, b = socket_pair + ms = MemcacheSocket(a) + flags = RequestFlags(no_reply=True, cache_ttl=300) + resp = ms.meta_set(b"mykey", b"hello", flags) + assert isinstance(resp, Success) + # Check wire format includes noop + data = b.recv(1024) + assert data == b"ms mykey 5 q T300\r\nhello\r\nmn\r\n" + + def test_meta_set_not_stored(self, socket_pair): + a, b = socket_pair + ms = MemcacheSocket(a) + b.sendall(b"NS\r\n") + resp = ms.meta_set(b"mykey", b"hello") + assert isinstance(resp, NotStored) + + def test_meta_delete_success(self, socket_pair): + a, b = socket_pair + ms = MemcacheSocket(a) + b.sendall(b"HD\r\n") + resp = ms.meta_delete(b"mykey") + assert isinstance(resp, Success) + data = b.recv(1024) + assert data == b"md mykey\r\n" + + def test_meta_delete_no_reply(self, socket_pair): + a, b = socket_pair + ms = MemcacheSocket(a) + flags = RequestFlags(no_reply=True) + resp = ms.meta_delete(b"mykey", flags) + assert isinstance(resp, Success) + data = b.recv(1024) + assert data == b"md mykey q\r\nmn\r\n" + + def test_meta_delete_miss(self, socket_pair): + a, b = socket_pair + ms = MemcacheSocket(a) + b.sendall(b"NF\r\n") + resp = ms.meta_delete(b"mykey") + assert isinstance(resp, Miss) + + def test_meta_arithmetic_success(self, socket_pair): + a, b = socket_pair + ms = MemcacheSocket(a) + flags = RequestFlags(ma_delta_value=5, return_value=True) + b.sendall(b"VA 2\r\n10\r\n") + resp = ms.meta_arithmetic(b"counter", flags) + assert isinstance(resp, Value) + assert resp.value == b"10" + data = b.recv(1024) + assert data == b"ma counter v D5\r\n" + + def test_meta_arithmetic_no_reply(self, socket_pair): + a, b = socket_pair + ms = MemcacheSocket(a) + flags = RequestFlags(no_reply=True, ma_delta_value=1) + resp = ms.meta_arithmetic(b"counter", flags) + assert isinstance(resp, Success) + data = b.recv(1024) + assert data == b"ma counter q D1\r\nmn\r\n" + + def test_meta_get_invalid_key(self, socket_pair): + a, b = socket_pair + ms = MemcacheSocket(a) + with pytest.raises(ValueError): + ms.meta_get(b"") + + def test_meta_set_legacy_version(self, socket_pair): + """AWS 1.6.6 uses legacy size format with S prefix.""" + a, b = socket_pair + ms = MemcacheSocket(a, version=SERVER_VERSION_AWS_1_6_6) + b.sendall(b"HD\r\n") + ms.meta_set(b"mykey", b"hello") + data = b.recv(1024) + assert data == b"ms mykey S5\r\nhello\r\n" + + def test_meta_no_reply_then_regular(self, socket_pair): + """no_reply command followed by regular command should work + (noop draining is handled correctly).""" + a, b = socket_pair + ms = MemcacheSocket(a) + # no_reply delete + flags_noreply = RequestFlags(no_reply=True) + resp1 = ms.meta_delete(b"key1", flags_noreply) + assert isinstance(resp1, Success) + + # Regular get — server sends noop response (from delete), then miss + b.sendall(b"MN\r\nEN\r\n") + resp2 = ms.meta_get(b"key2") + # The MN should be drained, and we get the EN (Miss) + assert isinstance(resp2, Miss) + + +# --- Non-blocking sockets --- class TestNonBlockingSocket: - """Test with sockets in non-blocking mode (settimeout), matching Python's socket_factory_builder.""" + """Test with sockets in non-blocking mode (settimeout).""" def test_settimeout_get_response(self, socket_pair): a, b = socket_pair - a.settimeout(5.0) # Puts socket in non-blocking mode with timeout + a.settimeout(5.0) ms = MemcacheSocket(a) b.sendall(b"EN\r\n") @@ -604,8 +809,7 @@ def test_settimeout_get_value(self, socket_pair): b.sendall(b"VA 5\r\nhello\r\n") resp = ms.get_response() assert isinstance(resp, Value) - val = ms.get_value(resp.size) - assert val == b"hello" + assert resp.value == b"hello" def test_settimeout_large_value(self, socket_pair): a, b = socket_pair @@ -616,8 +820,7 @@ def test_settimeout_large_value(self, socket_pair): b.sendall(b"VA 500\r\n" + payload + b"\r\n") resp = ms.get_response() assert isinstance(resp, Value) - val = ms.get_value(resp.size) - assert val == payload + assert resp.value == payload def test_settimeout_sendall(self, socket_pair): a, b = socket_pair @@ -644,15 +847,15 @@ def test_settimeout_pipeline(self, socket_pair): ms = MemcacheSocket(a) # Send two commands - ms.sendall(b"mg key1\r\n", False) - ms.sendall(b"mg key2\r\n", False) + ms.send_meta_get(b"key1") + ms.send_meta_get(b"key2") # Server responds b.sendall(b"VA 3 f1\r\nfoo\r\nEN\r\n") r1 = ms.get_response() assert isinstance(r1, Value) - assert ms.get_value(r1.size) == b"foo" + assert r1.value == b"foo" r2 = ms.get_response() assert isinstance(r2, Miss) @@ -668,6 +871,18 @@ def test_settimeout_noop(self, socket_pair): assert isinstance(resp, Success) assert resp.flags.cas_token == 1 + def test_settimeout_meta_blocking(self, socket_pair): + """Blocking meta_* with non-blocking sockets.""" + a, b = socket_pair + a.settimeout(5.0) + ms = MemcacheSocket(a) + + b.sendall(b"VA 5 c1\r\nhello\r\n") + resp = ms.meta_get(b"mykey", RequestFlags(return_cas_token=True, return_value=True)) + assert isinstance(resp, Value) + assert resp.value == b"hello" + assert resp.flags.cas_token == 1 + class TestSocketTimeout: """Test that Python socket timeouts are respected by the Rust implementation.""" @@ -682,22 +897,6 @@ def test_get_response_timeout(self, socket_pair): with pytest.raises(TimeoutError): ms.get_response() - def test_get_value_timeout(self, socket_pair): - """get_value() should raise TimeoutError when value data doesn't arrive.""" - a, b = socket_pair - a.settimeout(0.1) - ms = MemcacheSocket(a) - - # Send header but not the value data - b.sendall(b"VA 100\r\n") - resp = ms.get_response() - assert isinstance(resp, Value) - assert resp.size == 100 - - # Value data never arrives — should timeout - with pytest.raises((TimeoutError, ConnectionError)): - ms.get_value(resp.size) - def test_sendall_timeout(self, socket_pair): """sendall() should raise TimeoutError when send buffer is full.""" a, b = socket_pair @@ -705,7 +904,6 @@ def test_sendall_timeout(self, socket_pair): ms = MemcacheSocket(a) # Fill the send buffer until it blocks, then expect timeout. - # Use a large payload to overwhelm the socket buffer. big_data = b"x" * (1024 * 1024 * 10) # 10MB with pytest.raises((TimeoutError, ConnectionError)): for _ in range(100): @@ -714,7 +912,6 @@ def test_sendall_timeout(self, socket_pair): def test_blocking_socket_no_timeout(self, socket_pair): """Blocking socket (no settimeout) should not have poll timeout issues.""" a, b = socket_pair - # No settimeout — socket is blocking, timeout should be -1 (infinite) ms = MemcacheSocket(a) b.sendall(b"EN\r\n") @@ -746,6 +943,5 @@ def test_constants_values(self): def test_version_matches_intenum(self): """ServerVersion IntEnum values match Rust constants.""" - # These must match for backward compatibility assert SERVER_VERSION_AWS_1_6_6 == 1 assert SERVER_VERSION_STABLE == 2