diff --git a/Cargo.lock b/Cargo.lock index 304d783..fcf9272 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -113,7 +113,7 @@ checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" [[package]] name = "meta-memcache-socket" -version = "0.2.0" +version = "0.2.1" dependencies = [ "atoi", "base64", diff --git a/Cargo.toml b/Cargo.toml index cec2e94..e8b6b42 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "meta-memcache-socket" -version = "0.2.0" +version = "0.2.1" edition = "2024" [lib] diff --git a/meta_memcache_socket.pyi b/meta_memcache_socket.pyi index f5f962d..d66b9d0 100644 --- a/meta_memcache_socket.pyi +++ b/meta_memcache_socket.pyi @@ -367,3 +367,21 @@ class MemcacheSocket: :param request_flags: The flags to use for all keys """ ... + + # Raw command passthrough + def raw_cmd( + self, + cmd: bytes, + multi_line: bool = False, + ) -> bytes: + """ + Send a raw command and return the raw response bytes. + + Appends \\r\\n to the command if not already present. + If multi_line is False, reads until \\r\\n and returns the line. + If multi_line is True, reads until END\\r\\n and returns everything before it. + + :param cmd: The raw command bytes to send + :param multi_line: Whether to expect a multi-line response (terminated by END\\r\\n) + """ + ... diff --git a/src/memcache_socket.rs b/src/memcache_socket.rs index f382b33..a1961b0 100644 --- a/src/memcache_socket.rs +++ b/src/memcache_socket.rs @@ -1,7 +1,7 @@ use std::os::fd::RawFd; use log::warn; - +use memchr::memmem; use pyo3::BoundObject; use pyo3::exceptions::{PyConnectionError, PyTimeoutError, PyValueError}; use pyo3::prelude::*; @@ -875,4 +875,82 @@ impl MemcacheSocket { } Ok(results) } + + // ----------------------------------------------------------------------- + // Raw command passthrough + // ----------------------------------------------------------------------- + + /// Send a raw command and return the raw response bytes. + /// + /// Appends \r\n to the command if not already present. + /// If `multi_line` is false, reads until \r\n and returns the line. + /// If `multi_line` is true, reads until END\r\n and returns everything before it. + /// Uses a separate buffer to avoid disturbing the main I/O state. + /// Releases the GIL during socket I/O. + #[pyo3(signature = (cmd, multi_line=false))] + pub fn raw_cmd<'py>( + &self, + py: Python<'py>, + cmd: &[u8], + multi_line: bool, + ) -> PyResult> { + // Build command with \r\n if needed + let cmd_bytes = if cmd.ends_with(b"\r\n") { + cmd.to_vec() + } else { + let mut buf = Vec::with_capacity(cmd.len() + 2); + buf.extend_from_slice(cmd); + buf.extend_from_slice(b"\r\n"); + buf + }; + + let fd = self.io.fd; + let timeout_ms = self.io.timeout_ms; + + let response = py + .detach(|| { + send_all(fd, &cmd_bytes, timeout_ms)?; + raw_recv(fd, timeout_ms, multi_line) + }) + .map_err(|e| socket_err_io("Error in raw_cmd", e))?; + + Ok(PyBytes::new(py, &response)) + } +} + +/// Receive a raw response into a standalone buffer (not the main SocketIO buffer). +/// For single-line: read until \r\n, return everything before it. +/// For multi-line: read until END\r\n, return everything before it. +fn raw_recv( + fd: RawFd, + timeout_ms: libc::c_int, + multi_line: bool, +) -> Result, std::io::Error> { + let mut buf = Vec::with_capacity(1024); + let mut tmp = [0u8; 4096]; + + loop { + let n = recv_into(fd, &mut tmp, timeout_ms)?; + if n == 0 { + return Err(std::io::Error::new( + std::io::ErrorKind::ConnectionAborted, + "Connection closed during raw_recv", + )); + } + buf.extend_from_slice(&tmp[..n]); + + if multi_line { + // Look for END\r\n — can appear at start of a line + if let Some(pos) = memmem::find(&buf, b"END\r\n") { + buf.truncate(pos); + return Ok(buf); + } + } else { + // Look for first \r\n + if let Some(pos) = memmem::find(&buf, b"\r\n") { + buf.truncate(pos); + return Ok(buf); + } + } + } } diff --git a/tests/test_memcache_socket.py b/tests/test_memcache_socket.py index 879a7ba..4bae550 100644 --- a/tests/test_memcache_socket.py +++ b/tests/test_memcache_socket.py @@ -1237,6 +1237,87 @@ def test_unicode_str_key_meta_set(self, socket_pair): assert data == b"ms " + expected_b64 + b" 3 b\r\nval\r\n" +# --- raw_cmd --- + + +class TestRawCmd: + def test_single_line_version(self, socket_pair): + """Typical single-line command like 'version'.""" + a, b = socket_pair + ms = MemcacheSocket(a) + b.sendall(b"VERSION 1.6.22\r\n") + result = ms.raw_cmd(b"version") + assert result == b"VERSION 1.6.22" + # Verify \r\n was appended + data = b.recv(1024) + assert data == b"version\r\n" + + def test_single_line_already_has_endl(self, socket_pair): + """Command already ending with \\r\\n should not get doubled.""" + a, b = socket_pair + ms = MemcacheSocket(a) + b.sendall(b"OK\r\n") + result = ms.raw_cmd(b"flush_all\r\n") + assert result == b"OK" + data = b.recv(1024) + assert data == b"flush_all\r\n" + + def test_multi_line_stats(self, socket_pair): + """Multi-line response like 'stats'.""" + a, b = socket_pair + ms = MemcacheSocket(a) + b.sendall( + b"STAT pid 12345\r\n" + b"STAT uptime 1000\r\n" + b"STAT version 1.6.22\r\n" + b"END\r\n" + ) + result = ms.raw_cmd(b"stats", multi_line=True) + assert result == ( + b"STAT pid 12345\r\n" + b"STAT uptime 1000\r\n" + b"STAT version 1.6.22\r\n" + ) + + def test_multi_line_empty(self, socket_pair): + """Multi-line response with no content before END.""" + a, b = socket_pair + ms = MemcacheSocket(a) + b.sendall(b"END\r\n") + result = ms.raw_cmd(b"stats slabs", multi_line=True) + assert result == b"" + + def test_single_line_empty_response(self, socket_pair): + """Server returns just \\r\\n.""" + a, b = socket_pair + ms = MemcacheSocket(a) + b.sendall(b"\r\n") + result = ms.raw_cmd(b"test") + assert result == b"" + + def test_nonblocking_socket(self, socket_pair): + a, b = socket_pair + a.settimeout(5.0) + ms = MemcacheSocket(a) + b.sendall(b"VERSION 1.6.22\r\n") + result = ms.raw_cmd(b"version") + assert result == b"VERSION 1.6.22" + + def test_does_not_affect_main_buffer(self, socket_pair): + """raw_cmd should not disturb main I/O state for subsequent meta commands.""" + a, b = socket_pair + ms = MemcacheSocket(a) + # raw command + b.sendall(b"VERSION 1.6.22\r\n") + ms.raw_cmd(b"version") + # meta get should still work + b.sendall(b"VA 3 c1\r\nfoo\r\n") + resp = ms.meta_get(b"mykey", RequestFlags(return_cas_token=True)) + assert isinstance(resp, Value) + assert resp.value == b"foo" + assert resp.flags.cas_token == 1 + + class TestVersionConstants: def test_constants_values(self): assert SERVER_VERSION_AWS_1_6_6 == 1