diff --git a/python/python/lance_context/api.py b/python/python/lance_context/api.py index 961b21c..484bb21 100644 --- a/python/python/lance_context/api.py +++ b/python/python/lance_context/api.py @@ -236,13 +236,14 @@ def add( content: Any, content_type: str | None = None, data_type: str | None = None, + embedding: list[float] | None = None, ) -> None: if content_type is not None and data_type is not None: raise ValueError("Specify only one of content_type or data_type") if content_type is None: content_type = data_type payload, resolved_type = _normalize_content(content, content_type) - self._inner.add(role, payload, resolved_type) + self._inner.add(role, payload, resolved_type, embedding) def snapshot(self, label: str | None = None) -> str: return self._inner.snapshot(label) diff --git a/python/src/lib.rs b/python/src/lib.rs index c87972e..6cf765d 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -153,13 +153,14 @@ impl Context { self.store.version() } - #[pyo3(signature = (role, content, data_type = None))] + #[pyo3(signature = (role, content, data_type = None, embedding = None))] fn add( &mut self, py: Python<'_>, role: &str, content: &Bound<'_, PyAny>, data_type: Option<&str>, + embedding: Option>, ) -> PyResult<()> { let (content_type, text_payload, binary_payload, inner_content) = match content.extract::<&[u8]>() { @@ -190,7 +191,7 @@ impl Context { content_type, text_payload, binary_payload, - embedding: None, + embedding, }; let add_res = py.allow_threads(|| { diff --git a/python/tests/test_search.py b/python/tests/test_search.py index 2bd8b8b..22c5d18 100644 --- a/python/tests/test_search.py +++ b/python/tests/test_search.py @@ -1,4 +1,5 @@ from datetime import datetime +from typing import Any import pytest from lance_context.api import Context, _coerce_vector, _normalize_record, _normalize_search_hit @@ -8,6 +9,10 @@ class DummyInner: def __init__(self) -> None: self.search_calls: list[tuple[list[float], int | None]] = [] self.list_calls: list[tuple[int | None, int | None]] = [] + self.add_calls: list[tuple[str, Any, str | None, list[float] | None]] = [] + + def add(self, role: str, content: Any, data_type: str | None, embedding: list[float] | None): + self.add_calls.append((role, content, data_type, embedding)) def search(self, vector: list[float], limit: int | None): self.search_calls.append((vector, limit)) @@ -140,3 +145,48 @@ def test_context_list_default_args(): ctx.list() assert dummy.list_calls == [(None, None)] + + +def test_context_add_with_embedding(): + ctx = Context.__new__(Context) + dummy = DummyInner() + ctx._inner = dummy # type: ignore[attr-defined] + + embedding = [0.1, 0.2, 0.3] + ctx.add("user", "hello", embedding=embedding) + + assert len(dummy.add_calls) == 1 + role, content, data_type, passed_embedding = dummy.add_calls[0] + assert role == "user" + assert content == "hello" + assert data_type is None + assert passed_embedding == [0.1, 0.2, 0.3] + + +def test_context_add_without_embedding(): + ctx = Context.__new__(Context) + dummy = DummyInner() + ctx._inner = dummy # type: ignore[attr-defined] + + ctx.add("assistant", "world") + + assert len(dummy.add_calls) == 1 + role, content, data_type, passed_embedding = dummy.add_calls[0] + assert role == "assistant" + assert content == "world" + assert passed_embedding is None + + +def test_context_add_with_content_type_and_embedding(): + ctx = Context.__new__(Context) + dummy = DummyInner() + ctx._inner = dummy # type: ignore[attr-defined] + + embedding = [0.5, 0.6] + ctx.add("system", "prompt", content_type="text/markdown", embedding=embedding) + + assert len(dummy.add_calls) == 1 + role, content, data_type, passed_embedding = dummy.add_calls[0] + assert role == "system" + assert data_type == "text/markdown" + assert passed_embedding == [0.5, 0.6]