From 119e70e1019f897a59ffed87d155b5843fba8a1b Mon Sep 17 00:00:00 2001 From: Guillermo Perez Date: Wed, 11 Mar 2026 11:05:17 +0100 Subject: [PATCH] socket in rust --- .github/workflows/CI.yml | 116 ++---- .github/workflows/tests.yml | 33 ++ Cargo.lock | 3 +- Cargo.toml | 3 +- README.md | 414 +++++++++++++++++++ meta_memcache_socket.pyi | 63 ++- pyproject.toml | 7 + src/constants.rs | 8 + src/lib.rs | 15 + src/memcache_socket.rs | 536 ++++++++++++++++++++++++ src/response_types.rs | 109 +++++ tests/test_memcache_socket.py | 751 ++++++++++++++++++++++++++++++++++ tests/test_response_types.py | 124 ++++++ 13 files changed, 2106 insertions(+), 76 deletions(-) create mode 100644 .github/workflows/tests.yml create mode 100644 README.md create mode 100644 src/memcache_socket.rs create mode 100644 src/response_types.rs create mode 100644 tests/test_memcache_socket.py create mode 100644 tests/test_response_types.py diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index c010bfb..f062318 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -1,3 +1,8 @@ +# This file is autogenerated by maturin v1.12.6 +# To update, run +# +# maturin generate-ci github --platform manylinux --platform musllinux --platform macos +# name: CI on: @@ -19,34 +24,32 @@ jobs: strategy: matrix: platform: - - runner: ubuntu-latest + - runner: ubuntu-22.04 target: x86_64 - - runner: ubuntu-latest + - runner: ubuntu-22.04 target: x86 - - runner: ubuntu-latest + - runner: ubuntu-22.04 target: aarch64 - - runner: ubuntu-latest + - runner: ubuntu-22.04 target: armv7 - - runner: ubuntu-latest + - runner: ubuntu-22.04 target: s390x - - runner: ubuntu-latest + - runner: ubuntu-22.04 target: ppc64le steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 with: - python-version: | - 3.x - 3.13t + python-version: 3.x - name: Build wheels uses: PyO3/maturin-action@v1 with: target: ${{ matrix.platform.target }} args: --release --out dist --find-interpreter - sccache: 'true' + sccache: ${{ !startsWith(github.ref, 'refs/tags/') }} manylinux: auto - name: Upload wheels - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: wheels-linux-${{ matrix.platform.target }} path: dist @@ -56,87 +59,54 @@ jobs: strategy: matrix: platform: - - runner: ubuntu-latest + - runner: ubuntu-22.04 target: x86_64 - - runner: ubuntu-latest + - runner: ubuntu-22.04 target: x86 - - runner: ubuntu-latest + - runner: ubuntu-22.04 target: aarch64 - - runner: ubuntu-latest + - runner: ubuntu-22.04 target: armv7 steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 with: - python-version: | - 3.x - 3.13t + python-version: 3.x - name: Build wheels uses: PyO3/maturin-action@v1 with: target: ${{ matrix.platform.target }} args: --release --out dist --find-interpreter - sccache: 'true' + sccache: ${{ !startsWith(github.ref, 'refs/tags/') }} manylinux: musllinux_1_2 - name: Upload wheels - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: wheels-musllinux-${{ matrix.platform.target }} path: dist - windows: - runs-on: ${{ matrix.platform.runner }} - strategy: - matrix: - platform: - - runner: windows-latest - target: x64 - - runner: windows-latest - target: x86 - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 - with: - python-version: | - 3.x - 3.13t - architecture: ${{ matrix.platform.target }} - - name: Build wheels - uses: PyO3/maturin-action@v1 - with: - target: ${{ matrix.platform.target }} - args: --release --out dist --find-interpreter - sccache: 'true' - - name: Upload wheels - uses: actions/upload-artifact@v4 - with: - name: wheels-windows-${{ matrix.platform.target }} - path: dist - macos: runs-on: ${{ matrix.platform.runner }} strategy: matrix: platform: - - runner: macos-12 + - runner: macos-15-intel target: x86_64 - - runner: macos-14 + - runner: macos-latest target: aarch64 steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 with: - python-version: | - 3.x - 3.13t + python-version: 3.x - name: Build wheels uses: PyO3/maturin-action@v1 with: target: ${{ matrix.platform.target }} args: --release --out dist --find-interpreter - sccache: 'true' + sccache: ${{ !startsWith(github.ref, 'refs/tags/') }} - name: Upload wheels - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: wheels-macos-${{ matrix.platform.target }} path: dist @@ -144,14 +114,14 @@ jobs: sdist: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Build sdist uses: PyO3/maturin-action@v1 with: command: sdist args: --out dist - name: Upload sdist - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: wheels-sdist path: dist @@ -160,7 +130,7 @@ jobs: name: Release runs-on: ubuntu-latest if: ${{ startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch' }} - needs: [linux, musllinux, windows, macos, sdist] + needs: [linux, musllinux, macos, sdist] permissions: # Use to sign the release artifacts id-token: write @@ -169,16 +139,16 @@ jobs: # Used to generate artifact attestation attestations: write steps: - - uses: actions/download-artifact@v4 + - uses: actions/download-artifact@v7 - name: Generate artifact attestation - uses: actions/attest-build-provenance@v1 + uses: actions/attest-build-provenance@v3 with: subject-path: 'wheels-*/*' + - name: Install uv + if: ${{ startsWith(github.ref, 'refs/tags/') }} + uses: astral-sh/setup-uv@v7 - name: Publish to PyPI - if: "startsWith(github.ref, 'refs/tags/')" - uses: PyO3/maturin-action@v1 + if: ${{ startsWith(github.ref, 'refs/tags/') }} + run: uv publish 'wheels-*/*' env: - MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }} - with: - command: upload - args: --non-interactive --skip-existing wheels-*/* + UV_PUBLISH_TOKEN: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..7abac70 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,33 @@ +name: Tests + +on: + push: + branches: + - main + - master + pull_request: + workflow_dispatch: + +permissions: + contents: read + +jobs: + rust-tests: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + - run: cargo test + + python-tests: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.x" + - uses: astral-sh/setup-uv@v4 + - name: Build and install extension + run: uv run --with maturin maturin develop + - name: Run Python tests + run: uv run --with pytest pytest tests/ -v diff --git a/Cargo.lock b/Cargo.lock index 5d1cfa6..6ed298b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -49,11 +49,12 @@ checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" [[package]] name = "meta-memcache-socket" -version = "0.1.7" +version = "2.0.0-alpha.1" dependencies = [ "atoi", "base64", "itoa", + "libc", "memchr", "pyo3", ] diff --git a/Cargo.toml b/Cargo.toml index 2ce9504..0066997 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "meta-memcache-socket" -version = "0.1.7" +version = "2.0.0-alpha.1" edition = "2024" [lib] @@ -11,5 +11,6 @@ crate-type = ["cdylib"] atoi = "2.0.0" base64 = "0.22.1" itoa = "1" +libc = "0.2" memchr = "2" pyo3 = { version = "0.28", features = ["extension-module"] } diff --git a/README.md b/README.md new file mode 100644 index 0000000..a8c5207 --- /dev/null +++ b/README.md @@ -0,0 +1,414 @@ +# meta-memcache-socket + +A high-performance Rust extension for Python that provides socket I/O, command +building, and response parsing for the +[Memcache meta-protocol](https://github.com/memcached/memcached/wiki/MetaCommands). +Designed as the low-level transport layer for +[meta-memcache-py](https://github.com/RevenueCat/meta-memcache-py). + +## Key features + +- **Rust-native socket I/O** — direct `send()`/`recv()`/`poll()` syscalls, + bypassing Python's socket layer while still respecting `settimeout()` +- **GIL-free** — releases the GIL during all socket operations (`py.detach()`), + so other Python threads run freely while waiting on the network +- **Zero-copy where possible** — response values are read directly into the + internal buffer; `PyBytes` is created from the buffer slice without + intermediate allocation +- **SIMD-accelerated parsing** — uses `memchr` for fast `\r\n` scanning +- **Free-threaded Python support** — built with `gil_used = false`, compatible + with Python 3.13t (no-GIL builds) + +## Project structure + +``` +meta-memcache-socket-py/ +├── Cargo.toml # Rust package manifest +├── pyproject.toml # Python package manifest (maturin backend) +├── src/ +│ ├── lib.rs # PyO3 module entry — exports classes, functions, constants +│ ├── constants.rs # Protocol constants (response codes, set modes, NOOP, ENDL) +│ ├── memcache_socket.rs # MemcacheSocket class — socket I/O, buffering, GIL management +│ ├── request_flags.rs # RequestFlags class — mutable flags for building commands +│ ├── response_flags.rs # ResponseFlags class — immutable flags parsed from responses +│ ├── response_types.rs # Response type classes (Value, Success, Miss, NotStored, Conflict) +│ ├── impl_build_cmd.rs # Command builder — key validation, base64, flag encoding +│ ├── impl_parse_header.rs # Header parser — SIMD search, flag parsing, atoi +│ ├── impl_build_cmd_tests.rs # Rust unit tests for command building +│ ├── impl_parse_header_tests.rs # Rust unit tests for header parsing +│ ├── request_flags_tests.rs # Rust unit tests for RequestFlags +│ └── response_flags_tests.rs # Rust unit tests for ResponseFlags +├── tests/ +│ ├── test_memcache_socket.py # Python tests — socket I/O, timeouts, buffering, NOOP +│ └── test_response_types.py # Python tests — response type semantics +├── bench.py # Microbenchmarks for command building and header parsing +└── .github/workflows/CI.yml # CI — Rust tests, Python tests, cross-platform wheel builds +``` + +## Design overview + +### Architecture + +The module is a single Rust cdylib compiled with [PyO3](https://pyo3.rs/) and +packaged with [Maturin](https://www.maturin.rs/). There is no Python source +code — everything is implemented in Rust and exported to Python directly. + +The design separates into three layers: + +1. **Protocol layer** (`constants.rs`, `impl_build_cmd.rs`, + `impl_parse_header.rs`) — stateless functions that build command byte strings + and parse response headers. These know the meta-protocol grammar but nothing + about sockets. + +2. **Type layer** (`request_flags.rs`, `response_flags.rs`, + `response_types.rs`) — Python-visible classes that carry request parameters + and parsed response data. + +3. **I/O layer** (`memcache_socket.rs`) — the `MemcacheSocket` class that owns + a raw file descriptor, an internal read buffer, and a NOOP counter. All + socket operations release the GIL via `py.detach()` and use `poll()` to + handle non-blocking sockets with proper timeout support. + +### MemcacheSocket internals + +``` +┌──────────────────────────────────────────────────────────┐ +│ MemcacheSocket (Python-visible) │ +│ • _conn: Py — prevents Python GC of socket │ +│ • version: u8 — server version for compat │ +│ • io: SocketIO — all I/O state (Send + Ungil) │ +│ ├── fd: RawFd │ +│ ├── buf: Vec — ring buffer for recv'd data │ +│ ├── pos / read — read cursor / write cursor │ +│ ├── timeout_ms — poll() timeout from settimeout │ +│ └── noop_expected — pending NOOP responses to drain │ +└──────────────────────────────────────────────────────────┘ +``` + +The `SocketIO` struct contains no Python objects, so it satisfies PyO3's +`Ungil` trait and can be passed to `py.detach()` closures that release the GIL. + +**Buffer management**: the internal buffer acts as a sliding window. When `pos` +passes 75% of the buffer, remaining data is shifted to the front +(`copy_within`). Values that fit in the buffer are served directly from it +(zero-copy to Rust); values exceeding the buffer are allocated into a temporary +`Vec`. + +**NOOP handling**: when `sendall()` is called with `with_noop=True`, a `mn\r\n` +command is appended. The NOOP counter increments. On the next `get_response()`, +all responses before the corresponding `MN` are drained automatically, enabling +pipelined fire-and-forget commands. + +**Timeout handling**: at construction time (and on `set_socket()`), the Python +socket's `gettimeout()` is read and converted to milliseconds for `poll()`. If +the socket is blocking (`gettimeout()` returns `None`), poll uses `-1` +(infinite). If a timeout is set, poll respects it and raises Python's +`TimeoutError` on expiry. + +## API reference + +### MemcacheSocket + +The main class for socket communication with a Memcache server. + +```python +from meta_memcache_socket import MemcacheSocket + +# Constructor +ms = MemcacheSocket( + conn, # Python socket object + buffer_size=4096, # Internal read buffer size in bytes + version=SERVER_VERSION_STABLE, # Server version for protocol compat +) + +# Send data, optionally appending a NOOP command +ms.sendall(data: bytes, with_noop: bool) + +# Read and parse the next response header +# Returns one of: Value, Success, Miss, NotStored, Conflict +resp = ms.get_response() + +# Read value payload (call after get_response() returns a Value) +data: bytes = ms.get_value(resp.size) + +# Replace the underlying socket (e.g. after reconnect) +ms.set_socket(new_conn) + +# Close the underlying socket +ms.close() + +# Server version +ms.get_version() # -> int +``` + +### Response types + +All response types are returned by `get_response()`: + +| Class | Protocol code | Bool | Fields | +|---|---|---|---| +| `Miss` | `EN`, `NF` | `False` | — | +| `NotStored` | `NS` | `False` | — | +| `Conflict` | `EX` | `False` | — | +| `Success` | `HD`, `OK` | `True` | `flags: ResponseFlags` | +| `Value` | `VA` | `True` | `size: int`, `flags: ResponseFlags`, `value: Any` (settable) | + +`Miss`, `NotStored`, and `Conflict` are frozen and support equality. +`Value.value` is a mutable slot used by higher-level code (e.g. meta-memcache-py's +executor) to attach deserialized data. + +### ResponseFlags + +Immutable (frozen) container for flags parsed from a server response. + +```python +flags.cas_token # Optional[int] — CAS token (c) +flags.fetched # Optional[bool] — fetched from cache (h) +flags.last_access # Optional[int] — seconds since last access (l) +flags.ttl # Optional[int] — TTL in seconds, -1 = no expiry (t) +flags.client_flag # Optional[int] — user-defined flag (f) +flags.win # Optional[bool] — True=W (won), False=Z (lost) +flags.stale # bool — marked stale (X) +flags.size # Optional[int] — value size (s) +flags.opaque # Optional[bytes] — echoed opaque data (O) +``` + +### RequestFlags + +Mutable container for flags sent with commands. + +```python +from meta_memcache_socket import RequestFlags + +flags = RequestFlags( + # Boolean flags + no_reply=False, # q — don't expect a response + return_client_flag=True, # f + return_cas_token=True, # c + return_value=True, # v + return_ttl=False, # t + return_size=False, # s + return_last_access=False, # l + return_fetched=False, # h + return_key=False, # k + no_update_lru=False, # u + mark_stale=False, # I + + # Optional value flags + cache_ttl=3600, # T — TTL in seconds + recache_ttl=None, # R — recache window + vivify_on_miss_ttl=None, # N — create-on-miss TTL + client_flag=42, # F — user-defined flag + ma_initial_value=None, # J — arithmetic initial value + ma_delta_value=None, # D — arithmetic delta + cas_token=None, # C — CAS token for conditional ops + opaque=None, # O — opaque data echoed back + mode=None, # M — operation mode (set/arithmetic) +) + +flags.copy() # -> RequestFlags (deep copy) +flags.to_bytes() # -> bytes (encoded flag string) +``` + +### Command builders + +Convenience functions that build meta-protocol command byte strings. +All raise `ValueError` if the key exceeds the length limit (250 bytes, or +187 for binary keys which are base64-encoded with a `b` flag). + +```python +from meta_memcache_socket import ( + build_meta_get, + build_meta_set, + build_meta_delete, + build_meta_arithmetic, + build_cmd, +) + +# mg key [flags]\r\n +cmd = build_meta_get(key: bytes, request_flags=None) + +# ms key size [flags]\r\n +cmd = build_meta_set(key: bytes, size: int, request_flags=None, legacy_size_format=False) + +# md key [flags]\r\n +cmd = build_meta_delete(key: bytes, request_flags=None) + +# ma key [flags]\r\n +cmd = build_meta_arithmetic(key: bytes, request_flags=None) + +# Generic: {cmd} key [size] [flags]\r\n +cmd = build_cmd(cmd: bytes, key: bytes, size=None, request_flags=None, legacy_size_format=False) +``` + +### parse_header + +Low-level function to parse a response header from a buffer. Primarily used +internally by `MemcacheSocket.get_response()`, but exposed for advanced use. + +```python +from meta_memcache_socket import parse_header + +# Returns (end_pos, response_type, size, flags) or None if header is incomplete +result = parse_header( + buffer: Union[memoryview, bytearray], + start: int, + end: int, +) +``` + +### Constants + +```python +# Response type codes +RESPONSE_VALUE = 1 +RESPONSE_SUCCESS = 2 +RESPONSE_NOT_STORED = 3 +RESPONSE_CONFLICT = 4 +RESPONSE_MISS = 5 +RESPONSE_NOOP = 100 + +# Set modes (for RequestFlags.mode) +SET_MODE_SET = 83 # 'S' — default set +SET_MODE_ADD = 69 # 'E' — add (only if not exists) +SET_MODE_REPLACE = 82 # 'R' — replace (only if exists) +SET_MODE_APPEND = 65 # 'A' — append to value +SET_MODE_PREPEND = 80 # 'P' — prepend to value + +# Arithmetic modes +MA_MODE_INC = 43 # '+' — increment +MA_MODE_DEC = 45 # '-' — decrement + +# Server versions +SERVER_VERSION_AWS_1_6_6 = 1 # AWS ElastiCache 1.6.6 compat +SERVER_VERSION_STABLE = 2 # Standard memcached +``` + +## Development + +### Prerequisites + +- [Rust](https://rustup.rs/) (stable toolchain, edition 2024) +- Python >= 3.10 +- [uv](https://docs.astral.sh/uv/) (recommended) or pip + maturin + +### Building + +```bash +# Build and install into the project venv (development mode) +uv run --with maturin maturin develop + +# Build in release mode (optimized) +uv run --with maturin maturin develop --release +``` + +### Running tests + +**Rust unit tests** — tests command building, header parsing, and flag encoding: + +```bash +cargo test +``` + +**Python integration tests** — tests socket I/O, timeouts, buffering, response +types, and NOOP handling using real socket pairs: + +```bash +# Build the extension, then run pytest +uv run --with maturin maturin develop +uv run --with pytest pytest tests/ -v +``` + +### Running benchmarks + +```bash +uv run --with maturin maturin develop --release +uv run python bench.py +``` + +## Using a local build with meta-memcache-py + +When developing this package alongside +[meta-memcache-py](https://github.com/RevenueCat/meta-memcache-py), you need +meta-memcache-py to use your local build instead of the PyPI version. + +### Option 1: pip install from local path (quick iteration) + +```bash +cd /path/to/meta-memcache-py + +# Install the local build (--reinstall forces replacement of the existing version) +uv pip install -n -v /path/to/meta-memcache-socket-py --reinstall +``` + +NOTE: When using this option, any `uv run` will revert the package to the +version specified in the pyproject.toml file. + +### Option 2: pyproject.toml dependency override (persistent) + +In `meta-memcache-py`'s `pyproject.toml`, replace the PyPI dependency with a +local file reference: + +```toml +dependencies = [ + # "meta-memcache-socket>=2.0.0", # PyPI version (commented out) + "meta-memcache-socket @ file:///path/to/meta-memcache-socket-py", +] +``` + +Then sync the environment: + +```bash +uv sync +``` + +Remember to revert this before committing. + +## Releasing + +Releases are automated via GitHub Actions CI. + +### Process + +1. Update the version in `Cargo.toml`: + ```toml + [package] + version = "2.1.0" + ``` + +2. Commit and push to `main`. + +3. Create and push a git tag: + ```bash + git tag v2.1.0 + git push origin v2.1.0 + ``` + +4. The CI pipeline will: + - Run Rust and Python tests + - Build wheels for all platforms: + - Linux: x86_64, x86, aarch64, armv7, s390x, ppc64le (glibc + musl) + - macOS: x86_64 (Intel), aarch64 (Apple Silicon) + - Windows: x64, x86 + - Build for both CPython 3.x and free-threaded 3.13t + - Generate build provenance attestation + - Publish all wheels + sdist to PyPI + +The PyPI upload uses the `PYPI_API_TOKEN` repository secret. + +### Manual trigger + +The release job can also be triggered manually via GitHub's "Run workflow" +button on the CI workflow page (`workflow_dispatch`). This runs all build jobs +and generates artifacts but only publishes to PyPI if a tag is present. + +## Dependencies + +| Crate | Purpose | +|---|---| +| [pyo3](https://pyo3.rs/) 0.28 | Python ↔ Rust bindings, GIL management | +| [libc](https://docs.rs/libc) | Direct syscalls: `poll`, `send`, `recv`, `writev`, `setsockopt` | +| [memchr](https://docs.rs/memchr) | SIMD-accelerated `\r\n` scanning | +| [atoi](https://docs.rs/atoi) | Fast ASCII → integer for header parsing | +| [itoa](https://docs.rs/itoa) | Fast integer → ASCII for command building | +| [base64](https://docs.rs/base64) | Binary key encoding | diff --git a/meta_memcache_socket.pyi b/meta_memcache_socket.pyi index e3c86eb..e00828c 100644 --- a/meta_memcache_socket.pyi +++ b/meta_memcache_socket.pyi @@ -1,4 +1,5 @@ -from typing import Optional, Tuple, Union +from typing import Any, Optional, Tuple, Union +import socket RESPONSE_VALUE: int # 1 - VALUE (VA) RESPONSE_SUCCESS: int # 2 - SUCCESS (OK or HD) @@ -24,6 +25,10 @@ MA_MODE_INC: int # 43 ('+') # - "decrement" MA_MODE_DEC: int # 45 ('-') +# Server versions +SERVER_VERSION_AWS_1_6_6: int # 1 +SERVER_VERSION_STABLE: int # 2 + class RequestFlags: """ A class representing the flags for a meta-protocol request @@ -245,3 +250,59 @@ def build_meta_arithmetic( :param request_flags: The flags to use """ ... + +class Miss: + def __init__(self) -> None: ... + def __repr__(self) -> str: ... + def __bool__(self) -> bool: ... + +class NotStored: + def __init__(self) -> None: ... + def __repr__(self) -> str: ... + def __bool__(self) -> bool: ... + +class Conflict: + def __init__(self) -> None: ... + def __repr__(self) -> str: ... + def __bool__(self) -> bool: ... + +class Success: + flags: ResponseFlags + + def __init__(self, flags: ResponseFlags) -> None: ... + def __repr__(self) -> str: ... + +class Value: + size: int + flags: ResponseFlags + value: Any + + def __init__( + self, + size: int, + flags: ResponseFlags, + value: Any = None, + ) -> None: ... + def __repr__(self) -> str: ... + +class MemcacheSocket: + """ + A high-performance memcache socket that handles the meta-protocol + communication with a memcached server. + + Releases the GIL during socket I/O operations. + """ + + def __init__( + self, + conn: socket.socket, + buffer_size: int = 4096, + version: int = ..., # SERVER_VERSION_STABLE + ) -> None: ... + def __str__(self) -> str: ... + def get_version(self) -> int: ... + def set_socket(self, conn: socket.socket) -> None: ... + def close(self) -> None: ... + def sendall(self, data: bytes, with_noop: bool) -> None: ... + def get_response(self) -> Union[Value, Success, Miss, NotStored, Conflict]: ... + def get_value(self, size: int) -> bytes: ... diff --git a/pyproject.toml b/pyproject.toml index a21b363..149797c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,3 +14,10 @@ dynamic = ["version"] [tool.maturin] features = ["pyo3/extension-module"] + +[tool.uv] +cache-keys = [ + { file = "pyproject.toml" }, + { file = "Cargo.toml" }, + { file = "src/*.rs" } +] diff --git a/src/constants.rs b/src/constants.rs index bc0b12a..1225d65 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -18,3 +18,11 @@ pub const SET_MODE_REPLACE: u8 = 82; // 'R' pub const SET_MODE_SET: u8 = 83; // 'S' pub const MA_MODE_INC: u8 = 43; // '+' pub const MA_MODE_DEC: u8 = 45; // '-' + +pub const NOOP_CMD: &[u8] = b"mn\r\n"; +pub const ENDL: &[u8] = b"\r\n"; +pub const ENDL_LEN: usize = 2; + +// Server versions +pub const SERVER_VERSION_AWS_1_6_6: u8 = 1; +pub const SERVER_VERSION_STABLE: u8 = 2; diff --git a/src/lib.rs b/src/lib.rs index 1acd8b8..3596f82 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,10 +3,12 @@ mod impl_build_cmd; mod impl_build_cmd_tests; mod impl_parse_header; mod impl_parse_header_tests; +mod memcache_socket; mod request_flags; mod request_flags_tests; mod response_flags; mod response_flags_tests; +mod response_types; pub use constants::*; use impl_build_cmd::impl_build_cmd; use impl_parse_header::impl_parse_header; @@ -151,14 +153,25 @@ pub fn build_meta_arithmetic<'py>( #[pymodule(gil_used = false)] fn meta_memcache_socket(module: &Bound<'_, PyModule>) -> PyResult<()> { + // Classes module.add_class::()?; module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + + // Functions module.add_function(wrap_pyfunction!(parse_header, module)?)?; module.add_function(wrap_pyfunction!(build_cmd, module)?)?; module.add_function(wrap_pyfunction!(build_meta_get, module)?)?; module.add_function(wrap_pyfunction!(build_meta_set, module)?)?; module.add_function(wrap_pyfunction!(build_meta_delete, module)?)?; module.add_function(wrap_pyfunction!(build_meta_arithmetic, module)?)?; + + // Constants module.add("RESPONSE_VALUE", RESPONSE_VALUE)?; module.add("RESPONSE_SUCCESS", RESPONSE_SUCCESS)?; module.add("RESPONSE_NOT_STORED", RESPONSE_NOT_STORED)?; @@ -172,5 +185,7 @@ fn meta_memcache_socket(module: &Bound<'_, PyModule>) -> PyResult<()> { module.add("SET_MODE_SET", SET_MODE_SET)?; module.add("MA_MODE_INC", MA_MODE_INC)?; module.add("MA_MODE_DEC", MA_MODE_DEC)?; + module.add("SERVER_VERSION_AWS_1_6_6", SERVER_VERSION_AWS_1_6_6)?; + module.add("SERVER_VERSION_STABLE", SERVER_VERSION_STABLE)?; Ok(()) } diff --git a/src/memcache_socket.rs b/src/memcache_socket.rs new file mode 100644 index 0000000..7eb4cd2 --- /dev/null +++ b/src/memcache_socket.rs @@ -0,0 +1,536 @@ +use std::os::fd::RawFd; + +use pyo3::BoundObject; +use pyo3::exceptions::{PyConnectionError, PyTimeoutError}; +use pyo3::prelude::*; +use pyo3::types::PyBytes; + +use crate::constants::*; +use crate::impl_parse_header::{ParsedHeader, impl_parse_header}; +use crate::response_types::*; + +const DEFAULT_BUFFER_SIZE: usize = 4096; + +/// Convert a Rust pyclass into a `Py` for returning from methods +/// that return different Python types (union return). +fn into_py<'py, T: IntoPyObject<'py>>(py: Python<'py>, obj: T) -> PyResult> +where + T::Error: Into, +{ + Ok(obj + .into_pyobject(py) + .map_err(Into::into)? + .into_any() + .unbind()) +} + +fn socket_err(msg: &str) -> PyErr { + PyConnectionError::new_err(msg.to_string()) +} + +fn socket_err_io(msg: &str, source: std::io::Error) -> PyErr { + if source.kind() == std::io::ErrorKind::TimedOut { + PyTimeoutError::new_err("timed out") + } else { + PyConnectionError::new_err(format!("{msg}: {source}")) + } +} + +/// Read the timeout from a Python socket object and convert to poll() milliseconds. +/// Returns -1 for blocking sockets (timeout is None), or a positive ms value. +fn get_timeout_ms(conn: &Bound<'_, PyAny>) -> PyResult { + let timeout_obj = conn.call_method0("gettimeout")?; + if timeout_obj.is_none() { + Ok(-1) + } else { + let seconds: f64 = timeout_obj.extract()?; + // Convert seconds to milliseconds, clamping to valid range. + // A timeout of 0 means non-blocking (don't wait at all). + let ms = (seconds * 1000.0).ceil() as i64; + Ok(ms.clamp(0, libc::c_int::MAX as i64) as libc::c_int) + } +} + +/// Wait for the fd to become ready for reading/writing using poll(). +/// This handles non-blocking sockets set up via Python's settimeout(). +/// `timeout_ms` is the poll timeout: -1 for infinite (blocking sockets), +/// or a positive value in milliseconds (from Python's settimeout()). +#[inline] +fn poll_fd( + fd: RawFd, + events: libc::c_short, + timeout_ms: libc::c_int, +) -> Result<(), std::io::Error> { + let mut pfd = libc::pollfd { + fd, + events, + revents: 0, + }; + // SAFETY: pfd is a valid pollfd struct on the stack, nfds=1 + let ret = unsafe { libc::poll(&mut pfd, 1, timeout_ms) }; + if ret < 0 { + let err = std::io::Error::last_os_error(); + if err.kind() == std::io::ErrorKind::Interrupted { + return poll_fd(fd, events, timeout_ms); // signal interrupted us, retry + } + Err(err) + } else if ret == 0 { + Err(std::io::Error::new( + std::io::ErrorKind::TimedOut, + "timed out", + )) + } else if pfd.revents & (libc::POLLERR | libc::POLLHUP | libc::POLLNVAL) != 0 { + Err(std::io::Error::new( + std::io::ErrorKind::ConnectionReset, + "poll error on socket", + )) + } else { + Ok(()) + } +} + +/// Send all bytes through the fd, handling partial writes and EAGAIN. +#[inline] +fn send_all(fd: RawFd, data: &[u8], timeout_ms: libc::c_int) -> Result<(), std::io::Error> { + let mut sent = 0; + while sent < data.len() { + // SAFETY: data[sent..] is a valid byte slice, fd is a valid socket + let n = unsafe { + libc::send( + fd, + data[sent..].as_ptr() as *const libc::c_void, + data.len() - sent, + 0, + ) + }; + if n > 0 { + sent += n as usize; + } else if n < 0 { + let err = std::io::Error::last_os_error(); + match err.kind() { + std::io::ErrorKind::WouldBlock => poll_fd(fd, libc::POLLOUT, timeout_ms)?, + std::io::ErrorKind::Interrupted => continue, + _ => return Err(err), + } + } else { + return Err(std::io::Error::new( + std::io::ErrorKind::WriteZero, + "send returned 0", + )); + } + } + Ok(()) +} + +/// Send data + NOOP command in a single write when possible. +/// Uses writev() to avoid concatenation on the happy path, falls back +/// to send_all() for partial writes and EAGAIN. +#[inline] +fn send_all_with_noop( + fd: RawFd, + data: &[u8], + timeout_ms: libc::c_int, +) -> Result<(), std::io::Error> { + let iov = [ + libc::iovec { + iov_base: data.as_ptr() as *mut libc::c_void, + iov_len: data.len(), + }, + libc::iovec { + iov_base: NOOP_CMD.as_ptr() as *mut libc::c_void, + iov_len: NOOP_CMD.len(), + }, + ]; + // SAFETY: iov array has 2 valid entries pointing to data and NOOP_CMD + let n = unsafe { libc::writev(fd, iov.as_ptr(), 2) }; + let total_len = data.len() + NOOP_CMD.len(); + let written = if n >= 0 { + n as usize + } else { + let err = std::io::Error::last_os_error(); + match err.kind() { + std::io::ErrorKind::WouldBlock | std::io::ErrorKind::Interrupted => 0, + _ => return Err(err), + } + }; + if written < total_len { + let combined: Vec = [data, NOOP_CMD].concat(); + send_all(fd, &combined[written..], timeout_ms)?; + } + Ok(()) +} + +/// Recv into buffer slice, returns bytes read. Handles EAGAIN by polling. +fn recv_into(fd: RawFd, buf: &mut [u8], timeout_ms: libc::c_int) -> Result { + loop { + // SAFETY: buf is a valid mutable byte slice, fd is a valid socket + let n = unsafe { libc::recv(fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len(), 0) }; + if n > 0 { + return Ok(n as usize); + } else if n == 0 { + return Ok(0); + } else { + let err = std::io::Error::last_os_error(); + match err.kind() { + std::io::ErrorKind::WouldBlock => poll_fd(fd, libc::POLLIN, timeout_ms)?, + std::io::ErrorKind::Interrupted => continue, + _ => return Err(err), + } + } + } +} + +/// Recv filling the buffer completely. Handles EAGAIN by polling. +fn recv_fill(fd: RawFd, buf: &mut [u8], timeout_ms: libc::c_int) -> Result { + let mut total = 0; + let size = buf.len(); + while total < size { + // SAFETY: buf[total..] is a valid mutable byte slice, fd is a valid socket + let n = unsafe { + libc::recv( + fd, + buf[total..].as_mut_ptr() as *mut libc::c_void, + size - total, + libc::MSG_WAITALL, + ) + }; + if n > 0 { + total += n as usize; + } else if n == 0 { + return Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "connection closed during recv_fill", + )); + } else { + let err = std::io::Error::last_os_error(); + match err.kind() { + std::io::ErrorKind::WouldBlock => poll_fd(fd, libc::POLLIN, timeout_ms)?, + std::io::ErrorKind::Interrupted => continue, + _ => return Err(err), + } + } + } + Ok(total) +} + +/// Where the value data ended up after recv. +enum ValueData { + /// Value is in io.buf[pos..pos+size], ENDL validated. Caller advances pos. + InBuffer, + /// Value was too large for the buffer, stored in this Vec. + Allocated(Vec), +} + +/// Inner I/O state — no Python objects, so it is Send/Ungil. +/// This allows releasing the GIL during socket I/O via py.detach(). +struct SocketIO { + fd: RawFd, + buf: Vec, + buffer_size: usize, + reset_buffer_size: usize, + pos: usize, + read: usize, + noop_expected: u32, + /// poll() timeout in milliseconds. -1 for blocking sockets (no timeout), + /// positive value from Python's socket.settimeout(). + timeout_ms: libc::c_int, +} + +impl SocketIO { + fn recv_into_buffer(&mut self) -> Result { + let n = recv_into(self.fd, &mut self.buf[self.read..], self.timeout_ms)?; + if n > 0 { + self.read += n; + } + Ok(n) + } + + fn reset_buffer(&mut self) { + let remaining = self.read - self.pos; + if remaining > 0 { + self.buf.copy_within(self.pos..self.read, 0); + } + self.pos = 0; + self.read = remaining; + } + + fn get_single_header(&mut self) -> Result { + if self.pos >= self.read { + self.read = 0; + self.pos = 0; + } else if self.pos > self.reset_buffer_size { + self.reset_buffer(); + } + + loop { + if self.read != self.pos + && let Some(header) = impl_parse_header(&self.buf, self.pos, self.read) + { + self.pos = header.end_pos; + return Ok(header); + } + let n = self.recv_into_buffer()?; + if n == 0 { + return Err(std::io::Error::new( + std::io::ErrorKind::ConnectionAborted, + "Bad response. Socket might have closed unexpectedly", + )); + } + } + } + + fn read_until_noop_header(&mut self) -> Result<(), std::io::Error> { + while self.noop_expected > 0 { + let header = self.get_single_header()?; + if header.response_type == Some(RESPONSE_NOOP) { + self.noop_expected -= 1; + } + } + Ok(()) + } + + fn get_header(&mut self) -> Result { + if self.noop_expected > 0 { + self.read_until_noop_header()?; + } + self.get_single_header() + } + + fn sendall_impl(&self, data: &[u8], with_noop: bool) -> Result<(), std::io::Error> { + if with_noop { + send_all_with_noop(self.fd, data, self.timeout_ms) + } else { + send_all(self.fd, data, self.timeout_ms) + } + } + + /// Ensure value data is available for reading. + /// + /// For the common case (value fits in buffer), returns InBuffer — + /// the data is at buf[pos..pos+size] with ENDL validated. + /// The caller creates PyBytes directly from the buffer slice (zero-copy to Rust). + /// + /// For large values exceeding the buffer, returns Allocated with the data. + fn ensure_value(&mut self, size: usize) -> Result { + let message_size = size + ENDL_LEN; + + // Try to fill buffer with enough data + let mut data_in_buf = self.read - self.pos; + while data_in_buf < message_size && self.read < self.buffer_size { + let n = self.recv_into_buffer()?; + if n == 0 { + return Err(std::io::Error::new( + std::io::ErrorKind::ConnectionAborted, + "Connection closed while reading value", + )); + } + data_in_buf = self.read - self.pos; + } + + if data_in_buf >= message_size { + // Value + ENDL fully in buffer — validate ENDL in place + let data_end = self.pos + size; + if self.buf[data_end] != b'\r' || self.buf[data_end + 1] != b'\n' { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "Value not terminated with \\r\\n", + )); + } + // Don't advance pos yet — caller reads buf[pos..pos+size] then advances + Ok(ValueData::InBuffer) + } else if data_in_buf >= size { + // Value in buffer but ENDL partially/not in buffer. + // We still return InBuffer, but need to read+validate the ENDL. + let data_end = self.pos + size; + let endl_in_buf = data_in_buf - size; + let mut endl_buf = [0u8; ENDL_LEN]; + if endl_in_buf > 0 { + endl_buf[..endl_in_buf].copy_from_slice(&self.buf[data_end..self.read]); + } + if endl_in_buf < ENDL_LEN { + recv_fill(self.fd, &mut endl_buf[endl_in_buf..], self.timeout_ms)?; + } + if endl_buf != *ENDL { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "Value not terminated with \\r\\n", + )); + } + // Discard partial ENDL bytes from the buffer's tracked range. + // The caller will advance pos by size + ENDL_LEN, which will + // overshoot read — get_single_header handles this via pos >= read. + self.read = data_end; + Ok(ValueData::InBuffer) + } else { + // Value doesn't fit in buffer — allocate and read into temp buffer + let mut message = vec![0u8; message_size]; + message[..data_in_buf].copy_from_slice(&self.buf[self.pos..self.read]); + recv_fill(self.fd, &mut message[data_in_buf..], self.timeout_ms)?; + + if message[size] != b'\r' || message[size + 1] != b'\n' { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "Value not terminated with \\r\\n", + )); + } + + self.pos = self.read; // Buffer fully consumed + message.truncate(size); + Ok(ValueData::Allocated(message)) + } + } +} + +#[pyclass] +pub struct MemcacheSocket { + io: SocketIO, + /// Hold a reference to the Python socket to prevent GC. + _conn: Py, + version: u8, +} + +#[pymethods] +impl MemcacheSocket { + #[new] + #[pyo3(signature = (conn, buffer_size=DEFAULT_BUFFER_SIZE, version=SERVER_VERSION_STABLE))] + pub fn new(conn: &Bound<'_, PyAny>, buffer_size: usize, version: u8) -> PyResult { + let fd: RawFd = conn.call_method0("fileno")?.extract()?; + let timeout_ms = get_timeout_ms(conn)?; + + // Set SO_RCVBUF — failure is non-fatal (kernel may reject the size) + let recv_buf_size: libc::c_int = buffer_size as libc::c_int; + // SAFETY: fd is a valid socket, recv_buf_size is a valid c_int on the stack + let ret = unsafe { + libc::setsockopt( + fd, + libc::SOL_SOCKET, + libc::SO_RCVBUF, + &recv_buf_size as *const libc::c_int as *const libc::c_void, + std::mem::size_of::() as libc::socklen_t, + ) + }; + if ret != 0 { + // Non-fatal: log would be ideal but we don't have a logger here. + // The socket will still work with the kernel default buffer size. + let _ = ret; + } + + Ok(MemcacheSocket { + io: SocketIO { + fd, + buf: vec![0u8; buffer_size], + buffer_size, + reset_buffer_size: buffer_size * 3 / 4, + pos: 0, + read: 0, + noop_expected: 0, + timeout_ms, + }, + _conn: conn.clone().unbind(), + version, + }) + } + + pub fn __str__(&self) -> String { + format!("", self.io.fd) + } + + pub fn get_version(&self) -> u8 { + self.version + } + + pub fn set_socket(&mut self, conn: &Bound<'_, PyAny>) -> PyResult<()> { + self.io.fd = conn.call_method0("fileno")?.extract()?; + self.io.timeout_ms = get_timeout_ms(conn)?; + self._conn = conn.clone().unbind(); + self.io.pos = 0; + self.io.read = 0; + self.io.noop_expected = 0; + Ok(()) + } + + pub fn close(&mut self, py: Python<'_>) -> PyResult<()> { + self._conn.call_method0(py, "close")?; + self.io.fd = -1; + self.io.pos = 0; + self.io.read = 0; + self.io.noop_expected = 0; + Ok(()) + } + + /// Send data to the socket, optionally appending a NOOP command. + /// Releases the GIL during socket I/O. + pub fn sendall(&mut self, py: Python<'_>, data: &[u8], with_noop: bool) -> PyResult<()> { + let io = &mut self.io; + py.detach(|| io.sendall_impl(data, with_noop)) + .map_err(|e| socket_err_io("Error sending data", e))?; + if with_noop { + self.io.noop_expected += 1; + } + Ok(()) + } + + /// Read and parse the next response header. + /// Releases the GIL during socket I/O and header parsing. + pub fn get_response(&mut self, py: Python<'_>) -> PyResult> { + let io = &mut self.io; + let header = py + .detach(|| io.get_header()) + .map_err(|e| socket_err_io("Error reading header", e))?; + + match header.response_type { + Some(RESPONSE_VALUE) => { + let size = header + .size + .ok_or_else(|| socket_err("Value response missing size"))?; + let flags = header + .flags + .ok_or_else(|| socket_err("Value response missing flags"))?; + into_py(py, Value::new(size, flags, None)) + } + Some(RESPONSE_SUCCESS) => { + let flags = header + .flags + .ok_or_else(|| socket_err("Success response missing flags"))?; + into_py(py, Success::new(flags)) + } + Some(RESPONSE_NOT_STORED) => into_py(py, NotStored::new()), + Some(RESPONSE_CONFLICT) => into_py(py, Conflict::new()), + Some(RESPONSE_MISS) => into_py(py, Miss::new()), + _ => Err(socket_err(&format!( + "Unknown response code: {:?}", + header.response_type + ))), + } + } + + /// Read value data from the socket. + /// Releases the GIL during socket I/O. + /// For the common case (value fits in buffer), creates PyBytes directly + /// from the buffer — no intermediate allocation. + pub fn get_value<'py>(&mut self, py: Python<'py>, size: u32) -> PyResult> { + let size_usize = size as usize; + let io = &mut self.io; + + // Phase 1: recv data without GIL + let location = py + .detach(|| io.ensure_value(size_usize)) + .map_err(|e| socket_err_io("Error receiving value", e))?; + + // Phase 2: create PyBytes with GIL + match location { + ValueData::InBuffer => { + // Common path: value is in io.buf — create PyBytes directly, no extra alloc + let data_start = self.io.pos; + let data_end = data_start + size_usize; + let result = PyBytes::new(py, &self.io.buf[data_start..data_end]); + self.io.pos = data_end + ENDL_LEN; + Ok(result) + } + ValueData::Allocated(data) => { + // Large value path: data already in Vec, create PyBytes from it + Ok(PyBytes::new(py, &data)) + } + } + } +} diff --git a/src/response_types.rs b/src/response_types.rs new file mode 100644 index 0000000..b301fd0 --- /dev/null +++ b/src/response_types.rs @@ -0,0 +1,109 @@ +use pyo3::prelude::*; + +use crate::response_flags::ResponseFlags; + +#[pyclass(frozen, eq, skip_from_py_object)] +#[derive(Clone, Debug, PartialEq)] +pub struct Miss {} + +#[pymethods] +impl Miss { + #[new] + pub fn new() -> Self { + Miss {} + } + + pub fn __repr__(&self) -> &'static str { + "Miss()" + } + + pub fn __bool__(&self) -> bool { + false + } +} + +#[pyclass(frozen, eq, skip_from_py_object)] +#[derive(Clone, Debug, PartialEq)] +pub struct NotStored {} + +#[pymethods] +impl NotStored { + #[new] + pub fn new() -> Self { + NotStored {} + } + + pub fn __repr__(&self) -> &'static str { + "NotStored()" + } + + pub fn __bool__(&self) -> bool { + false + } +} + +#[pyclass(frozen, eq, skip_from_py_object)] +#[derive(Clone, Debug, PartialEq)] +pub struct Conflict {} + +#[pymethods] +impl Conflict { + #[new] + pub fn new() -> Self { + Conflict {} + } + + pub fn __repr__(&self) -> &'static str { + "Conflict()" + } + + pub fn __bool__(&self) -> bool { + false + } +} + +#[pyclass(frozen, skip_from_py_object)] +#[derive(Clone, Debug)] +pub struct Success { + #[pyo3(get)] + pub flags: ResponseFlags, +} + +#[pymethods] +impl Success { + #[new] + pub fn new(flags: ResponseFlags) -> Self { + Success { flags } + } + + pub fn __repr__(&self) -> String { + format!("Success(flags={})", self.flags.__str__()) + } +} + +#[pyclass(skip_from_py_object)] +pub struct Value { + #[pyo3(get)] + pub size: u32, + #[pyo3(get)] + pub flags: ResponseFlags, + #[pyo3(get, set)] + pub value: Option>, +} + +#[pymethods] +impl Value { + #[new] + pub fn new(size: u32, flags: ResponseFlags, value: Option>) -> Self { + Value { size, flags, value } + } + + pub fn __repr__(&self) -> String { + format!( + "Value(size={}, flags={}, value={:?})", + self.size, + self.flags.__str__(), + self.value.as_ref().map(|_| "..."), + ) + } +} diff --git a/tests/test_memcache_socket.py b/tests/test_memcache_socket.py new file mode 100644 index 0000000..cf64439 --- /dev/null +++ b/tests/test_memcache_socket.py @@ -0,0 +1,751 @@ +"""Tests for the Rust MemcacheSocket class. + +Mirrors the tests from meta-memcache-py/tests/memcache_socket_test.py +but tests the Rust implementation directly. +""" + +import socket + +import pytest + +from meta_memcache_socket import ( + Conflict, + MemcacheSocket, + Miss, + NotStored, + ResponseFlags, + Success, + Value, + SERVER_VERSION_AWS_1_6_6, + SERVER_VERSION_STABLE, +) + + +@pytest.fixture +def socket_pair(): + a, b = socket.socketpair() + yield a, b + try: + a.close() + except OSError: + pass + try: + b.close() + except OSError: + pass + + +# --- Constructor and basic methods --- + + +class TestConstructor: + def test_create_default(self, socket_pair): + a, b = socket_pair + ms = MemcacheSocket(a) + assert ms.get_version() == SERVER_VERSION_STABLE + + def test_create_with_version(self, socket_pair): + a, b = socket_pair + ms = MemcacheSocket(a, version=SERVER_VERSION_AWS_1_6_6) + assert ms.get_version() == SERVER_VERSION_AWS_1_6_6 + + def test_create_with_buffer_size(self, socket_pair): + a, b = socket_pair + ms = MemcacheSocket(a, buffer_size=8192) + assert ms.get_version() == SERVER_VERSION_STABLE + + def test_str(self, socket_pair): + a, b = socket_pair + ms = MemcacheSocket(a) + s = str(ms) + assert " read after the ENDL-split path. + resp2 = ms.get_response() + assert isinstance(resp2, Miss) + + def test_value_with_incomplete_endl_then_value(self, socket_pair): + """ENDL-split followed by another value response.""" + a, b = socket_pair + ms = MemcacheSocket(a, buffer_size=18) + b.sendall(b"VA 10\r\n1234567890\r\nVA 3\r\nfoo\r\n") + resp = ms.get_response() + assert isinstance(resp, Value) + assert ms.get_value(resp.size) == b"1234567890" + + resp2 = ms.get_response() + assert isinstance(resp2, Value) + assert ms.get_value(resp2.size) == b"foo" + + def test_multiple_values(self, socket_pair): + a, b = socket_pair + ms = MemcacheSocket(a) + b.sendall( + b"VA 3\r\nfoo\r\n" + b"VA 3\r\nbar\r\n" + b"VA 3\r\nbaz\r\n" + ) + for expected in [b"foo", b"bar", b"baz"]: + resp = ms.get_response() + assert isinstance(resp, Value) + val = ms.get_value(resp.size) + assert val == expected + + def test_value_then_miss(self, socket_pair): + """Read a value, then a simple response.""" + a, b = socket_pair + ms = MemcacheSocket(a) + b.sendall(b"VA 5 f1\r\nhello\r\nEN\r\n") + resp1 = ms.get_response() + assert isinstance(resp1, Value) + val = ms.get_value(resp1.size) + assert val == b"hello" + + resp2 = ms.get_response() + assert isinstance(resp2, Miss) + + def test_interleaved_responses(self, socket_pair): + """Simulate pipelined responses: value, success, miss, value.""" + a, b = socket_pair + ms = MemcacheSocket(a) + b.sendall( + b"VA 2 f1\r\nhi\r\n" + b"HD c5\r\n" + b"EN\r\n" + b"VA 3 f2\r\nbye\r\n" + ) + # Value + r = ms.get_response() + assert isinstance(r, Value) + assert ms.get_value(r.size) == b"hi" + # Success + r = ms.get_response() + assert isinstance(r, Success) + assert r.flags.cas_token == 5 + # Miss + r = ms.get_response() + assert isinstance(r, Miss) + # Value + r = ms.get_response() + assert isinstance(r, Value) + assert ms.get_value(r.size) == b"bye" + + +# --- NOOP handling --- + + +class TestNoopHandling: + def test_noop_drains_responses(self, socket_pair): + """Responses before NOOP should be discarded.""" + a, b = socket_pair + ms = MemcacheSocket(a) + ms.sendall(b"test", with_noop=True) + + # EX (conflict) before MN should be discarded; HD after is real + b.sendall(b"EX\r\nMN\r\nHD\r\n") + resp = ms.get_response() + assert isinstance(resp, Success) + + def test_noop_no_responses_before(self, socket_pair): + """NOOP with nothing to drain.""" + a, b = socket_pair + ms = MemcacheSocket(a) + ms.sendall(b"test", with_noop=True) + + b.sendall(b"MN\r\nEN\r\n") + resp = ms.get_response() + assert isinstance(resp, Miss) + + def test_multiple_noops(self, socket_pair): + """Multiple NOOPs should all be drained.""" + a, b = socket_pair + ms = MemcacheSocket(a) + ms.sendall(b"cmd1", with_noop=True) + ms.sendall(b"cmd2", with_noop=True) + + # Two MN responses followed by the actual response + b.sendall(b"MN\r\nMN\r\nHD c1\r\n") + resp = ms.get_response() + assert isinstance(resp, Success) + assert resp.flags.cas_token == 1 + + def test_noop_with_multiple_skipped_responses(self, socket_pair): + """Multiple responses before NOOP are all discarded.""" + a, b = socket_pair + ms = MemcacheSocket(a) + ms.sendall(b"test", with_noop=True) + + # HD, NS, EX all before MN - all discarded + b.sendall(b"HD\r\nNS\r\nEX\r\nMN\r\nEN\r\n") + resp = ms.get_response() + assert isinstance(resp, Miss) + + +# --- Error handling --- + + +class TestErrorHandling: + def test_closed_socket_on_get_response(self, socket_pair): + a, b = socket_pair + ms = MemcacheSocket(a) + b.close() + with pytest.raises(ConnectionError): + ms.get_response() + + def test_closed_socket_on_get_value(self, socket_pair): + a, b = socket_pair + ms = MemcacheSocket(a) + b.sendall(b"VA 100\r\n") # Claim 100 bytes but close + resp = ms.get_response() + assert isinstance(resp, Value) + b.close() + with pytest.raises(ConnectionError): + ms.get_value(resp.size) + + def test_closed_socket_on_sendall(self, socket_pair): + a, b = socket_pair + ms = MemcacheSocket(a) + a.close() + with pytest.raises(ConnectionError): + ms.sendall(b"test\r\n", False) + + def test_close_invalidates_fd(self, socket_pair): + """After close(), operations should fail, not use a stale fd.""" + a, b = socket_pair + ms = MemcacheSocket(a) + ms.close() + with pytest.raises(ConnectionError): + ms.sendall(b"test\r\n", False) + with pytest.raises(ConnectionError): + ms.get_response() + + def test_close_resets_noop_state(self, socket_pair): + """close() should reset noop_expected so set_socket starts clean.""" + a, b = socket_pair + ms = MemcacheSocket(a) + ms.sendall(b"test", with_noop=True) + ms.close() + + # Reconnect with a new socket + c, d = socket.socketpair() + try: + ms.set_socket(c) + d.sendall(b"EN\r\n") + # Should NOT try to drain a NOOP from the previous connection + resp = ms.get_response() + assert isinstance(resp, Miss) + finally: + c.close() + d.close() + + def test_set_socket_resets_noop_state(self, socket_pair): + """set_socket() should reset noop_expected for the new connection.""" + a, b = socket_pair + ms = MemcacheSocket(a) + ms.sendall(b"test", with_noop=True) + + c, d = socket.socketpair() + try: + ms.set_socket(c) + d.sendall(b"HD c1\r\n") + resp = ms.get_response() + assert isinstance(resp, Success) + assert resp.flags.cas_token == 1 + finally: + c.close() + d.close() + + +# --- Buffer management --- + + +class TestBufferManagement: + def test_small_buffer_many_responses(self, socket_pair): + """Stress the buffer reset logic with a small buffer.""" + a, b = socket_pair + ms = MemcacheSocket(a, buffer_size=32) + + for i in range(50): + b.sendall(b"EN\r\n") + resp = ms.get_response() + assert isinstance(resp, Miss) + + def test_small_buffer_values(self, socket_pair): + """Values that just fit in a small buffer.""" + a, b = socket_pair + ms = MemcacheSocket(a, buffer_size=64) + + for i in range(20): + b.sendall(b"VA 5\r\nhello\r\n") + resp = ms.get_response() + assert isinstance(resp, Value) + val = ms.get_value(resp.size) + assert val == b"hello" + + def test_responses_spanning_buffer_boundary(self, socket_pair): + """Header that arrives across multiple recv calls.""" + a, b = socket_pair + ms = MemcacheSocket(a, buffer_size=4096) + + # Send header in two parts + b.sendall(b"VA 3 ") + b.sendall(b"c42\r\nfoo\r\n") + resp = ms.get_response() + assert isinstance(resp, Value) + assert resp.flags.cas_token == 42 + val = ms.get_value(resp.size) + assert val == b"foo" + + +# --- Version constants --- + + +class TestNonBlockingSocket: + """Test with sockets in non-blocking mode (settimeout), matching Python's socket_factory_builder.""" + + def test_settimeout_get_response(self, socket_pair): + a, b = socket_pair + a.settimeout(5.0) # Puts socket in non-blocking mode with timeout + ms = MemcacheSocket(a) + + b.sendall(b"EN\r\n") + resp = ms.get_response() + assert isinstance(resp, Miss) + + def test_settimeout_get_value(self, socket_pair): + a, b = socket_pair + a.settimeout(5.0) + ms = MemcacheSocket(a) + + b.sendall(b"VA 5\r\nhello\r\n") + resp = ms.get_response() + assert isinstance(resp, Value) + val = ms.get_value(resp.size) + assert val == b"hello" + + def test_settimeout_large_value(self, socket_pair): + a, b = socket_pair + a.settimeout(5.0) + ms = MemcacheSocket(a, buffer_size=100) + + payload = b"x" * 500 + b.sendall(b"VA 500\r\n" + payload + b"\r\n") + resp = ms.get_response() + assert isinstance(resp, Value) + val = ms.get_value(resp.size) + assert val == payload + + def test_settimeout_sendall(self, socket_pair): + a, b = socket_pair + a.settimeout(5.0) + ms = MemcacheSocket(a) + + ms.sendall(b"mg testkey\r\n", False) + data = b.recv(1024) + assert data == b"mg testkey\r\n" + + def test_settimeout_sendall_with_noop(self, socket_pair): + a, b = socket_pair + a.settimeout(5.0) + ms = MemcacheSocket(a) + + ms.sendall(b"md testkey q\r\n", True) + data = b.recv(1024) + assert data == b"md testkey q\r\nmn\r\n" + + def test_settimeout_pipeline(self, socket_pair): + """Full pipeline flow with non-blocking sockets.""" + a, b = socket_pair + a.settimeout(5.0) + ms = MemcacheSocket(a) + + # Send two commands + ms.sendall(b"mg key1\r\n", False) + ms.sendall(b"mg key2\r\n", False) + + # Server responds + b.sendall(b"VA 3 f1\r\nfoo\r\nEN\r\n") + + r1 = ms.get_response() + assert isinstance(r1, Value) + assert ms.get_value(r1.size) == b"foo" + + r2 = ms.get_response() + assert isinstance(r2, Miss) + + def test_settimeout_noop(self, socket_pair): + a, b = socket_pair + a.settimeout(5.0) + ms = MemcacheSocket(a) + + ms.sendall(b"test", with_noop=True) + b.sendall(b"EX\r\nMN\r\nHD c1\r\n") + resp = ms.get_response() + assert isinstance(resp, Success) + assert resp.flags.cas_token == 1 + + +class TestSocketTimeout: + """Test that Python socket timeouts are respected by the Rust implementation.""" + + def test_get_response_timeout(self, socket_pair): + """get_response() should raise TimeoutError when no data arrives within timeout.""" + a, b = socket_pair + a.settimeout(0.1) # 100ms timeout + ms = MemcacheSocket(a) + + # Don't send any data — should timeout + with pytest.raises(TimeoutError): + ms.get_response() + + def test_get_value_timeout(self, socket_pair): + """get_value() should raise TimeoutError when value data doesn't arrive.""" + a, b = socket_pair + a.settimeout(0.1) + ms = MemcacheSocket(a) + + # Send header but not the value data + b.sendall(b"VA 100\r\n") + resp = ms.get_response() + assert isinstance(resp, Value) + assert resp.size == 100 + + # Value data never arrives — should timeout + with pytest.raises((TimeoutError, ConnectionError)): + ms.get_value(resp.size) + + def test_sendall_timeout(self, socket_pair): + """sendall() should raise TimeoutError when send buffer is full.""" + a, b = socket_pair + a.settimeout(0.1) + ms = MemcacheSocket(a) + + # Fill the send buffer until it blocks, then expect timeout. + # Use a large payload to overwhelm the socket buffer. + big_data = b"x" * (1024 * 1024 * 10) # 10MB + with pytest.raises((TimeoutError, ConnectionError)): + for _ in range(100): + ms.sendall(big_data, False) + + def test_blocking_socket_no_timeout(self, socket_pair): + """Blocking socket (no settimeout) should not have poll timeout issues.""" + a, b = socket_pair + # No settimeout — socket is blocking, timeout should be -1 (infinite) + ms = MemcacheSocket(a) + + b.sendall(b"EN\r\n") + resp = ms.get_response() + assert isinstance(resp, Miss) + + def test_set_socket_updates_timeout(self, socket_pair): + """set_socket() should pick up the new socket's timeout.""" + a, b = socket_pair + a.settimeout(0.1) + ms = MemcacheSocket(a) + + # Create a new socket pair with no timeout + c, d = socket.socketpair() + try: + ms.set_socket(c) + d.sendall(b"EN\r\n") + resp = ms.get_response() + assert isinstance(resp, Miss) + finally: + c.close() + d.close() + + +class TestVersionConstants: + def test_constants_values(self): + assert SERVER_VERSION_AWS_1_6_6 == 1 + assert SERVER_VERSION_STABLE == 2 + + def test_version_matches_intenum(self): + """ServerVersion IntEnum values match Rust constants.""" + # These must match for backward compatibility + assert SERVER_VERSION_AWS_1_6_6 == 1 + assert SERVER_VERSION_STABLE == 2 diff --git a/tests/test_response_types.py b/tests/test_response_types.py new file mode 100644 index 0000000..9719c99 --- /dev/null +++ b/tests/test_response_types.py @@ -0,0 +1,124 @@ +"""Tests for Rust response types: Value, Success, Miss, NotStored, Conflict.""" + +from meta_memcache_socket import ( + Conflict, + Miss, + NotStored, + ResponseFlags, + Success, + Value, +) + + +class TestMiss: + def test_create(self): + m = Miss() + assert repr(m) == "Miss()" + + def test_bool_is_false(self): + assert not Miss() + assert bool(Miss()) is False + + def test_equality(self): + assert Miss() == Miss() + + def test_inequality_with_other_types(self): + assert Miss() != NotStored() + assert Miss() != Conflict() + + +class TestNotStored: + def test_create(self): + ns = NotStored() + assert repr(ns) == "NotStored()" + + def test_bool_is_false(self): + assert not NotStored() + + def test_equality(self): + assert NotStored() == NotStored() + + +class TestConflict: + def test_create(self): + c = Conflict() + assert repr(c) == "Conflict()" + + def test_bool_is_false(self): + assert not Conflict() + + def test_equality(self): + assert Conflict() == Conflict() + + +class TestSuccess: + def test_create(self): + flags = ResponseFlags() + s = Success(flags) + assert s.flags == flags + + def test_create_with_flags(self): + flags = ResponseFlags(cas_token=42, stale=True) + s = Success(flags) + assert s.flags.cas_token == 42 + assert s.flags.stale is True + + def test_repr(self): + s = Success(ResponseFlags()) + assert "Success" in repr(s) + + def test_flags_immutable(self): + """Success is frozen, flags should not be settable.""" + s = Success(ResponseFlags()) + try: + s.flags = ResponseFlags(cas_token=1) + assert False, "Should have raised AttributeError" + except AttributeError: + pass + + +class TestValue: + def test_create(self): + flags = ResponseFlags(client_flag=42) + v = Value(size=100, flags=flags, value=None) + assert v.size == 100 + assert v.flags.client_flag == 42 + assert v.value is None + + def test_value_setter(self): + """Value.value must be settable (used by executor to attach deserialized data).""" + v = Value(size=5, flags=ResponseFlags(), value=None) + assert v.value is None + v.value = b"hello" + assert v.value == b"hello" + v.value = "deserialized string" + assert v.value == "deserialized string" + v.value = {"key": "val"} + assert v.value == {"key": "val"} + v.value = None + assert v.value is None + + def test_repr(self): + v = Value(size=10, flags=ResponseFlags(), value=None) + r = repr(v) + assert "Value" in r + assert "10" in r + + def test_isinstance_checks(self): + """isinstance checks must work (used extensively by executor/commands).""" + v = Value(size=1, flags=ResponseFlags(), value=None) + s = Success(ResponseFlags()) + m = Miss() + ns = NotStored() + c = Conflict() + + assert isinstance(v, Value) + assert isinstance(s, Success) + assert isinstance(m, Miss) + assert isinstance(ns, NotStored) + assert isinstance(c, Conflict) + + # Value is NOT a subclass of Success in Rust version + assert not isinstance(v, Success) + assert not isinstance(s, Value) + assert not isinstance(m, Value)