From d6f297be8dc2da92afce59a714406287d05fec73 Mon Sep 17 00:00:00 2001 From: Allen Date: Wed, 21 Jan 2026 09:22:08 -0800 Subject: [PATCH] feat: add list() method to retrieve stored entries Adds list(limit, offset) API across all layers: - ContextStore::list() in Rust core - PyO3 bindings in lib.rs - Python Context.list() wrapper Refactors existing code to share record conversion logic: - batch_to_records extracted from batch_to_search_results - record_to_py extracted from search_hit_to_py - _normalize_record extracted from _normalize_search_hit Includes unit tests for the new functionality. --- crates/lance-context-core/src/store.rs | 60 ++++++++++++++----- python/python/lance_context/api.py | 27 ++++++++- python/src/lib.rs | 25 +++++++- python/tests/test_search.py | 83 ++++++++++++++++++++++++-- 4 files changed, 173 insertions(+), 22 deletions(-) diff --git a/crates/lance-context-core/src/store.rs b/crates/lance-context-core/src/store.rs index 381ba8e..bb5a8fa 100644 --- a/crates/lance-context-core/src/store.rs +++ b/crates/lance-context-core/src/store.rs @@ -91,6 +91,27 @@ impl ContextStore { Ok(()) } + /// List all records in the dataset. + pub async fn list( + &self, + limit: Option, + offset: Option, + ) -> LanceResult> { + let mut scanner = self.dataset.scan(); + if let Some(limit) = limit { + scanner.limit(Some(limit as i64), offset.map(|o| o as i64))?; + } else if let Some(offset) = offset { + scanner.limit(None, Some(offset as i64))?; + } + + let mut stream = scanner.try_into_stream().await?; + let mut results = Vec::new(); + while let Some(batch) = stream.try_next().await? { + results.extend(batch_to_records(&batch)?); + } + Ok(results) + } + /// Perform a nearest-neighbor search over stored embeddings. pub async fn search( &self, @@ -348,15 +369,7 @@ impl ContextStore { } fn batch_to_search_results(batch: &RecordBatch) -> LanceResult> { - let id_array = column_as::(batch, "id")?; - let run_id_array = column_as::(batch, "run_id")?; - let created_at_array = column_as::(batch, "created_at")?; - let role_array = column_as::>(batch, "role")?; - let state_array = column_as::(batch, "state_metadata")?; - let content_type_array = column_as::(batch, "content_type")?; - let text_array = column_as::(batch, "text_payload")?; - let binary_array = column_as::(batch, "binary_payload")?; - let embedding_array = column_as::(batch, "embedding")?; + let records = batch_to_records(batch)?; let distance_column = batch.column_by_name("_distance").ok_or_else(|| { LanceError::from(ArrowError::InvalidArgumentError( @@ -373,6 +386,28 @@ fn batch_to_search_results(batch: &RecordBatch) -> LanceResult )) })?; + Ok(records + .into_iter() + .enumerate() + .map(|(i, record)| SearchResult { + record, + distance: distance_array.value(i), + }) + .collect()) +} + +/// Convert a record batch to context records. +fn batch_to_records(batch: &RecordBatch) -> LanceResult> { + let id_array = column_as::(batch, "id")?; + let run_id_array = column_as::(batch, "run_id")?; + let created_at_array = column_as::(batch, "created_at")?; + let role_array = column_as::>(batch, "role")?; + let state_array = column_as::(batch, "state_metadata")?; + let content_type_array = column_as::(batch, "content_type")?; + let text_array = column_as::(batch, "text_payload")?; + let binary_array = column_as::(batch, "binary_payload")?; + let embedding_array = column_as::(batch, "embedding")?; + let step_array = state_array .column(0) .as_ref() @@ -487,7 +522,7 @@ fn batch_to_search_results(batch: &RecordBatch) -> LanceResult role_values.value(key).to_string() }; - let record = ContextRecord { + results.push(ContextRecord { id: id_array.value(row).to_string(), run_id: run_id_array.value(row).to_string(), created_at, @@ -497,11 +532,6 @@ fn batch_to_search_results(batch: &RecordBatch) -> LanceResult text_payload, binary_payload, embedding, - }; - - results.push(SearchResult { - record, - distance: distance_array.value(row), }); } diff --git a/python/python/lance_context/api.py b/python/python/lance_context/api.py index 9f33d5e..2b74833 100644 --- a/python/python/lance_context/api.py +++ b/python/python/lance_context/api.py @@ -107,7 +107,8 @@ def _coerce_vector(query: Any) -> list[float]: raise TypeError("search query must be a sequence of floats") -def _normalize_search_hit(raw: dict[str, Any]) -> dict[str, Any]: +def _normalize_record(raw: dict[str, Any]) -> dict[str, Any]: + """Normalize a raw record dict from the Rust layer.""" created_at = raw.get("created_at") if isinstance(created_at, str): created_at = datetime.fromisoformat(created_at.replace("Z", "+00:00")) @@ -119,12 +120,18 @@ def _normalize_search_hit(raw: dict[str, Any]) -> dict[str, Any]: "text": raw.get("text_payload"), "binary": raw.get("binary_payload"), "embedding": raw.get("embedding"), - "distance": raw.get("distance"), "created_at": created_at, "state_metadata": raw.get("state_metadata"), } +def _normalize_search_hit(raw: dict[str, Any]) -> dict[str, Any]: + """Normalize a search hit - adds distance to the base record.""" + result = _normalize_record(raw) + result["distance"] = raw.get("distance") + return result + + class Context: def __init__( self, @@ -222,6 +229,22 @@ def search(self, query: Any, limit: int | None = None) -> list[dict[str, Any]]: results = self._inner.search(vector, limit) return [_normalize_search_hit(item) for item in results] + def list( + self, limit: int | None = None, offset: int | None = None + ) -> list[dict[str, Any]]: + """Return stored entries. + + Args: + limit: Maximum number of entries to return. If None, returns all. + offset: Number of entries to skip before returning results. + + Returns: + List of entry dicts with keys: id, run_id, role, content_type, + text, binary, embedding, created_at, state_metadata. + """ + results = self._inner.list(limit, offset) + return [_normalize_record(item) for item in results] + def __repr__(self) -> str: return ( f"Context(uri={self._inner.uri()!r}, " diff --git a/python/src/lib.rs b/python/src/lib.rs index 30d8e44..64295e9 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -191,6 +191,23 @@ impl Context { .map(|hit| search_hit_to_py(py, hit)) .collect() } + + #[pyo3(signature = (limit = None, offset = None))] + fn list( + &self, + py: Python<'_>, + limit: Option, + offset: Option, + ) -> PyResult> { + let records = self + .runtime + .block_on(self.store.list(limit, offset)) + .map_err(to_py_err)?; + records + .into_iter() + .map(|record| record_to_py(py, record)) + .collect() + } } fn new_run_id() -> String { @@ -203,6 +220,13 @@ fn new_run_id() -> String { fn search_hit_to_py(py: Python<'_>, hit: SearchResult) -> PyResult { let SearchResult { record, distance } = hit; + let dict = record_to_py(py, record)?; + let dict_ref = dict.downcast_bound::(py)?; + dict_ref.set_item("distance", distance)?; + Ok(dict) +} + +fn record_to_py(py: Python<'_>, record: ContextRecord) -> PyResult { let ContextRecord { id, run_id, @@ -243,7 +267,6 @@ fn search_hit_to_py(py: Python<'_>, hit: SearchResult) -> PyResult { None => dict.set_item("binary_payload", py.None())?, } dict.set_item("embedding", embedding)?; - dict.set_item("distance", distance)?; Ok(dict.into_pyobject(py)?.unbind().into()) } diff --git a/python/tests/test_search.py b/python/tests/test_search.py index 4e152a1..2bd8b8b 100644 --- a/python/tests/test_search.py +++ b/python/tests/test_search.py @@ -1,15 +1,16 @@ from datetime import datetime import pytest -from lance_context.api import Context, _coerce_vector, _normalize_search_hit +from lance_context.api import Context, _coerce_vector, _normalize_record, _normalize_search_hit class DummyInner: def __init__(self) -> None: - self.calls: list[tuple[list[float], int | None]] = [] + self.search_calls: list[tuple[list[float], int | None]] = [] + self.list_calls: list[tuple[int | None, int | None]] = [] def search(self, vector: list[float], limit: int | None): - self.calls.append((vector, limit)) + self.search_calls.append((vector, limit)) return [ { "id": "rec-1", @@ -25,6 +26,33 @@ def search(self, vector: list[float], limit: int | None): } ] + def list(self, limit: int | None, offset: int | None): + self.list_calls.append((limit, offset)) + return [ + { + "id": "rec-1", + "run_id": "run-1", + "role": "user", + "content_type": "text/plain", + "text_payload": "hello", + "binary_payload": None, + "embedding": [0.1, 0.2], + "created_at": "2024-01-01T12:00:00Z", + "state_metadata": {"step": 1}, + }, + { + "id": "rec-2", + "run_id": "run-1", + "role": "assistant", + "content_type": "text/plain", + "text_payload": "world", + "binary_payload": None, + "embedding": None, + "created_at": "2024-01-02T12:00:00Z", + "state_metadata": None, + }, + ] + def test_coerce_vector_from_list(): assert _coerce_vector([1, 2.5]) == [1.0, 2.5] @@ -60,8 +88,55 @@ def test_context_search_formats_results(): hits = ctx.search([0.5, 0.4], limit=3) - assert dummy.calls == [([0.5, 0.4], 3)] + assert dummy.search_calls == [([0.5, 0.4], 3)] assert hits[0]["id"] == "rec-1" assert hits[0]["text"] == "hello" assert hits[0]["binary"] is None assert isinstance(hits[0]["created_at"], datetime) + + +def test_normalize_record_without_distance(): + result = _normalize_record( + { + "id": "rec-1", + "created_at": "2024-01-01T00:00:00Z", + "content_type": "text/plain", + "text_payload": "hello", + "binary_payload": None, + "embedding": None, + "run_id": "run-1", + "role": "user", + "state_metadata": None, + } + ) + assert "distance" not in result + assert result["text"] == "hello" + assert isinstance(result["created_at"], datetime) + + +def test_context_list_returns_entries(): + ctx = Context.__new__(Context) + dummy = DummyInner() + ctx._inner = dummy # type: ignore[attr-defined] + + entries = ctx.list(limit=10, offset=5) + + assert dummy.list_calls == [(10, 5)] + assert len(entries) == 2 + assert entries[0]["id"] == "rec-1" + assert entries[0]["text"] == "hello" + assert entries[0]["role"] == "user" + assert "distance" not in entries[0] + assert entries[1]["id"] == "rec-2" + assert entries[1]["text"] == "world" + assert isinstance(entries[0]["created_at"], datetime) + + +def test_context_list_default_args(): + ctx = Context.__new__(Context) + dummy = DummyInner() + ctx._inner = dummy # type: ignore[attr-defined] + + ctx.list() + + assert dummy.list_calls == [(None, None)]