diff --git a/crates/lance-context-core/src/record.rs b/crates/lance-context-core/src/record.rs index 4e3aa49..e5729eb 100644 --- a/crates/lance-context-core/src/record.rs +++ b/crates/lance-context-core/src/record.rs @@ -14,6 +14,8 @@ pub struct StateMetadata { pub struct ContextRecord { pub id: String, pub run_id: String, + pub bot_id: Option, + pub session_id: Option, pub created_at: DateTime, pub role: String, pub state_metadata: Option, diff --git a/crates/lance-context-core/src/store.rs b/crates/lance-context-core/src/store.rs index c24da5d..7698944 100644 --- a/crates/lance-context-core/src/store.rs +++ b/crates/lance-context-core/src/store.rs @@ -404,6 +404,8 @@ impl ContextStore { Schema::new(vec![ Field::new("id", DataType::Utf8, false), Field::new("run_id", DataType::Utf8, false), + Field::new("bot_id", DataType::Utf8, true), + Field::new("session_id", DataType::Utf8, true), Field::new( "created_at", DataType::Timestamp(TimeUnit::Microsecond, None), @@ -485,6 +487,8 @@ impl ContextStore { fn records_to_batch(entries: &[ContextRecord]) -> LanceResult { let mut id_builder = StringBuilder::new(); let mut run_id_builder = StringBuilder::new(); + let mut bot_id_builder = StringBuilder::new(); + let mut session_id_builder = StringBuilder::new(); let mut created_at_builder = TimestampMicrosecondBuilder::with_capacity(entries.len()); let mut role_builder = StringDictionaryBuilder::::new(); let mut content_type_builder = StringBuilder::new(); @@ -513,6 +517,8 @@ impl ContextStore { for entry in entries { id_builder.append_value(&entry.id); run_id_builder.append_value(&entry.run_id); + bot_id_builder.append_option(entry.bot_id.as_deref()); + session_id_builder.append_option(entry.session_id.as_deref()); created_at_builder.append_value(entry.created_at.timestamp_micros()); role_builder.append(&entry.role)?; content_type_builder.append_value(&entry.content_type); @@ -593,6 +599,8 @@ impl ContextStore { let id_array: ArrayRef = Arc::new(id_builder.finish()); let run_id_array: ArrayRef = Arc::new(run_id_builder.finish()); + let bot_id_array: ArrayRef = Arc::new(bot_id_builder.finish()); + let session_id_array: ArrayRef = Arc::new(session_id_builder.finish()); let created_at_array: ArrayRef = Arc::new(created_at_builder.finish()); let role_array: ArrayRef = Arc::new(role_builder.finish()); let content_type_array: ArrayRef = Arc::new(content_type_builder.finish()); @@ -607,6 +615,8 @@ impl ContextStore { vec![ id_array, run_id_array, + bot_id_array, + session_id_array, created_at_array, role_array, state_array, @@ -664,6 +674,8 @@ fn batch_to_search_results(batch: &RecordBatch) -> LanceResult fn batch_to_records(batch: &RecordBatch) -> LanceResult> { let id_array = column_as::(batch, "id")?; let run_id_array = column_as::(batch, "run_id")?; + let bot_id_array = column_as_optional::(batch, "bot_id"); + let session_id_array = column_as_optional::(batch, "session_id"); let created_at_array = column_as::(batch, "created_at")?; let role_array = column_as::>(batch, "role")?; let state_array = column_as::(batch, "state_metadata")?; @@ -786,9 +798,27 @@ fn batch_to_records(batch: &RecordBatch) -> LanceResult> { role_values.value(key).to_string() }; + let bot_id = bot_id_array.and_then(|arr| { + if arr.is_null(row) { + None + } else { + Some(arr.value(row).to_string()) + } + }); + + let session_id = session_id_array.and_then(|arr| { + if arr.is_null(row) { + None + } else { + Some(arr.value(row).to_string()) + } + }); + results.push(ContextRecord { id: id_array.value(row).to_string(), run_id: run_id_array.value(row).to_string(), + bot_id, + session_id, created_at, role, state_metadata, @@ -836,6 +866,15 @@ where }) } +fn column_as_optional<'a, A>(batch: &'a RecordBatch, name: &str) -> Option<&'a A> +where + A: Array + 'static, +{ + batch + .column_by_name(name) + .and_then(|col| col.as_ref().as_any().downcast_ref::()) +} + #[cfg(test)] mod tests { use super::*; @@ -855,6 +894,8 @@ mod tests { ContextRecord { id: id.to_string(), run_id: format!("run-{id}"), + bot_id: None, + session_id: None, created_at: Utc::now(), role: "user".to_string(), state_metadata: Some(StateMetadata { diff --git a/python/python/lance_context/api.py b/python/python/lance_context/api.py index 484bb21..8df1d03 100644 --- a/python/python/lance_context/api.py +++ b/python/python/lance_context/api.py @@ -115,6 +115,8 @@ def _normalize_record(raw: dict[str, Any]) -> dict[str, Any]: return { "id": raw.get("id"), "run_id": raw.get("run_id"), + "bot_id": raw.get("bot_id"), + "session_id": raw.get("session_id"), "role": raw.get("role"), "content_type": raw.get("content_type"), "text": raw.get("text_payload"), @@ -237,13 +239,15 @@ def add( content_type: str | None = None, data_type: str | None = None, embedding: list[float] | None = None, + bot_id: str | None = None, + session_id: str | None = None, ) -> None: if content_type is not None and data_type is not None: raise ValueError("Specify only one of content_type or data_type") if content_type is None: content_type = data_type payload, resolved_type = _normalize_content(content, content_type) - self._inner.add(role, payload, resolved_type, embedding) + self._inner.add(role, payload, resolved_type, embedding, bot_id, session_id) def snapshot(self, label: str | None = None) -> str: return self._inner.snapshot(label) diff --git a/python/src/lib.rs b/python/src/lib.rs index 6cf765d..94dcb07 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -153,7 +153,8 @@ impl Context { self.store.version() } - #[pyo3(signature = (role, content, data_type = None, embedding = None))] + #[allow(clippy::too_many_arguments)] + #[pyo3(signature = (role, content, data_type = None, embedding = None, bot_id = None, session_id = None))] fn add( &mut self, py: Python<'_>, @@ -161,6 +162,8 @@ impl Context { content: &Bound<'_, PyAny>, data_type: Option<&str>, embedding: Option>, + bot_id: Option, + session_id: Option, ) -> PyResult<()> { let (content_type, text_payload, binary_payload, inner_content) = match content.extract::<&[u8]>() { @@ -185,6 +188,8 @@ impl Context { let record = ContextRecord { id: record_id, run_id: self.run_id.clone(), + bot_id, + session_id, created_at: Utc::now(), role: role.to_string(), state_metadata: None, @@ -345,6 +350,8 @@ fn record_to_py(py: Python<'_>, record: ContextRecord) -> PyResult { let ContextRecord { id, run_id, + bot_id, + session_id, created_at, role, state_metadata, @@ -357,6 +364,8 @@ fn record_to_py(py: Python<'_>, record: ContextRecord) -> PyResult { let dict = PyDict::new(py); dict.set_item("id", id)?; dict.set_item("run_id", run_id)?; + dict.set_item("bot_id", bot_id)?; + dict.set_item("session_id", session_id)?; dict.set_item( "created_at", created_at.to_rfc3339_opts(SecondsFormat::Micros, true), diff --git a/python/tests/test_search.py b/python/tests/test_search.py index 22c5d18..da3557e 100644 --- a/python/tests/test_search.py +++ b/python/tests/test_search.py @@ -9,10 +9,18 @@ class DummyInner: def __init__(self) -> None: self.search_calls: list[tuple[list[float], int | None]] = [] self.list_calls: list[tuple[int | None, int | None]] = [] - self.add_calls: list[tuple[str, Any, str | None, list[float] | None]] = [] - - def add(self, role: str, content: Any, data_type: str | None, embedding: list[float] | None): - self.add_calls.append((role, content, data_type, embedding)) + self.add_calls: list[tuple[str, Any, str | None, list[float] | None, str | None, str | None]] = [] + + def add( + self, + role: str, + content: Any, + data_type: str | None, + embedding: list[float] | None, + bot_id: str | None, + session_id: str | None, + ): + self.add_calls.append((role, content, data_type, embedding, bot_id, session_id)) def search(self, vector: list[float], limit: int | None): self.search_calls.append((vector, limit)) @@ -20,6 +28,8 @@ def search(self, vector: list[float], limit: int | None): { "id": "rec-1", "run_id": "run-1", + "bot_id": "support_bot", + "session_id": None, "role": "user", "content_type": "text/plain", "text_payload": "hello", @@ -37,6 +47,8 @@ def list(self, limit: int | None, offset: int | None): { "id": "rec-1", "run_id": "run-1", + "bot_id": "support_bot", + "session_id": "user_1", "role": "user", "content_type": "text/plain", "text_payload": "hello", @@ -48,6 +60,8 @@ def list(self, limit: int | None, offset: int | None): { "id": "rec-2", "run_id": "run-1", + "bot_id": None, + "session_id": None, "role": "assistant", "content_type": "text/plain", "text_payload": "world", @@ -156,11 +170,13 @@ def test_context_add_with_embedding(): ctx.add("user", "hello", embedding=embedding) assert len(dummy.add_calls) == 1 - role, content, data_type, passed_embedding = dummy.add_calls[0] + role, content, data_type, passed_embedding, bot_id, session_id = dummy.add_calls[0] assert role == "user" assert content == "hello" assert data_type is None assert passed_embedding == [0.1, 0.2, 0.3] + assert bot_id is None + assert session_id is None def test_context_add_without_embedding(): @@ -171,10 +187,12 @@ def test_context_add_without_embedding(): ctx.add("assistant", "world") assert len(dummy.add_calls) == 1 - role, content, data_type, passed_embedding = dummy.add_calls[0] + role, content, data_type, passed_embedding, bot_id, session_id = dummy.add_calls[0] assert role == "assistant" assert content == "world" assert passed_embedding is None + assert bot_id is None + assert session_id is None def test_context_add_with_content_type_and_embedding(): @@ -186,7 +204,89 @@ def test_context_add_with_content_type_and_embedding(): ctx.add("system", "prompt", content_type="text/markdown", embedding=embedding) assert len(dummy.add_calls) == 1 - role, content, data_type, passed_embedding = dummy.add_calls[0] + role, content, data_type, passed_embedding, bot_id, session_id = dummy.add_calls[0] assert role == "system" assert data_type == "text/markdown" assert passed_embedding == [0.5, 0.6] + assert bot_id is None + assert session_id is None + + +def test_context_add_with_bot_id(): + ctx = Context.__new__(Context) + dummy = DummyInner() + ctx._inner = dummy # type: ignore[attr-defined] + + ctx.add("user", "hello", bot_id="support_bot") + + assert len(dummy.add_calls) == 1 + role, content, data_type, passed_embedding, bot_id, session_id = dummy.add_calls[0] + assert role == "user" + assert content == "hello" + assert bot_id == "support_bot" + assert session_id is None + + +def test_context_add_with_session_id(): + ctx = Context.__new__(Context) + dummy = DummyInner() + ctx._inner = dummy # type: ignore[attr-defined] + + ctx.add("user", "hello", session_id="user_123") + + assert len(dummy.add_calls) == 1 + role, content, data_type, passed_embedding, bot_id, session_id = dummy.add_calls[0] + assert role == "user" + assert content == "hello" + assert bot_id is None + assert session_id == "user_123" + + +def test_context_add_with_agent_and_session_id(): + ctx = Context.__new__(Context) + dummy = DummyInner() + ctx._inner = dummy # type: ignore[attr-defined] + + ctx.add("user", "hello", bot_id="sales_bot", session_id="conv_456") + + assert len(dummy.add_calls) == 1 + role, content, data_type, passed_embedding, bot_id, session_id = dummy.add_calls[0] + assert role == "user" + assert bot_id == "sales_bot" + assert session_id == "conv_456" + + +def test_context_add_with_all_options(): + ctx = Context.__new__(Context) + dummy = DummyInner() + ctx._inner = dummy # type: ignore[attr-defined] + + embedding = [0.1, 0.2] + ctx.add("user", "hello", embedding=embedding, bot_id="bot", session_id="sess") + + assert len(dummy.add_calls) == 1 + role, content, data_type, passed_embedding, bot_id, session_id = dummy.add_calls[0] + assert role == "user" + assert passed_embedding == [0.1, 0.2] + assert bot_id == "bot" + assert session_id == "sess" + + +def test_normalize_record_with_agent_and_session_id(): + result = _normalize_record( + { + "id": "rec-1", + "created_at": "2024-01-01T00:00:00Z", + "content_type": "text/plain", + "text_payload": "hello", + "binary_payload": None, + "embedding": None, + "run_id": "run-1", + "bot_id": "support_bot", + "session_id": "user_88", + "role": "user", + "state_metadata": None, + } + ) + assert result["bot_id"] == "support_bot" + assert result["session_id"] == "user_88"