diff --git a/src/_ravnar/api/threads.py b/src/_ravnar/api/threads.py index 7e7e0d8..db61bdd 100644 --- a/src/_ravnar/api/threads.py +++ b/src/_ravnar/api/threads.py @@ -1,12 +1,10 @@ from __future__ import annotations import base64 -import uuid from collections.abc import Callable from typing import TYPE_CHECKING, Annotated, Any import ag_ui.core -import ag_ui.encoder import fastsse import pydantic from fastapi import Depends, HTTPException, Path, Query, status @@ -15,7 +13,7 @@ from _ravnar import schema from _ravnar.file_storage import FileHandler, WrappedMetadata from _ravnar.observability import traced -from _ravnar.utils import as_awaitable, now +from _ravnar.utils import as_awaitable tracer = trace.get_tracer(__name__) @@ -26,8 +24,10 @@ from . import AgentHandler ThreadsSortBy = str + RunsSortBy = str else: - ThreadsSortBy = schema.create_str_literal("created_at", "updated_at", default="created_at") + ThreadsSortBy = schema.create_str_literal("created_at", default="created_at") + RunsSortBy = schema.create_str_literal("created_at", default="created_at") def make_router( @@ -70,35 +70,66 @@ async def get_thread( @router.get("/{threadId}/messages") async def get_thread_messages( - id: Annotated[str, Path(alias="threadId")], + thread_id: Annotated[str, Path(alias="threadId")], user: schema.User = Depends(authenticated_user), # noqa: B008 ) -> list[schema.AugmentedMessage]: - thread = await database.get_thread(user_id=user.id, id=id, with_messages=True) - return pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_python( - thread.messages, from_attributes=True + _, _, messages = await database.get_thread_history(user_id=user.id, thread_id=thread_id, run_id=None) + return pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_python(messages, from_attributes=True) + + @router.get("/{threadId}/runs") + async def get_runs( + *, + user: schema.User = Depends(authenticated_user), # noqa: B008 + thread_id: Annotated[str, Path(alias="threadId")], + pagination: Annotated[schema.Pagination[RunsSortBy], Query()], + ) -> schema.Page[schema.Run]: + return schema.Page[schema.Run].model_validate( + await database.get_runs(user_id=user.id, thread_id=thread_id, pagination=pagination), + from_attributes=True, ) - @router.sse("/{threadId}/run", methods=["POST"], response_model=schema.Event, tags=["Runs"]) + @router.get("/{threadId}/runs/{runId}") + async def get_run( + *, + user: schema.User = Depends(authenticated_user), # noqa: B008 + thread_id: Annotated[str, Path(alias="threadId")], + run_id: Annotated[str, Path(alias="runId")], + ) -> schema.Run: + return schema.Run.model_validate(await database.get_run(id=run_id, user_id=user.id), from_attributes=True) + + @router.get("/{threadId}/runs/{runId}/messages") + async def get_run_messages( + *, + user: schema.User = Depends(authenticated_user), # noqa: B008 + thread_id: Annotated[str, Path(alias="threadId")], + run_id: Annotated[str, Path(alias="runId")], + ) -> list[schema.AugmentedMessage]: + _, _, messages = await database.get_thread_history(user_id=user.id, thread_id=thread_id, run_id=run_id) + return pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_python(messages, from_attributes=True) + + @router.sse("/{threadId}/runs", methods=["POST"], response_model=schema.Event, tags=["Runs"]) async def create_run( *, user: schema.User = Depends(authenticated_user), # noqa: B008 thread_id: Annotated[str, Path(alias="threadId")], data: schema.CreateRunData, ) -> fastsse.Response: - thread = await database.get_thread(user_id=user.id, id=thread_id, with_messages=True) + thread, parent_run, parent_messages = await database.get_thread_history( + user_id=user.id, thread_id=thread_id, run_id=data.parent_run_id + ) messages = pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_python( - thread.messages, from_attributes=True + parent_messages, from_attributes=True ) messages.extend(data.messages) await hydrate_files(messages, user=user, file_handler=file_handler) run_agent_input = ag_ui.core.RunAgentInput( - thread_id=thread.id, - run_id=str(uuid.uuid4()), - parent_run_id=None, - state=thread.state, + thread_id=thread_id, + run_id=data.id, + parent_run_id=data.parent_run_id, + state=parent_run.state if parent_run is not None else None, messages=[pydantic.TypeAdapter(ag_ui.core.Message).validate_python(m.model_dump()) for m in messages], tools=data.tools, context=data.context, @@ -106,10 +137,8 @@ async def create_run( ) async def callback(event_processor: EventProcessor) -> None: - with tracer.start_as_current_span("persist_run"): - thread.state, thread.messages = event_processor.extract() - thread.updated_at = now() - await database.update_thread(thread) + run = event_processor.extract(include_input_message_ids={m.id for m in data.messages}) + await database.create_run(run) return await agent_handler.run(thread.agent_id, run_agent_input, callback=callback) diff --git a/src/_ravnar/core.py b/src/_ravnar/core.py index 3475673..45270c7 100644 --- a/src/_ravnar/core.py +++ b/src/_ravnar/core.py @@ -145,13 +145,7 @@ async def run( self.assert_available(agent_id) agent = self._agents[agent_id] - event_processor = EventProcessor( - thread_id=run_agent_input.thread_id, - run_id=run_agent_input.run_id, - parent_run_id=run_agent_input.parent_run_id, - state=run_agent_input.state, - messages=run_agent_input.messages, - ) + event_processor = EventProcessor(run_agent_input=run_agent_input) span = tracer.start_span("AgentHandler.run") span.set_attribute("agent_id", agent_id) diff --git a/src/_ravnar/database.py b/src/_ravnar/database.py index 2b0fb78..4ab1d62 100644 --- a/src/_ravnar/database.py +++ b/src/_ravnar/database.py @@ -9,11 +9,11 @@ from fastapi import HTTPException, status from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor -from sqlalchemy import Engine, Select, asc, create_engine, desc, func, inspect, select +from sqlalchemy import Engine, Select, asc, create_engine, desc, func, inspect, literal_column, select from sqlalchemy.engine.url import make_url from sqlalchemy.exc import InvalidRequestError from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine -from sqlalchemy.orm import Session, selectinload, sessionmaker +from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm.interfaces import ORMOption from starlette.concurrency import run_in_threadpool from typing_extensions import TypedDict @@ -166,16 +166,13 @@ async def create_thread(self, *, user_id: str, id: str, name: str | None, agent_ if thread is not None: raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Thread exists") - created_at = updated_at = now() thread = orm.Thread( id=id, user_id=user_id, agent_id=agent_id, name=name, - created_at=created_at, - updated_at=updated_at, - state=None, - messages=[], + created_at=now(), + runs=[], ) session.add(thread) return thread @@ -202,10 +199,8 @@ async def get_threads(self, *, user_id: str, pagination: schema.Pagination) -> o async with self._get_session() as session: return await self._get_threads(session, user_id=user_id, pagination=pagination) - async def _get_thread(self, session: AsyncSession, *, user_id: str, id: str, with_messages: bool) -> orm.Thread: + async def _get_thread(self, session: AsyncSession, *, user_id: str, id: str) -> orm.Thread: query = select(orm.Thread).where((orm.Thread.id == id) & (orm.Thread.user_id == user_id)) - if with_messages: - query = query.options(selectinload(orm.Thread.messages)) result = await session.execute(query) thread = result.scalar_one_or_none() if thread is None: @@ -213,27 +208,117 @@ async def _get_thread(self, session: AsyncSession, *, user_id: str, id: str, wit return thread @traced - async def get_thread(self, *, user_id: str, id: str, with_messages: bool = False) -> orm.Thread: + async def get_thread(self, *, user_id: str, id: str) -> orm.Thread: async with self._get_session() as session: - return await self._get_thread(session, user_id=user_id, id=id, with_messages=with_messages) - - @traced - async def append_messages_to_thread(self, *, user_id: str, id: str, messages: list[orm.Message]) -> None: - async with self._get_session() as session: - thread = await self._get_thread(session, user_id=user_id, id=id, with_messages=False) - thread.messages.extend(messages) + return await self._get_thread(session, user_id=user_id, id=id) @traced async def rename_thread(self, *, user_id: str, id: str, name: str) -> orm.Thread: async with self._get_session() as session: - thread = await self._get_thread(session, user_id=user_id, id=id, with_messages=False) + thread = await self._get_thread(session, user_id=user_id, id=id) thread.name = name return thread @traced - async def update_thread(self, thread: orm.Thread) -> None: + async def create_run(self, run: orm.Run) -> None: + async with self._get_session() as session: + session.add(run) + + async def _get_run(self, session: AsyncSession, *, id: str, user_id: str) -> orm.Run: + query = ( + select(orm.Run) + .join(orm.Thread, orm.Run.thread_id == orm.Thread.id) + .where((orm.Run.id == id) & (orm.Thread.user_id == user_id)) + ) + result = await session.execute(query) + run = result.scalar_one_or_none() + if run is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Run not found") + return run + + @traced + async def get_run(self, *, id: str, user_id: str) -> orm.Run: + async with self._get_session() as session: + return await self._get_run(session, id=id, user_id=user_id) + + @traced + async def get_runs( + self, + *, + user_id: str, + thread_id: str, + pagination: schema.Pagination | None = None, + ) -> orm.Page[orm.Run]: + async with self._get_session() as session: + + def select_qualifier(query: Select) -> Select: + return query.join(orm.Thread, orm.Run.thread_id == orm.Thread.id).where( + (orm.Run.thread_id == thread_id) & (orm.Thread.user_id == user_id) + ) + + return await self._get_page( + session, orm_type=orm.Run, select_qualifier=select_qualifier, pagination=pagination + ) + + async def _get_thread_messages(self, session: AsyncSession, *, run_id: str) -> list[orm.Message]: + run_chain = ( + select( + orm.Run.id.label("run_id"), + orm.Run.parent_run_id.label("parent_run_id"), + literal_column("0").label("depth"), + ) + .where(orm.Run.id == run_id) + .cte(name="run_chain", recursive=True) + ) + run_chain = run_chain.union_all( + select(orm.Run.id, orm.Run.parent_run_id, (run_chain.c.depth + 1).label("depth")) + .select_from(orm.Run) + .join(run_chain, orm.Run.id == run_chain.c.parent_run_id) + ) + + ranked = ( + select( + orm.Message.uid, + func.row_number() + .over( + partition_by=orm.Message.id, + order_by=run_chain.c.depth.asc(), + ) + .label("rn"), + ) + .join(run_chain, orm.Message.run_id == run_chain.c.run_id) + .subquery() + ) + + query = ( + select(orm.Message) + .join(ranked, orm.Message.uid == ranked.c.uid) + .where(ranked.c.rn == 1) + .order_by(orm.Message.created_at.asc(), orm.Message.id.asc()) + ) + + result = await session.execute(query) + return list(result.unique().scalars().all()) + + @traced + async def get_thread_history( + self, *, user_id: str, thread_id: str, run_id: str | None + ) -> tuple[orm.Thread, orm.Run | None, list[orm.Message]]: async with self._get_session() as session: - await session.merge(thread) + thread = await self._get_thread(session, user_id=user_id, id=thread_id) + + if run_id is None: + if not thread.runs: + return thread, None, [] + run = thread.runs[-1] + else: + try: + run = next(r for r in thread.runs if r.id == run_id) + except StopIteration: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Parent run not found") from None + + messages = await self._get_thread_messages(session, run_id=run.id) + return thread, run, messages @traced async def delete_threads(self, *, user_id: str, ids: list[str]) -> None: diff --git a/src/_ravnar/events.py b/src/_ravnar/events.py index c219ed5..9be1cdc 100644 --- a/src/_ravnar/events.py +++ b/src/_ravnar/events.py @@ -4,7 +4,7 @@ import enum import time import uuid -from collections.abc import AsyncIterable, AsyncIterator +from collections.abc import AsyncIterable, AsyncIterator, Collection from datetime import UTC, datetime from typing import Any, TypeVar, cast @@ -74,21 +74,12 @@ class ReasoningData: class EventProcessor: - def __init__( - self, - *, - thread_id: str, - run_id: str, - parent_run_id: str | None, - state: ag_ui.core.State, - messages: list[ag_ui.core.Message], - ): - self._thread_id = thread_id - self._run_id = run_id - self._parent_run_id = parent_run_id - - self._state = state - self._messages = self._convert_messages(messages) + def __init__(self, *, run_agent_input: ag_ui.core.RunAgentInput): + self._run_agent_input = run_agent_input + + self._state = run_agent_input.state + self._input_messages = self._convert_input_messages(run_agent_input.messages) + self._messages: dict[str, orm.Message] = {} self._progress = RunProgress.NOT_STARTED self._text_message_data: dict[str, TextMessageData] = {} @@ -97,15 +88,19 @@ def __init__( self._reasoning_data: dict[str, ReasoningData] = {} self._thinking_message_id: str | None = None - self._logger = structlog.get_logger(thread_id=thread_id, run_id=run_id, parent_run_id=parent_run_id) + self._logger = structlog.get_logger( + thread_id=run_agent_input.thread_id, + run_id=run_agent_input.run_id, + parent_run_id=run_agent_input.parent_run_id, + ) + + def _convert_input_messages(self, messages: list[ag_ui.core.Message]) -> dict[str, orm.Message]: + message_uids = {m.id: uuid.uuid4() for m in messages} - def _convert_messages( - self, messages: list[ag_ui.core.Message], *, updated_at: datetime | None = None - ) -> dict[str, orm.Message]: tool_calls = { tc.id: orm.ToolCall( id=tc.id, - assistant_message_id=m.id, + assistant_message_id=message_uids[m.id], tool_message_id=None, name=tc.function.name, arguments=tc.function.arguments, @@ -138,10 +133,9 @@ def _convert_messages( text = None file_id = metadata.file_id input_contents.append( - orm.InputContent(user_message_id=m.id, index=i, text=text, file_id=file_id) + orm.InputContent(user_message_id=message_uids[m.id], index=i, text=text, file_id=file_id) ) data = {**m.model_dump(exclude={"content"}), "input_contents": input_contents} - print() case ag_ui.core.AssistantMessage(): data = { **m.model_dump(exclude={"tool_calls"}), @@ -149,15 +143,15 @@ def _convert_messages( } case ag_ui.core.ToolMessage(): tool_call = tool_calls[m.tool_call_id] - tool_call.tool_message_id = m.id + tool_call.tool_message_id = message_uids[m.id] data = {**m.model_dump(exclude={"tool_call_id"}), "tool_call": tool_call} case _: data = m.model_dump() - if updated_at is not None: - data["updated_at"] = updated_at - data["thread_id"] = self._thread_id - + data["uid"] = message_uids[m.id] + # setting the run ID to the current run here avoids the need to set it later in case the message is either + # promoted to the current run because it was mutated or it is actually an input message of the current run + data["run_id"] = self._run_agent_input.run_id converted_messages[m.id] = cls(**data) return converted_messages @@ -166,10 +160,10 @@ async def process_event_stream( self, event_stream: AsyncIterable[ag_ui.core.Event] ) -> AsyncIterator[ag_ui.core.Event]: span = tracer.start_span("EventProcessor.run") - span.set_attribute("thread_id", self._thread_id) - span.set_attribute("run_id", self._run_id) - if self._parent_run_id is not None: - span.set_attribute("parent_run_id", self._parent_run_id) + span.set_attribute("thread_id", self._run_agent_input.thread_id) + span.set_attribute("run_id", self._run_agent_input.run_id) + if self._run_agent_input.parent_run_id is not None: + span.set_attribute("parent_run_id", self._run_agent_input.parent_run_id) events = aiter(event_stream) try: @@ -223,13 +217,16 @@ def _process_event(self, event: ag_ui.core.Event) -> ag_ui.core.Event | None: return None if ( - event.thread_id != self._thread_id - or event.run_id != self._run_id - or event.parent_run_id != self._parent_run_id + event.thread_id != self._run_agent_input.thread_id + or event.run_id != self._run_agent_input.run_id + or event.parent_run_id != self._run_agent_input.parent_run_id ): lifecycle_data = event.model_dump(include={"thread_id", "run_id", "parent_run_id"}, mode="json") event = self._override_event( - event, thread_id=self._thread_id, run_id=self._run_id, parent_run_id=self._parent_run_id + event, + thread_id=self._run_agent_input.thread_id, + run_id=self._run_agent_input.run_id, + parent_run_id=self._run_agent_input.parent_run_id, ) attributes = { "state": "overridden", @@ -244,9 +241,11 @@ def _process_event(self, event: ag_ui.core.Event) -> ag_ui.core.Event | None: return self._trace_event(event, **attributes) case ag_ui.core.RunFinishedEvent(): - if event.thread_id != self._thread_id or event.run_id != self._run_id: + if event.thread_id != self._run_agent_input.thread_id or event.run_id != self._run_agent_input.run_id: lifecycle_data = event.model_dump(include={"thread_id", "run_id"}, mode="json") - event = self._override_event(event, thread_id=self._thread_id, run_id=self._run_id) + event = self._override_event( + event, thread_id=self._run_agent_input.thread_id, run_id=self._run_agent_input.run_id + ) self._progress = RunProgress.FINISHED return self._trace_event( event, @@ -391,7 +390,9 @@ def _process_event(self, event: ag_ui.core.Event) -> ag_ui.core.Event | None: # return ag_ui.core.MessagesSnapshotEvent(messages=messages, timestamp=event.timestamp, raw_event=event) # activity events case ag_ui.core.ActivitySnapshotEvent(): - if event.message_id in self._messages and not event.replace: + if not event.replace and ( + event.message_id in self._messages or event.message_id in self._input_messages + ): self._trace_event( event, state="dropped", @@ -402,7 +403,7 @@ def _process_event(self, event: ag_ui.core.Event) -> ag_ui.core.Event | None: return None self._messages[event.message_id] = orm.ActivityMessage( id=event.message_id, - thread_id=self._thread_id, + run_id=self._run_agent_input.run_id, created_at=parse_timestamp(event.timestamp), content=event.content, activity_type=event.activity_type, @@ -410,6 +411,8 @@ def _process_event(self, event: ag_ui.core.Event) -> ag_ui.core.Event | None: return self._trace_event(event) case ag_ui.core.ActivityDeltaEvent(): message = self._messages.get(event.message_id) + if message is None: + message = self._input_messages.pop(event.message_id, None) if message is None: self._trace_event( event, @@ -418,6 +421,7 @@ def _process_event(self, event: ag_ui.core.Event) -> ag_ui.core.Event | None: message_id=event.message_id, ) return None + if not isinstance(message, orm.ActivityMessage): self._trace_event( event, @@ -438,7 +442,6 @@ def _process_event(self, event: ag_ui.core.Event) -> ag_ui.core.Event | None: self._trace_event(event, state="dropped", reason="invalid JSON patches") return None - message.updated_at = parse_timestamp(event.timestamp) message.content = content return self._trace_event( event, @@ -549,13 +552,17 @@ def _apply_jsonpatch(document: dict[str, Any], patches: list[Any]) -> Any: except jsonpatch.JsonPatchException: return None - def extract(self) -> tuple[orm.State, list[orm.Message]]: - return self._state, self._extract_messages() - - def _extract_messages(self) -> list[orm.Message]: - tool_calls: dict[str, orm.ToolCall] = {} - tool_calls_created_at: dict[str, datetime] = {} - grouped_tool_calls: dict[str, list[orm.ToolCall]] = {} + def extract(self, *, include_input_message_ids: Collection[str]) -> orm.Run: + return orm.Run( + id=self._run_agent_input.run_id, + thread_id=self._run_agent_input.thread_id, + parent_run_id=self._run_agent_input.parent_run_id, + state=self._state, + messages=self._extract_messages(include_input_message_ids), + ) + + def _extract_messages(self, include_input_message_ids: Collection[str]) -> list[orm.Message]: + grouped_tool_calls: dict[str, list[ToolCallData]] = {} for tcd in self._tool_call_data.values(): if not tcd.finished: span = trace.get_current_span() @@ -576,21 +583,10 @@ def _extract_messages(self) -> list[orm.Message]: parent_message_id=tcd.parent_message_id, ) continue + grouped_tool_calls.setdefault(tcd.parent_message_id, []).append(tcd) - tool_call = orm.ToolCall( - id=tcd.tool_call_id, - assistant_message_id=tcd.parent_message_id, - tool_message_id=None, - name=tcd.tool_call_name, - arguments="".join(tcd.arguments_delta), - # FIXME - encrypted_value=None, - ) - tool_calls[tool_call.id] = tool_call - tool_calls_created_at[tool_call.id] = tcd.created_at - grouped_tool_calls.setdefault(tool_call.assistant_message_id, []).append(tool_call) - - messages: list[orm.Message] = list(self._messages.values()) + # Build assistant messages so we have their UUIDs for tool call FKs + assistant_messages: dict[str, orm.AssistantMessage] = {} for tmd in self._text_message_data.values(): if not tmd.finished: @@ -602,26 +598,46 @@ def _extract_messages(self) -> list[orm.Message]: self._logger.warn("text message", state="dropped", reason="unfinished", message_id=tmd.message_id) continue - messages.append( - orm.AssistantMessage( - id=tmd.message_id, - thread_id=self._thread_id, - created_at=tmd.created_at, - content="".join(tmd.content_deltas) or None, - tool_calls=grouped_tool_calls.pop(tmd.message_id, []), - ) + assistant_messages[tmd.message_id] = orm.AssistantMessage( + uid=uuid.uuid4(), + run_id=self._run_agent_input.run_id, + id=tmd.message_id, + created_at=tmd.created_at, + content="".join(tmd.content_deltas) or None, + tool_calls=[], ) - for mid, tcs in grouped_tool_calls.items(): - messages.append( - orm.AssistantMessage( - id=mid, - thread_id=self._thread_id, - created_at=min(tool_calls_created_at[tc.id] for tc in tcs), + for parent_message_id, tcds in grouped_tool_calls.items(): + if parent_message_id not in assistant_messages: + assistant_messages[parent_message_id] = orm.AssistantMessage( + uid=uuid.uuid4(), + run_id=self._run_agent_input.run_id, + id=parent_message_id, + created_at=min(tcd.created_at for tcd in tcds), content=None, - tool_calls=tcs, + tool_calls=[], ) - ) + + tool_calls: dict[str, orm.ToolCall] = {} + for parent_message_id, tcds in grouped_tool_calls.items(): + parent_msg = assistant_messages[parent_message_id] + for tcd in tcds: + tool_call = orm.ToolCall( + id=tcd.tool_call_id, + assistant_message_id=parent_msg.uid, + tool_message_id=None, + name=tcd.tool_call_name, + arguments="".join(tcd.arguments_delta), + encrypted_value=None, + ) + tool_calls[tool_call.id] = tool_call + parent_msg.tool_calls.append(tool_call) + + messages: list[orm.Message] = [ + *(m for id, m in self._input_messages.items() if id in include_input_message_ids), + *self._messages.values(), + *assistant_messages.values(), + ] for rd in self._reasoning_data.values(): if not rd.finished: @@ -635,8 +651,9 @@ def _extract_messages(self) -> list[orm.Message]: messages.append( orm.ReasoningMessage( + uid=uuid.uuid4(), + run_id=self._run_agent_input.run_id, id=rd.message_id, - thread_id=self._thread_id, created_at=rd.created_at, content="".join(rd.content_deltas), ) @@ -657,14 +674,17 @@ def _extract_messages(self) -> list[orm.Message]: tool_call_id=trd.tool_call_id, ) continue + msg_uid = uuid.uuid4() + tool_call = tool_calls[trd.tool_call_id] + tool_call.tool_message_id = msg_uid messages.append( orm.ToolMessage( + uid=msg_uid, + run_id=self._run_agent_input.run_id, id=trd.message_id, - thread_id=self._thread_id, created_at=trd.created_at, content=trd.content, - tool_call=tool_calls[trd.tool_call_id], - # FIXME + tool_call=tool_call, error=None, encrypted_value=None, ) diff --git a/src/_ravnar/orm.py b/src/_ravnar/orm.py index 503e87f..358235e 100644 --- a/src/_ravnar/orm.py +++ b/src/_ravnar/orm.py @@ -111,12 +111,31 @@ class Thread(Base, kw_only=True, repr=False): name: Mapped[str | None] created_at: Mapped[datetime] = mapped_column(UtcAwareDateTime) - updated_at: Mapped[datetime] = mapped_column(UtcAwareDateTime) - state: Mapped[State] = mapped_column(Json, nullable=True) + runs: Mapped[list[Run]] = relationship( + "Run", + back_populates="thread", + cascade="all, delete-orphan", + order_by="Run.created_at.asc()", + lazy="selectin", + ) + + +class Run(Base, kw_only=True, repr=False): + __tablename__ = "runs" + + id: Mapped[str] = mapped_column(primary_key=True) + thread_id: Mapped[str] = mapped_column(ForeignKey("threads.id", ondelete="CASCADE"), index=True) + thread: Mapped[Thread] = relationship("Thread", back_populates="runs", init=False) + parent_run_id: Mapped[str | None] = mapped_column( + ForeignKey("runs.id", ondelete="CASCADE"), index=True, default=None + ) + created_at: Mapped[datetime] = mapped_column(UtcAwareDateTime, default_factory=now) + state: Mapped[State] = mapped_column(Json, nullable=True, default=None) + messages: Mapped[list[Message]] = relationship( "Message", - back_populates="thread", + back_populates="run", cascade="all, delete-orphan", order_by="[Message.created_at.asc(), Message.id]", ) @@ -125,13 +144,12 @@ class Thread(Base, kw_only=True, repr=False): class Message(Base, kw_only=True, repr=False): __tablename__ = "messages" - id: Mapped[str] = mapped_column(primary_key=True) - - thread_id: Mapped[str] = mapped_column(ForeignKey("threads.id"), index=True) - thread: Mapped[Thread] = relationship(init=False) + uid: Mapped[uuid.UUID] = mapped_column(types.Uuid, primary_key=True, default_factory=uuid.uuid4) + run_id: Mapped[str] = mapped_column(ForeignKey("runs.id", ondelete="CASCADE"), index=True) + run: Mapped[Run] = relationship("Run", back_populates="messages", init=False) + id: Mapped[str] = mapped_column(index=True) created_at: Mapped[datetime] = mapped_column(UtcAwareDateTime) - updated_at: Mapped[datetime | None] = mapped_column(UtcAwareDateTime, default=None) role: Mapped[ag_ui.core.Role] = mapped_column( types.Enum(*get_args(ag_ui.core.Role), name="message_role", native_enum=False) @@ -220,7 +238,7 @@ class ReasoningMessage(Message, kw_only=True, repr=False): class InputContent(Base, kw_only=True, repr=False): __tablename__ = "input_contents" - user_message_id: Mapped[str] = mapped_column(ForeignKey("messages.id"), primary_key=True) + user_message_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("messages.uid", ondelete="CASCADE"), primary_key=True) user_message: Mapped[UserMessage] = relationship("UserMessage", init=False, back_populates="input_contents") index: Mapped[int] = mapped_column(primary_key=True) @@ -235,12 +253,14 @@ class ToolCall(Base, kw_only=True, repr=False): id: Mapped[str] = mapped_column(primary_key=True) - assistant_message_id: Mapped[str] = mapped_column(ForeignKey("messages.id", ondelete="CASCADE"), index=True) + assistant_message_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("messages.uid", ondelete="CASCADE"), index=True) assistant_message: Mapped[AssistantMessage] = relationship( init=False, back_populates="tool_calls", foreign_keys=[assistant_message_id] ) - tool_message_id: Mapped[str | None] = mapped_column(ForeignKey("messages.id", ondelete="CASCADE"), index=True) + tool_message_id: Mapped[uuid.UUID | None] = mapped_column( + ForeignKey("messages.uid", ondelete="CASCADE"), index=True + ) tool_message: Mapped[ToolMessage | None] = relationship( init=False, back_populates="tool_call", foreign_keys=[tool_message_id] ) diff --git a/src/_ravnar/schema/__init__.py b/src/_ravnar/schema/__init__.py index e758a17..12f2998 100644 --- a/src/_ravnar/schema/__init__.py +++ b/src/_ravnar/schema/__init__.py @@ -19,6 +19,7 @@ "Pagination", "QuickPrompt", "RenameThreadData", + "Run", "TModel", "Thread", "User", @@ -42,6 +43,7 @@ Event, QuickPrompt, RenameThreadData, + Run, Thread, ) from .misc import APIRouter, BaseModel, Page, Pagination, TModel, User, create_str_literal diff --git a/src/_ravnar/schema/api.py b/src/_ravnar/schema/api.py index 38a5d66..e37805d 100644 --- a/src/_ravnar/schema/api.py +++ b/src/_ravnar/schema/api.py @@ -17,6 +17,7 @@ "Event", "QuickPrompt", "RenameThreadData", + "Run", "Thread", ] @@ -54,13 +55,11 @@ class Thread(BaseModel): name: str | None = None agent_id: str created_at: datetime - updated_at: datetime class AugmentedMessageMixin(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4())) created_at: datetime = Field(default_factory=now) - updated_at: datetime | None = None @classmethod def _convert_orm_tool_call(cls, tool_call: orm.ToolCall) -> ag_ui.core.ToolCall: @@ -168,7 +167,16 @@ class CreateThreadData(BaseModel): agent_id: str +class Run(BaseModel): + id: str + thread_id: str + parent_run_id: str | None = None + created_at: datetime + + class CreateRunData(BaseModel): + id: str = Field(default_factory=lambda: str(uuid.uuid4())) + parent_run_id: str | None = None messages: list[AugmentedUserMessage | AugmentedToolMessage] tools: list[ag_ui.core.Tool] = Field(default_factory=list) context: list[ag_ui.core.Context] = Field(default_factory=list) diff --git a/tests/api/test_threads.py b/tests/api/test_threads.py index 956faaa..538fa07 100644 --- a/tests/api/test_threads.py +++ b/tests/api/test_threads.py @@ -30,8 +30,6 @@ def test_create_thread(self, app_client, id, name): assert thread.id == id assert thread.name == name assert thread.agent_id == agent_id - assert thread.updated_at == thread.created_at - expected = thread response = app_client.get(f"/api/threads/{thread.id}").raise_for_status() actual = schema.Thread.model_validate_json(response.content) @@ -263,7 +261,7 @@ def create_run(self, client, *, thread_id, data, **kwargs): with httpx_sse.connect_sse( client, "POST", - f"/api/threads/{thread_id}/run", + f"/api/threads/{thread_id}/runs", json=schema.CreateRunData.model_validate(data).model_dump(mode="json", by_alias=True, exclude_unset=True), **kwargs, ) as event_source: @@ -330,3 +328,91 @@ def test_files_smoke(self, app_client): list(event_stream) app_client.get(f"/api/threads/{thread_id}/messages").raise_for_status() + + def test_run_crud(self, app_client): + thread_id = self.create_thread(app_client).id + run_id = "run-1" + + event_stream = self.create_run( + app_client, + thread_id=thread_id, + data={"id": run_id, "messages": [{"role": "user", "content": [{"type": "text", "text": "hello"}]}]}, + ) + list(event_stream) + + response = app_client.get(f"/api/threads/{thread_id}/runs").raise_for_status() + runs_page = schema.Page[schema.Run].model_validate_json(response.content) + assert len(runs_page.items) == 1 + assert runs_page.items[0].id == run_id + + response = app_client.get(f"/api/threads/{thread_id}/runs/{run_id}").raise_for_status() + run = schema.Run.model_validate_json(response.content) + assert run.id == run_id + assert run.thread_id == thread_id + assert run.parent_run_id is None + + response = app_client.get(f"/api/threads/{thread_id}/runs/{run_id}/messages").raise_for_status() + messages = pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_json(response.content) + assert len(messages) == 2 # user message + assistant response + + response = app_client.get(f"/api/threads/{thread_id}/messages").raise_for_status() + thread_messages = pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_json(response.content) + assert len(thread_messages) == 2 + + def test_child_run(self, app_client): + thread_id = self.create_thread(app_client).id + run1_id = "run-1" + run2_id = "run-2" + + event_stream = self.create_run( + app_client, + thread_id=thread_id, + data={"id": run1_id, "messages": [{"role": "user", "content": [{"type": "text", "text": "hello"}]}]}, + ) + list(event_stream) + + event_stream = self.create_run( + app_client, + thread_id=thread_id, + data={ + "id": run2_id, + "parentRunId": run1_id, + "messages": [{"role": "user", "content": [{"type": "text", "text": "follow up"}]}], + }, + ) + list(event_stream) + + response = app_client.get(f"/api/threads/{thread_id}/runs/{run2_id}").raise_for_status() + run2 = schema.Run.model_validate_json(response.content) + assert run2.parent_run_id == run1_id + + response = app_client.get(f"/api/threads/{thread_id}/runs/{run2_id}/messages").raise_for_status() + messages = pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_json(response.content) + assert len(messages) == 4 # run1 user+assistant, run2 user+assistant + + response = app_client.get(f"/api/threads/{thread_id}/runs/{run1_id}/messages").raise_for_status() + messages1 = pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_json(response.content) + assert len(messages1) == 2 # run1 user+assistant + + def test_invalid_parent_run_id(self, app_client): + thread1_id = self.create_thread(app_client).id + thread2_id = self.create_thread(app_client).id + run_id = "run-1" + + event_stream = self.create_run( + app_client, + thread_id=thread1_id, + data={"id": run_id, "messages": [{"role": "user", "content": [{"type": "text", "text": "hello"}]}]}, + ) + list(event_stream) + + with httpx_sse.connect_sse( + app_client, + "POST", + f"/api/threads/{thread2_id}/runs", + json=schema.CreateRunData( + messages=[{"role": "user", "content": [{"type": "text", "text": "hello"}]}], + parent_run_id=run_id, + ).model_dump(mode="json", by_alias=True), + ) as event_source: + assert event_source.response.status_code == status.HTTP_404_NOT_FOUND diff --git a/tests/test_events.py b/tests/test_events.py index 4f6dc98..bc16e4d 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -24,13 +24,17 @@ def assert_equal(self, actual, expected): @pytest_cases.parametrize_with_cases("test_case", cases=test_events_cases.EventProcessingCases) async def test_event_processing(self, test_case: test_events_cases.EventProcessingCase): - event_processor = EventProcessor( + run_agent_input = ag_ui.core.RunAgentInput( thread_id=test_case.thread_id, run_id=test_case.run_id, parent_run_id=test_case.parent_run_id, state=test_case.state, messages=test_case.messages, + tools=[], + context=[], + forwarded_props=None, ) + event_processor = EventProcessor(run_agent_input=run_agent_input) input = test_case.input if test_case.handle_run_lifecycle_events: @@ -52,10 +56,11 @@ async def test_event_processing(self, test_case: test_events_cases.EventProcessi self.assert_equal(actual_event_stream, test_case.expected_event_stream) - actual_state, actual_orm_messages = event_processor.extract() + actual_run = event_processor.extract(include_input_message_ids={m.id for m in test_case.messages}) actual_messages = pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_python( - actual_orm_messages, from_attributes=True + actual_run.messages, from_attributes=True ) + actual_state = actual_run.state compyre.assert_equal(actual_state, test_case.expected_state) self.assert_equal(actual_messages, test_case.expected_messages) diff --git a/tests/test_events_cases.py b/tests/test_events_cases.py index efe9796..8eab489 100644 --- a/tests/test_events_cases.py +++ b/tests/test_events_cases.py @@ -136,7 +136,6 @@ def case_activity_message_delta(self): activity_type=activity_type, content={"baz": "boo", "hello": ["world"]}, created_at=parse_timestamp(snapshot_timestamp), - updated_at=parse_timestamp(last_patch_timestamp), ) ], )