From f798bf0f132c1dc1b6433507de3df7bb943d6609 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 13 May 2026 11:38:06 +0200 Subject: [PATCH 1/7] make runs persistent --- src/_ravnar/api/threads.py | 95 ++++++++++++++---- src/_ravnar/core.py | 8 +- src/_ravnar/database.py | 111 +++++++++++++++++---- src/_ravnar/events.py | 174 +++++++++++++++++++-------------- src/_ravnar/orm.py | 41 +++++--- src/_ravnar/schema/__init__.py | 2 + src/_ravnar/schema/api.py | 12 ++- tests/api/test_threads.py | 100 ++++++++++++++++++- tests/test_events.py | 11 ++- tests/test_events_cases.py | 1 - 10 files changed, 417 insertions(+), 138 deletions(-) diff --git a/src/_ravnar/api/threads.py b/src/_ravnar/api/threads.py index d74fcc1..140cfc9 100644 --- a/src/_ravnar/api/threads.py +++ b/src/_ravnar/api/threads.py @@ -1,19 +1,17 @@ from __future__ import annotations import base64 -import uuid from collections.abc import Callable from typing import TYPE_CHECKING, Annotated, Any import ag_ui.core -import ag_ui.encoder import fastsse import pydantic from fastapi import Depends, HTTPException, Path, Query, status from _ravnar import schema from _ravnar.file_storage import FileHandler, WrappedMetadata -from _ravnar.utils import as_awaitable, now +from _ravnar.utils import as_awaitable if TYPE_CHECKING: from _ravnar.database import Database @@ -22,8 +20,10 @@ from . import AgentHandler ThreadsSortBy = str + RunsSortBy = str else: - ThreadsSortBy = schema.create_str_literal("created_at", "updated_at", default="created_at") + ThreadsSortBy = schema.create_str_literal("created_at", default="created_at") + RunsSortBy = schema.create_str_literal("created_at", default="created_at") def make_router( @@ -69,23 +69,75 @@ async def get_thread_messages( id: Annotated[str, Path(alias="threadId")], user: schema.User = Depends(authenticated_user), # noqa: B008 ) -> list[schema.AugmentedMessage]: - thread = await database.get_thread(user_id=user.id, id=id, with_messages=True) - return pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_python( - thread.messages, from_attributes=True + latest_run = await database.get_latest_run(thread_id=id) + if latest_run is None: + return [] + messages = await database.get_run_messages(run_id=latest_run.id) + return pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_python(messages, from_attributes=True) + + @router.get("/{threadId}/runs") + async def get_runs( + *, + user: schema.User = Depends(authenticated_user), # noqa: B008 + thread_id: Annotated[str, Path(alias="threadId")], + pagination: Annotated[schema.Pagination[RunsSortBy], Query()], + ) -> schema.Page[schema.Run]: + return schema.Page[schema.Run].model_validate( + await database.get_runs(user_id=user.id, thread_id=thread_id, pagination=pagination), + from_attributes=True, ) - @router.sse("/{threadId}/run", methods=["POST"], response_model=schema.Event, tags=["Runs"]) + @router.get("/{threadId}/runs/{runId}") + async def get_run( + *, + user: schema.User = Depends(authenticated_user), # noqa: B008 + thread_id: Annotated[str, Path(alias="threadId")], + run_id: Annotated[str, Path(alias="runId")], + ) -> schema.Run: + return schema.Run.model_validate(await database.get_run(id=run_id, user_id=user.id), from_attributes=True) + + @router.get("/{threadId}/runs/{runId}/messages") + async def get_run_messages( + *, + user: schema.User = Depends(authenticated_user), # noqa: B008 + thread_id: Annotated[str, Path(alias="threadId")], + run_id: Annotated[str, Path(alias="runId")], + ) -> list[schema.AugmentedMessage]: + messages = await database.get_run_messages(run_id=run_id) + return pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_python(messages, from_attributes=True) + + @router.sse("/{threadId}/runs", methods=["POST"], response_model=schema.Event, tags=["Runs"]) async def create_run( *, user: schema.User = Depends(authenticated_user), # noqa: B008 thread_id: Annotated[str, Path(alias="threadId")], data: schema.CreateRunData, ) -> fastsse.Response: - thread = await database.get_thread(user_id=user.id, id=thread_id, with_messages=True) - - messages = pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_python( - thread.messages, from_attributes=True - ) + thread = await database.get_thread(user_id=user.id, id=thread_id) + + parent_run_id = data.parent_run_id + if parent_run_id is not None: + parent_run = await database.get_run(id=parent_run_id, user_id=user.id) + if parent_run.thread_id != thread_id: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + detail="parent_run_id does not belong to this thread", + ) + else: + latest_run = await database.get_latest_run(thread_id=thread_id) + parent_run_id = latest_run.id if latest_run is not None else None + + parent_messages: list[schema.AugmentedMessage] = [] + parent_state = None + if parent_run_id is not None: + orm_messages = await database.get_run_messages(run_id=parent_run_id) + parent_messages = pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_python( + orm_messages, from_attributes=True + ) + parent_run = await database.get_run(id=parent_run_id, user_id=user.id) + parent_state = parent_run.state + + messages = list(parent_messages) messages.extend(data.messages) for m in messages: @@ -107,11 +159,13 @@ async def create_run( ) input_content.metadata = WrappedMetadata(raw=input_content.metadata, file_id=file.id) + client_message_ids = {m.id for m in data.messages} + run_agent_input = ag_ui.core.RunAgentInput( thread_id=thread.id, - run_id=str(uuid.uuid4()), - parent_run_id=None, - state=thread.state, + run_id=data.id, + parent_run_id=parent_run_id, + state=parent_state, messages=[pydantic.TypeAdapter(ag_ui.core.Message).validate_python(m.model_dump()) for m in messages], tools=data.tools, context=data.context, @@ -119,9 +173,12 @@ async def create_run( ) async def callback(event_processor: EventProcessor) -> None: - thread.state, thread.messages = event_processor.extract() - thread.updated_at = now() - await database.update_thread(thread) + # Client-supplied messages are part of this run's delta, not inherited + for msg_id in client_message_ids: + if msg_id in event_processor._parent_messages: + event_processor._messages[msg_id] = event_processor._parent_messages.pop(msg_id) + run = event_processor.extract() + await database.create_run(run=run) return await agent_handler.run(thread.agent_id, run_agent_input, callback=callback) diff --git a/src/_ravnar/core.py b/src/_ravnar/core.py index aabf3d4..940c62a 100644 --- a/src/_ravnar/core.py +++ b/src/_ravnar/core.py @@ -147,13 +147,7 @@ async def run( self.assert_available(agent_id) agent = self._agents[agent_id] - event_processor = EventProcessor( - thread_id=run_agent_input.thread_id, - run_id=run_agent_input.run_id, - parent_run_id=run_agent_input.parent_run_id, - state=run_agent_input.state, - messages=run_agent_input.messages, - ) + event_processor = EventProcessor(run_input=run_agent_input) async def event_stream() -> AsyncIterator[ag_ui.core.Event]: async for event in event_processor.process_event_stream(agent.run(run_agent_input)): diff --git a/src/_ravnar/database.py b/src/_ravnar/database.py index 6fc0347..d4602eb 100644 --- a/src/_ravnar/database.py +++ b/src/_ravnar/database.py @@ -161,16 +161,13 @@ async def create_thread(self, *, user_id: str, id: str, name: str | None, agent_ if thread is not None: raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Thread exists") - created_at = updated_at = now() thread = orm.Thread( id=id, user_id=user_id, agent_id=agent_id, name=name, - created_at=created_at, - updated_at=updated_at, - state=None, - messages=[], + created_at=now(), + runs=[], ) session.add(thread) return thread @@ -196,34 +193,112 @@ async def get_threads(self, *, user_id: str, pagination: schema.Pagination) -> o async with self._get_session() as session: return await self._get_threads(session, user_id=user_id, pagination=pagination) - async def _get_thread(self, session: AsyncSession, *, user_id: str, id: str, with_messages: bool) -> orm.Thread: + async def _get_thread(self, session: AsyncSession, *, user_id: str, id: str) -> orm.Thread: query = select(orm.Thread).where((orm.Thread.id == id) & (orm.Thread.user_id == user_id)) - if with_messages: - query = query.options(selectinload(orm.Thread.messages)) result = await session.execute(query) thread = result.scalar_one_or_none() if thread is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Thread not found") return thread - async def get_thread(self, *, user_id: str, id: str, with_messages: bool = False) -> orm.Thread: + async def get_thread(self, *, user_id: str, id: str) -> orm.Thread: async with self._get_session() as session: - return await self._get_thread(session, user_id=user_id, id=id, with_messages=with_messages) - - async def append_messages_to_thread(self, *, user_id: str, id: str, messages: list[orm.Message]) -> None: - async with self._get_session() as session: - thread = await self._get_thread(session, user_id=user_id, id=id, with_messages=False) - thread.messages.extend(messages) + return await self._get_thread(session, user_id=user_id, id=id) async def rename_thread(self, *, user_id: str, id: str, name: str) -> orm.Thread: async with self._get_session() as session: - thread = await self._get_thread(session, user_id=user_id, id=id, with_messages=False) + thread = await self._get_thread(session, user_id=user_id, id=id) thread.name = name return thread - async def update_thread(self, thread: orm.Thread) -> None: + async def create_run(self, *, run: orm.Run) -> None: + async with self._get_session() as session: + session.add(run) + + async def get_run(self, *, id: str, user_id: str) -> orm.Run: + async with self._get_session() as session: + query = ( + select(orm.Run) + .join(orm.Thread, orm.Run.thread_id == orm.Thread.id) + .where((orm.Run.id == id) & (orm.Thread.user_id == user_id)) + ) + result = await session.execute(query) + run = result.scalar_one_or_none() + if run is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Run not found") + return run + + async def get_runs( + self, + *, + user_id: str, + thread_id: str, + pagination: schema.Pagination | None = None, + ) -> orm.Page[orm.Run]: + async with self._get_session() as session: + + def select_qualifier(query: Select) -> Select: + return query.join(orm.Thread, orm.Run.thread_id == orm.Thread.id).where( + (orm.Run.thread_id == thread_id) & (orm.Thread.user_id == user_id) + ) + + return await self._get_page( + session, orm_type=orm.Run, select_qualifier=select_qualifier, pagination=pagination + ) + + async def get_run_messages(self, *, run_id: str) -> list[orm.Message]: + async with self._get_session() as session: + from sqlalchemy import literal_column + + run_chain = ( + select( + orm.Run.id.label("run_id"), + orm.Run.parent_run_id.label("parent_run_id"), + literal_column("0").label("depth"), + ) + .where(orm.Run.id == run_id) + .cte(name="run_chain", recursive=True) + ) + run_chain = run_chain.union_all( + select(orm.Run.id, orm.Run.parent_run_id, (run_chain.c.depth + 1).label("depth")) + .select_from(orm.Run) + .join(run_chain, orm.Run.id == run_chain.c.parent_run_id) + ) + + ranked = ( + select( + orm.Message.uid, + func.row_number() + .over( + partition_by=orm.Message.id, + order_by=run_chain.c.depth.asc(), + ) + .label("rn"), + ) + .join(run_chain, orm.Message.run_id == run_chain.c.run_id) + .subquery() + ) + + query = ( + select(orm.Message) + .join(ranked, orm.Message.uid == ranked.c.uid) + .where(ranked.c.rn == 1) + .options( + selectinload(orm.UserMessage.input_contents).selectinload(orm.InputContent.file), + selectinload(orm.AssistantMessage.tool_calls), + selectinload(orm.ToolMessage.tool_call).selectinload(orm.ToolCall.tool_message), + ) + .order_by(orm.Message.created_at.asc(), orm.Message.id.asc()) + ) + + result = await session.execute(query) + return result.unique().scalars().all() + + async def get_latest_run(self, *, thread_id: str) -> orm.Run | None: async with self._get_session() as session: - await session.merge(thread) + query = select(orm.Run).where(orm.Run.thread_id == thread_id).order_by(orm.Run.created_at.desc()).limit(1) + result = await session.execute(query) + return result.scalar_one_or_none() async def delete_threads(self, *, user_id: str, ids: list[str]) -> None: async with self._get_session() as session: diff --git a/src/_ravnar/events.py b/src/_ravnar/events.py index 6947389..597ec55 100644 --- a/src/_ravnar/events.py +++ b/src/_ravnar/events.py @@ -74,21 +74,15 @@ class ReasoningData: class EventProcessor: - def __init__( - self, - *, - thread_id: str, - run_id: str, - parent_run_id: str | None, - state: ag_ui.core.State, - messages: list[ag_ui.core.Message], - ): - self._thread_id = thread_id - self._run_id = run_id - self._parent_run_id = parent_run_id - - self._state = state - self._messages = self._convert_messages(messages) + def __init__(self, *, run_input: ag_ui.core.RunAgentInput): + self._run_input = run_input + self._thread_id = run_input.thread_id + self._run_id = run_input.run_id + self._parent_run_id = run_input.parent_run_id + + self._state = run_input.state + self._parent_messages = self._convert_messages(run_input.messages) + self._messages: dict[str, orm.Message] = {} self._progress = RunProgress.NOT_STARTED self._text_message_data: dict[str, TextMessageData] = {} @@ -97,15 +91,17 @@ def __init__( self._reasoning_data: dict[str, ReasoningData] = {} self._thinking_message_id: str | None = None - self._logger = structlog.get_logger(thread_id=thread_id, run_id=run_id, parent_run_id=parent_run_id) + self._logger = structlog.get_logger( + thread_id=run_input.thread_id, run_id=run_input.run_id, parent_run_id=run_input.parent_run_id + ) + + def _convert_messages(self, messages: list[ag_ui.core.Message]) -> dict[str, orm.Message]: + message_uids = {m.id: uuid.uuid4() for m in messages} - def _convert_messages( - self, messages: list[ag_ui.core.Message], *, updated_at: datetime | None = None - ) -> dict[str, orm.Message]: tool_calls = { tc.id: orm.ToolCall( id=tc.id, - assistant_message_id=m.id, + assistant_message_id=message_uids[m.id], tool_message_id=None, name=tc.function.name, arguments=tc.function.arguments, @@ -139,10 +135,9 @@ def _convert_messages( text = None file_id = metadata.file_id input_contents.append( - orm.InputContent(user_message_id=m.id, index=i, text=text, file_id=file_id) + orm.InputContent(user_message_id=message_uids[m.id], index=i, text=text, file_id=file_id) ) data = {**m.model_dump(exclude={"content"}), "input_contents": input_contents} - print() case ag_ui.core.AssistantMessage(): data = { **m.model_dump(exclude={"tool_calls"}), @@ -150,15 +145,14 @@ def _convert_messages( } case ag_ui.core.ToolMessage(): tool_call = tool_calls[m.tool_call_id] - tool_call.tool_message_id = m.id + tool_call.tool_message_id = message_uids[m.id] data = {**m.model_dump(exclude={"tool_call_id"}), "tool_call": tool_call} case _: data = m.model_dump() - if updated_at is not None: - data["updated_at"] = updated_at - data["thread_id"] = self._thread_id - + data["uid"] = message_uids[m.id] + # Parent messages are never persisted; run_id is a placeholder to satisfy SQLAlchemy + data["run_id"] = "" converted_messages[m.id] = cls(**data) return converted_messages @@ -345,9 +339,12 @@ def _process_event(self, event: ag_ui.core.Event) -> ag_ui.core.Event | None: ) return None + if event.message_id in self._parent_messages and not event.replace: + return None + self._messages[event.message_id] = orm.ActivityMessage( id=event.message_id, - thread_id=self._thread_id, + run_id="", created_at=parse_timestamp(event.timestamp), content=event.content, activity_type=event.activity_type, @@ -355,13 +352,25 @@ def _process_event(self, event: ag_ui.core.Event) -> ag_ui.core.Event | None: case ag_ui.core.ActivityDeltaEvent(): message = self._messages.get(event.message_id) if message is None: - logger.error( - "event", - state="dropped", - reason="message does not exist", - message_id=event.message_id, + message = self._parent_messages.get(event.message_id) + if message is None: + logger.error( + "event", + state="dropped", + reason="message does not exist", + message_id=event.message_id, + ) + return None + # Copy inherited message to delta bucket + self._messages[event.message_id] = orm.ActivityMessage( + id=message.id, + run_id="", + created_at=message.created_at, + content=message.content, + activity_type=message.activity_type, ) - return None + message = self._messages[event.message_id] + if not isinstance(message, orm.ActivityMessage): logger.error( "event", @@ -385,7 +394,6 @@ def _process_event(self, event: ag_ui.core.Event) -> ag_ui.core.Event | None: if content is None: return None - message.updated_at = parse_timestamp(event.timestamp) message.content = content # special events # reasoning events @@ -481,13 +489,20 @@ def _apply_jsonpatch(document: dict[str, Any], patches: list[Any], *, logger: Fi ) return None - def extract(self) -> tuple[orm.State, list[orm.Message]]: - return self._state, self._extract_messages() + def extract(self) -> orm.Run: + delta_messages = list(self._messages.values()) + self._extract_messages() + for m in delta_messages: + m.run_id = self._run_input.run_id + return orm.Run( + id=self._run_input.run_id, + thread_id=self._run_input.thread_id, + parent_run_id=self._run_input.parent_run_id, + state=self._state, + messages=delta_messages, + ) def _extract_messages(self) -> list[orm.Message]: - tool_calls: dict[str, orm.ToolCall] = {} - tool_calls_created_at: dict[str, datetime] = {} - grouped_tool_calls: dict[str, list[orm.ToolCall]] = {} + grouped_tool_calls: dict[str, list[ToolCallData]] = {} for tcd in self._tool_call_data.values(): if not tcd.finished: self._logger.warn( @@ -499,47 +514,54 @@ def _extract_messages(self) -> list[orm.Message]: parent_message_id=tcd.parent_message_id, ) continue + grouped_tool_calls.setdefault(tcd.parent_message_id, []).append(tcd) - tool_call = orm.ToolCall( - id=tcd.tool_call_id, - assistant_message_id=tcd.parent_message_id, - tool_message_id=None, - name=tcd.tool_call_name, - arguments="".join(tcd.arguments_delta), - # FIXME - encrypted_value=None, - ) - tool_calls[tool_call.id] = tool_call - tool_calls_created_at[tool_call.id] = tcd.created_at - grouped_tool_calls.setdefault(tool_call.assistant_message_id, []).append(tool_call) - - messages: list[orm.Message] = list(self._messages.values()) + # Build assistant messages so we have their UUIDs for tool call FKs + assistant_messages: dict[str, orm.AssistantMessage] = {} for tmd in self._text_message_data.values(): if not tmd.finished: self._logger.warn("text message", state="dropped", reason="unfinished", message_id=tmd.message_id) continue - messages.append( - orm.AssistantMessage( - id=tmd.message_id, - thread_id=self._thread_id, - created_at=tmd.created_at, - content="".join(tmd.content_deltas) or None, - tool_calls=grouped_tool_calls.pop(tmd.message_id, []), - ) + assistant_messages[tmd.message_id] = orm.AssistantMessage( + uid=uuid.uuid4(), + run_id="", + id=tmd.message_id, + created_at=tmd.created_at, + content="".join(tmd.content_deltas) or None, + tool_calls=[], ) - for mid, tcs in grouped_tool_calls.items(): - messages.append( - orm.AssistantMessage( - id=mid, - thread_id=self._thread_id, - created_at=min(tool_calls_created_at[tc.id] for tc in tcs), + for parent_message_id, tcds in grouped_tool_calls.items(): + if parent_message_id not in assistant_messages: + assistant_messages[parent_message_id] = orm.AssistantMessage( + uid=uuid.uuid4(), + run_id="", + id=parent_message_id, + created_at=min(tcd.created_at for tcd in tcds), content=None, - tool_calls=tcs, + tool_calls=[], ) - ) + + tool_calls: dict[str, orm.ToolCall] = {} + for parent_message_id, tcds in grouped_tool_calls.items(): + parent_msg = assistant_messages[parent_message_id] + for tcd in tcds: + tool_call = orm.ToolCall( + id=tcd.tool_call_id, + assistant_message_id=parent_msg.uid, + tool_message_id=None, + name=tcd.tool_call_name, + arguments="".join(tcd.arguments_delta), + encrypted_value=None, + ) + tool_calls[tool_call.id] = tool_call + parent_msg.tool_calls.append(tool_call) + + messages: list[orm.Message] = [] + + messages.extend(assistant_messages.values()) for rd in self._reasoning_data.values(): if not rd.finished: @@ -548,8 +570,9 @@ def _extract_messages(self) -> list[orm.Message]: messages.append( orm.ReasoningMessage( + uid=uuid.uuid4(), + run_id="", id=rd.message_id, - thread_id=self._thread_id, created_at=rd.created_at, content="".join(rd.content_deltas), ) @@ -565,14 +588,17 @@ def _extract_messages(self) -> list[orm.Message]: tool_call_id=trd.tool_call_id, ) continue + msg_uid = uuid.uuid4() + tool_call = tool_calls[trd.tool_call_id] + tool_call.tool_message_id = msg_uid messages.append( orm.ToolMessage( + uid=msg_uid, + run_id="", id=trd.message_id, - thread_id=self._thread_id, created_at=trd.created_at, content=trd.content, - tool_call=tool_calls[trd.tool_call_id], - # FIXME + tool_call=tool_call, error=None, encrypted_value=None, ) diff --git a/src/_ravnar/orm.py b/src/_ravnar/orm.py index 503e87f..8156ce5 100644 --- a/src/_ravnar/orm.py +++ b/src/_ravnar/orm.py @@ -111,12 +111,30 @@ class Thread(Base, kw_only=True, repr=False): name: Mapped[str | None] created_at: Mapped[datetime] = mapped_column(UtcAwareDateTime) - updated_at: Mapped[datetime] = mapped_column(UtcAwareDateTime) - state: Mapped[State] = mapped_column(Json, nullable=True) + runs: Mapped[list[Run]] = relationship( + "Run", + back_populates="thread", + cascade="all, delete-orphan", + order_by="Run.created_at.asc()", + ) + + +class Run(Base, kw_only=True, repr=False): + __tablename__ = "runs" + + id: Mapped[str] = mapped_column(primary_key=True) + thread_id: Mapped[str] = mapped_column(ForeignKey("threads.id", ondelete="CASCADE"), index=True) + thread: Mapped[Thread] = relationship("Thread", back_populates="runs", init=False) + parent_run_id: Mapped[str | None] = mapped_column( + ForeignKey("runs.id", ondelete="CASCADE"), index=True, default=None + ) + created_at: Mapped[datetime] = mapped_column(UtcAwareDateTime, default_factory=now) + state: Mapped[State] = mapped_column(Json, nullable=True, default=None) + messages: Mapped[list[Message]] = relationship( "Message", - back_populates="thread", + back_populates="run", cascade="all, delete-orphan", order_by="[Message.created_at.asc(), Message.id]", ) @@ -125,13 +143,12 @@ class Thread(Base, kw_only=True, repr=False): class Message(Base, kw_only=True, repr=False): __tablename__ = "messages" - id: Mapped[str] = mapped_column(primary_key=True) - - thread_id: Mapped[str] = mapped_column(ForeignKey("threads.id"), index=True) - thread: Mapped[Thread] = relationship(init=False) + uid: Mapped[uuid.UUID] = mapped_column(types.Uuid, primary_key=True, default_factory=uuid.uuid4) + run_id: Mapped[str] = mapped_column(ForeignKey("runs.id", ondelete="CASCADE"), index=True) + run: Mapped[Run] = relationship("Run", back_populates="messages", init=False) + id: Mapped[str] = mapped_column(index=True) created_at: Mapped[datetime] = mapped_column(UtcAwareDateTime) - updated_at: Mapped[datetime | None] = mapped_column(UtcAwareDateTime, default=None) role: Mapped[ag_ui.core.Role] = mapped_column( types.Enum(*get_args(ag_ui.core.Role), name="message_role", native_enum=False) @@ -220,7 +237,7 @@ class ReasoningMessage(Message, kw_only=True, repr=False): class InputContent(Base, kw_only=True, repr=False): __tablename__ = "input_contents" - user_message_id: Mapped[str] = mapped_column(ForeignKey("messages.id"), primary_key=True) + user_message_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("messages.uid", ondelete="CASCADE"), primary_key=True) user_message: Mapped[UserMessage] = relationship("UserMessage", init=False, back_populates="input_contents") index: Mapped[int] = mapped_column(primary_key=True) @@ -235,12 +252,14 @@ class ToolCall(Base, kw_only=True, repr=False): id: Mapped[str] = mapped_column(primary_key=True) - assistant_message_id: Mapped[str] = mapped_column(ForeignKey("messages.id", ondelete="CASCADE"), index=True) + assistant_message_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("messages.uid", ondelete="CASCADE"), index=True) assistant_message: Mapped[AssistantMessage] = relationship( init=False, back_populates="tool_calls", foreign_keys=[assistant_message_id] ) - tool_message_id: Mapped[str | None] = mapped_column(ForeignKey("messages.id", ondelete="CASCADE"), index=True) + tool_message_id: Mapped[uuid.UUID | None] = mapped_column( + ForeignKey("messages.uid", ondelete="CASCADE"), index=True + ) tool_message: Mapped[ToolMessage | None] = relationship( init=False, back_populates="tool_call", foreign_keys=[tool_message_id] ) diff --git a/src/_ravnar/schema/__init__.py b/src/_ravnar/schema/__init__.py index e758a17..12f2998 100644 --- a/src/_ravnar/schema/__init__.py +++ b/src/_ravnar/schema/__init__.py @@ -19,6 +19,7 @@ "Pagination", "QuickPrompt", "RenameThreadData", + "Run", "TModel", "Thread", "User", @@ -42,6 +43,7 @@ Event, QuickPrompt, RenameThreadData, + Run, Thread, ) from .misc import APIRouter, BaseModel, Page, Pagination, TModel, User, create_str_literal diff --git a/src/_ravnar/schema/api.py b/src/_ravnar/schema/api.py index 38a5d66..e37805d 100644 --- a/src/_ravnar/schema/api.py +++ b/src/_ravnar/schema/api.py @@ -17,6 +17,7 @@ "Event", "QuickPrompt", "RenameThreadData", + "Run", "Thread", ] @@ -54,13 +55,11 @@ class Thread(BaseModel): name: str | None = None agent_id: str created_at: datetime - updated_at: datetime class AugmentedMessageMixin(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4())) created_at: datetime = Field(default_factory=now) - updated_at: datetime | None = None @classmethod def _convert_orm_tool_call(cls, tool_call: orm.ToolCall) -> ag_ui.core.ToolCall: @@ -168,7 +167,16 @@ class CreateThreadData(BaseModel): agent_id: str +class Run(BaseModel): + id: str + thread_id: str + parent_run_id: str | None = None + created_at: datetime + + class CreateRunData(BaseModel): + id: str = Field(default_factory=lambda: str(uuid.uuid4())) + parent_run_id: str | None = None messages: list[AugmentedUserMessage | AugmentedToolMessage] tools: list[ag_ui.core.Tool] = Field(default_factory=list) context: list[ag_ui.core.Context] = Field(default_factory=list) diff --git a/tests/api/test_threads.py b/tests/api/test_threads.py index 956faaa..6354822 100644 --- a/tests/api/test_threads.py +++ b/tests/api/test_threads.py @@ -30,8 +30,6 @@ def test_create_thread(self, app_client, id, name): assert thread.id == id assert thread.name == name assert thread.agent_id == agent_id - assert thread.updated_at == thread.created_at - expected = thread response = app_client.get(f"/api/threads/{thread.id}").raise_for_status() actual = schema.Thread.model_validate_json(response.content) @@ -263,7 +261,7 @@ def create_run(self, client, *, thread_id, data, **kwargs): with httpx_sse.connect_sse( client, "POST", - f"/api/threads/{thread_id}/run", + f"/api/threads/{thread_id}/runs", json=schema.CreateRunData.model_validate(data).model_dump(mode="json", by_alias=True, exclude_unset=True), **kwargs, ) as event_source: @@ -330,3 +328,99 @@ def test_files_smoke(self, app_client): list(event_stream) app_client.get(f"/api/threads/{thread_id}/messages").raise_for_status() + + def test_run_crud(self, app_client): + thread_id = self.create_thread(app_client).id + run_id = "run-1" + + event_stream = self.create_run( + app_client, + thread_id=thread_id, + data={"id": run_id, "messages": [{"role": "user", "content": [{"type": "text", "text": "hello"}]}]}, + ) + list(event_stream) + + # List runs + response = app_client.get(f"/api/threads/{thread_id}/runs").raise_for_status() + runs_page = schema.Page[schema.Run].model_validate_json(response.content) + assert len(runs_page.items) == 1 + assert runs_page.items[0].id == run_id + + # Get run + response = app_client.get(f"/api/threads/{thread_id}/runs/{run_id}").raise_for_status() + run = schema.Run.model_validate_json(response.content) + assert run.id == run_id + assert run.thread_id == thread_id + assert run.parent_run_id is None + + # Get run messages + response = app_client.get(f"/api/threads/{thread_id}/runs/{run_id}/messages").raise_for_status() + messages = pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_json(response.content) + assert len(messages) == 2 # user message + assistant response + + # Backward compat thread messages returns latest run snapshot + response = app_client.get(f"/api/threads/{thread_id}/messages").raise_for_status() + thread_messages = pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_json(response.content) + assert len(thread_messages) == 2 + + def test_child_run(self, app_client): + thread_id = self.create_thread(app_client).id + run1_id = "run-1" + run2_id = "run-2" + + event_stream = self.create_run( + app_client, + thread_id=thread_id, + data={"id": run1_id, "messages": [{"role": "user", "content": [{"type": "text", "text": "hello"}]}]}, + ) + list(event_stream) + + # Create child run with explicit parent_run_id + event_stream = self.create_run( + app_client, + thread_id=thread_id, + data={ + "id": run2_id, + "parentRunId": run1_id, + "messages": [{"role": "user", "content": [{"type": "text", "text": "follow up"}]}], + }, + ) + list(event_stream) + + response = app_client.get(f"/api/threads/{thread_id}/runs/{run2_id}").raise_for_status() + run2 = schema.Run.model_validate_json(response.content) + assert run2.parent_run_id == run1_id + + # Run2 messages include run1 messages + new ones + response = app_client.get(f"/api/threads/{thread_id}/runs/{run2_id}/messages").raise_for_status() + messages = pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_json(response.content) + assert len(messages) == 4 # run1 user+assistant, run2 user+assistant + + # Run1 messages unchanged + response = app_client.get(f"/api/threads/{thread_id}/runs/{run1_id}/messages").raise_for_status() + messages1 = pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_json(response.content) + assert len(messages1) == 2 + + def test_invalid_parent_run_id(self, app_client): + thread1_id = self.create_thread(app_client).id + thread2_id = self.create_thread(app_client).id + run_id = "run-1" + + event_stream = self.create_run( + app_client, + thread_id=thread1_id, + data={"id": run_id, "messages": [{"role": "user", "content": [{"type": "text", "text": "hello"}]}]}, + ) + list(event_stream) + + # Try to use run from thread1 as parent for thread2 + with httpx_sse.connect_sse( + app_client, + "POST", + f"/api/threads/{thread2_id}/runs", + json=schema.CreateRunData( + messages=[{"role": "user", "content": [{"type": "text", "text": "hello"}]}], + parent_run_id=run_id, + ).model_dump(mode="json", by_alias=True), + ) as event_source: + assert event_source.response.status_code == 422 diff --git a/tests/test_events.py b/tests/test_events.py index 4f6dc98..8f0f4b0 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -24,13 +24,17 @@ def assert_equal(self, actual, expected): @pytest_cases.parametrize_with_cases("test_case", cases=test_events_cases.EventProcessingCases) async def test_event_processing(self, test_case: test_events_cases.EventProcessingCase): - event_processor = EventProcessor( + run_input = ag_ui.core.RunAgentInput( thread_id=test_case.thread_id, run_id=test_case.run_id, parent_run_id=test_case.parent_run_id, state=test_case.state, messages=test_case.messages, + tools=[], + context=[], + forwarded_props=None, ) + event_processor = EventProcessor(run_input=run_input) input = test_case.input if test_case.handle_run_lifecycle_events: @@ -52,10 +56,11 @@ async def test_event_processing(self, test_case: test_events_cases.EventProcessi self.assert_equal(actual_event_stream, test_case.expected_event_stream) - actual_state, actual_orm_messages = event_processor.extract() + actual_run = event_processor.extract() actual_messages = pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_python( - actual_orm_messages, from_attributes=True + actual_run.messages, from_attributes=True ) + actual_state = actual_run.state compyre.assert_equal(actual_state, test_case.expected_state) self.assert_equal(actual_messages, test_case.expected_messages) diff --git a/tests/test_events_cases.py b/tests/test_events_cases.py index efe9796..8eab489 100644 --- a/tests/test_events_cases.py +++ b/tests/test_events_cases.py @@ -136,7 +136,6 @@ def case_activity_message_delta(self): activity_type=activity_type, content={"baz": "boo", "hello": ["world"]}, created_at=parse_timestamp(snapshot_timestamp), - updated_at=parse_timestamp(last_patch_timestamp), ) ], ) From c0c8cc860e5a0136cab3a39bd3df4eddd935bdf3 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 13 May 2026 14:35:07 +0000 Subject: [PATCH 2/7] Add get_run_history, rename EventProcessor arg to run_agent_input, remove redundant attrs --- src/_ravnar/api/threads.py | 10 +++--- src/_ravnar/core.py | 2 +- src/_ravnar/database.py | 72 ++++++++++++++++++++++++++++++++++++++ src/_ravnar/events.py | 47 +++++++++++++++---------- tests/test_events.py | 4 +-- 5 files changed, 108 insertions(+), 27 deletions(-) diff --git a/src/_ravnar/api/threads.py b/src/_ravnar/api/threads.py index 140cfc9..83e2de5 100644 --- a/src/_ravnar/api/threads.py +++ b/src/_ravnar/api/threads.py @@ -127,15 +127,15 @@ async def create_run( latest_run = await database.get_latest_run(thread_id=thread_id) parent_run_id = latest_run.id if latest_run is not None else None + parent_run: orm.Run | None = None parent_messages: list[schema.AugmentedMessage] = [] - parent_state = None if parent_run_id is not None: - orm_messages = await database.get_run_messages(run_id=parent_run_id) + parent_run, orm_messages = await database.get_run_history( + run_id=parent_run_id, user_id=user.id, thread_id=thread_id + ) parent_messages = pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_python( orm_messages, from_attributes=True ) - parent_run = await database.get_run(id=parent_run_id, user_id=user.id) - parent_state = parent_run.state messages = list(parent_messages) messages.extend(data.messages) @@ -165,7 +165,7 @@ async def create_run( thread_id=thread.id, run_id=data.id, parent_run_id=parent_run_id, - state=parent_state, + state=parent_run.state if parent_run is not None else None, messages=[pydantic.TypeAdapter(ag_ui.core.Message).validate_python(m.model_dump()) for m in messages], tools=data.tools, context=data.context, diff --git a/src/_ravnar/core.py b/src/_ravnar/core.py index 940c62a..32f876a 100644 --- a/src/_ravnar/core.py +++ b/src/_ravnar/core.py @@ -147,7 +147,7 @@ async def run( self.assert_available(agent_id) agent = self._agents[agent_id] - event_processor = EventProcessor(run_input=run_agent_input) + event_processor = EventProcessor(run_agent_input=run_agent_input) async def event_stream() -> AsyncIterator[ag_ui.core.Event]: async for event in event_processor.process_event_stream(agent.run(run_agent_input)): diff --git a/src/_ravnar/database.py b/src/_ravnar/database.py index d4602eb..17e73e8 100644 --- a/src/_ravnar/database.py +++ b/src/_ravnar/database.py @@ -300,6 +300,78 @@ async def get_latest_run(self, *, thread_id: str) -> orm.Run | None: result = await session.execute(query) return result.scalar_one_or_none() + async def get_run_history( + self, *, run_id: str, user_id: str, thread_id: str + ) -> tuple[orm.Run | None, list[orm.Message]]: + """Return the parent run and full message snapshot for a run chain. + + Validates that the resolved run belongs to the given thread and user. + """ + async with self._get_session() as session: + # Resolve the target run and validate ownership in one query + query = ( + select(orm.Run) + .join(orm.Thread, orm.Run.thread_id == orm.Thread.id) + .where( + (orm.Run.id == run_id) + & (orm.Run.thread_id == thread_id) + & (orm.Thread.user_id == user_id) + ) + ) + result = await session.execute(query) + target_run = result.scalar_one_or_none() + if target_run is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Run not found") + + # Build the recursive CTE for the run chain + from sqlalchemy import literal_column + + run_chain = ( + select( + orm.Run.id.label("run_id"), + orm.Run.parent_run_id.label("parent_run_id"), + literal_column("0").label("depth"), + ) + .where(orm.Run.id == run_id) + .cte(name="run_chain", recursive=True) + ) + run_chain = run_chain.union_all( + select(orm.Run.id, orm.Run.parent_run_id, (run_chain.c.depth + 1).label("depth")) + .select_from(orm.Run) + .join(run_chain, orm.Run.id == run_chain.c.parent_run_id) + ) + + ranked = ( + select( + orm.Message.uid, + func.row_number() + .over( + partition_by=orm.Message.id, + order_by=run_chain.c.depth.asc(), + ) + .label("rn"), + ) + .join(run_chain, orm.Message.run_id == run_chain.c.run_id) + .subquery() + ) + + messages_query = ( + select(orm.Message) + .join(ranked, orm.Message.uid == ranked.c.uid) + .where(ranked.c.rn == 1) + .options( + selectinload(orm.UserMessage.input_contents).selectinload(orm.InputContent.file), + selectinload(orm.AssistantMessage.tool_calls), + selectinload(orm.ToolMessage.tool_call).selectinload(orm.ToolCall.tool_message), + ) + .order_by(orm.Message.created_at.asc(), orm.Message.id.asc()) + ) + + messages_result = await session.execute(messages_query) + messages = messages_result.unique().scalars().all() + + return target_run, list(messages) + async def delete_threads(self, *, user_id: str, ids: list[str]) -> None: async with self._get_session() as session: single_page = await self._get_threads(session, user_id=user_id, ids=ids) diff --git a/src/_ravnar/events.py b/src/_ravnar/events.py index 597ec55..ff132e1 100644 --- a/src/_ravnar/events.py +++ b/src/_ravnar/events.py @@ -74,14 +74,11 @@ class ReasoningData: class EventProcessor: - def __init__(self, *, run_input: ag_ui.core.RunAgentInput): - self._run_input = run_input - self._thread_id = run_input.thread_id - self._run_id = run_input.run_id - self._parent_run_id = run_input.parent_run_id - - self._state = run_input.state - self._parent_messages = self._convert_messages(run_input.messages) + def __init__(self, *, run_agent_input: ag_ui.core.RunAgentInput): + self._run_agent_input = run_agent_input + + self._state = run_agent_input.state + self._parent_messages = self._convert_messages(run_agent_input.messages) self._messages: dict[str, orm.Message] = {} self._progress = RunProgress.NOT_STARTED @@ -92,7 +89,9 @@ def __init__(self, *, run_input: ag_ui.core.RunAgentInput): self._thinking_message_id: str | None = None self._logger = structlog.get_logger( - thread_id=run_input.thread_id, run_id=run_input.run_id, parent_run_id=run_input.parent_run_id + thread_id=run_agent_input.thread_id, + run_id=run_agent_input.run_id, + parent_run_id=run_agent_input.parent_run_id, ) def _convert_messages(self, messages: list[ag_ui.core.Message]) -> dict[str, orm.Message]: @@ -201,9 +200,9 @@ def _process_event(self, event: ag_ui.core.Event) -> ag_ui.core.Event | None: logger.warn("event", state="dropped", reason="already started") return None if ( - event.thread_id != self._thread_id - or event.run_id != self._run_id - or event.parent_run_id != self._parent_run_id + event.thread_id != self._run_agent_input.thread_id + or event.run_id != self._run_agent_input.run_id + or event.parent_run_id != self._run_agent_input.parent_run_id ): logger.warn( "event", @@ -214,18 +213,28 @@ def _process_event(self, event: ag_ui.core.Event) -> ag_ui.core.Event | None: ), ) event = self._override_event( - event, thread_id=self._thread_id, run_id=self._run_id, parent_run_id=self._parent_run_id + event, + thread_id=self._run_agent_input.thread_id, + run_id=self._run_agent_input.run_id, + parent_run_id=self._run_agent_input.parent_run_id, ) self._progress = RunProgress.STARTED case ag_ui.core.RunFinishedEvent(): - if event.thread_id != self._thread_id or event.run_id != self._run_id: + if ( + event.thread_id != self._run_agent_input.thread_id + or event.run_id != self._run_agent_input.run_id + ): logger.warn( "event", state="overridden", reason="mismatching lifecycle data", event_lifecycle_data=event.model_dump(include={"thread_id", "run_id"}, mode="json"), ) - event = self._override_event(event, thread_id=self._thread_id, run_id=self._run_id) + event = self._override_event( + event, + thread_id=self._run_agent_input.thread_id, + run_id=self._run_agent_input.run_id, + ) self._progress = RunProgress.FINISHED case ag_ui.core.RunErrorEvent(): self._progress = RunProgress.FINISHED @@ -492,11 +501,11 @@ def _apply_jsonpatch(document: dict[str, Any], patches: list[Any], *, logger: Fi def extract(self) -> orm.Run: delta_messages = list(self._messages.values()) + self._extract_messages() for m in delta_messages: - m.run_id = self._run_input.run_id + m.run_id = self._run_agent_input.run_id return orm.Run( - id=self._run_input.run_id, - thread_id=self._run_input.thread_id, - parent_run_id=self._run_input.parent_run_id, + id=self._run_agent_input.run_id, + thread_id=self._run_agent_input.thread_id, + parent_run_id=self._run_agent_input.parent_run_id, state=self._state, messages=delta_messages, ) diff --git a/tests/test_events.py b/tests/test_events.py index 8f0f4b0..0e5794b 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -24,7 +24,7 @@ def assert_equal(self, actual, expected): @pytest_cases.parametrize_with_cases("test_case", cases=test_events_cases.EventProcessingCases) async def test_event_processing(self, test_case: test_events_cases.EventProcessingCase): - run_input = ag_ui.core.RunAgentInput( + run_agent_input = ag_ui.core.RunAgentInput( thread_id=test_case.thread_id, run_id=test_case.run_id, parent_run_id=test_case.parent_run_id, @@ -34,7 +34,7 @@ async def test_event_processing(self, test_case: test_events_cases.EventProcessi context=[], forwarded_props=None, ) - event_processor = EventProcessor(run_input=run_input) + event_processor = EventProcessor(run_agent_input=run_agent_input) input = test_case.input if test_case.handle_run_lifecycle_events: From aa9aad25455ca0ac47d676545152c10c182c24ca Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 13 May 2026 23:11:22 +0200 Subject: [PATCH 3/7] simplify DB setup --- src/_ravnar/api/threads.py | 45 +++------ src/_ravnar/database.py | 194 +++++++++++++------------------------ src/_ravnar/orm.py | 1 + tests/api/test_threads.py | 2 +- 4 files changed, 79 insertions(+), 163 deletions(-) diff --git a/src/_ravnar/api/threads.py b/src/_ravnar/api/threads.py index 83e2de5..b953dc8 100644 --- a/src/_ravnar/api/threads.py +++ b/src/_ravnar/api/threads.py @@ -66,13 +66,10 @@ async def get_thread( @router.get("/{threadId}/messages") async def get_thread_messages( - id: Annotated[str, Path(alias="threadId")], + thread_id: Annotated[str, Path(alias="threadId")], user: schema.User = Depends(authenticated_user), # noqa: B008 ) -> list[schema.AugmentedMessage]: - latest_run = await database.get_latest_run(thread_id=id) - if latest_run is None: - return [] - messages = await database.get_run_messages(run_id=latest_run.id) + _, _, messages = await database.get_thread_history(user_id=user.id, thread_id=thread_id, run_id=None) return pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_python(messages, from_attributes=True) @router.get("/{threadId}/runs") @@ -103,7 +100,7 @@ async def get_run_messages( thread_id: Annotated[str, Path(alias="threadId")], run_id: Annotated[str, Path(alias="runId")], ) -> list[schema.AugmentedMessage]: - messages = await database.get_run_messages(run_id=run_id) + _, _, messages = await database.get_thread_history(user_id=user.id, thread_id=thread_id, run_id=run_id) return pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_python(messages, from_attributes=True) @router.sse("/{threadId}/runs", methods=["POST"], response_model=schema.Event, tags=["Runs"]) @@ -113,31 +110,13 @@ async def create_run( thread_id: Annotated[str, Path(alias="threadId")], data: schema.CreateRunData, ) -> fastsse.Response: - thread = await database.get_thread(user_id=user.id, id=thread_id) - - parent_run_id = data.parent_run_id - if parent_run_id is not None: - parent_run = await database.get_run(id=parent_run_id, user_id=user.id) - if parent_run.thread_id != thread_id: - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, - detail="parent_run_id does not belong to this thread", - ) - else: - latest_run = await database.get_latest_run(thread_id=thread_id) - parent_run_id = latest_run.id if latest_run is not None else None - - parent_run: orm.Run | None = None - parent_messages: list[schema.AugmentedMessage] = [] - if parent_run_id is not None: - parent_run, orm_messages = await database.get_run_history( - run_id=parent_run_id, user_id=user.id, thread_id=thread_id - ) - parent_messages = pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_python( - orm_messages, from_attributes=True - ) - - messages = list(parent_messages) + thread, parent_run, parent_messages = await database.get_thread_history( + user_id=user.id, thread_id=thread_id, run_id=data.parent_run_id + ) + + messages = pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_python( + parent_messages, from_attributes=True + ) messages.extend(data.messages) for m in messages: @@ -162,9 +141,9 @@ async def create_run( client_message_ids = {m.id for m in data.messages} run_agent_input = ag_ui.core.RunAgentInput( - thread_id=thread.id, + thread_id=thread_id, run_id=data.id, - parent_run_id=parent_run_id, + parent_run_id=data.parent_run_id, state=parent_run.state if parent_run is not None else None, messages=[pydantic.TypeAdapter(ag_ui.core.Message).validate_python(m.model_dump()) for m in messages], tools=data.tools, diff --git a/src/_ravnar/database.py b/src/_ravnar/database.py index 17e73e8..53b318c 100644 --- a/src/_ravnar/database.py +++ b/src/_ravnar/database.py @@ -9,11 +9,11 @@ from fastapi import HTTPException, status from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor -from sqlalchemy import Engine, Select, asc, create_engine, desc, func, inspect, select +from sqlalchemy import Engine, Select, asc, create_engine, desc, func, inspect, literal_column, select from sqlalchemy.engine.url import make_url from sqlalchemy.exc import InvalidRequestError from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine -from sqlalchemy.orm import Session, selectinload, sessionmaker +from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm.interfaces import ORMOption from starlette.concurrency import run_in_threadpool from typing_extensions import TypedDict @@ -215,18 +215,21 @@ async def create_run(self, *, run: orm.Run) -> None: async with self._get_session() as session: session.add(run) + async def _get_run(self, session: AsyncSession, *, id: str, user_id: str) -> orm.Run: + query = ( + select(orm.Run) + .join(orm.Thread, orm.Run.thread_id == orm.Thread.id) + .where((orm.Run.id == id) & (orm.Thread.user_id == user_id)) + ) + result = await session.execute(query) + run = result.scalar_one_or_none() + if run is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Run not found") + return run + async def get_run(self, *, id: str, user_id: str) -> orm.Run: async with self._get_session() as session: - query = ( - select(orm.Run) - .join(orm.Thread, orm.Run.thread_id == orm.Thread.id) - .where((orm.Run.id == id) & (orm.Thread.user_id == user_id)) - ) - result = await session.execute(query) - run = result.scalar_one_or_none() - if run is None: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Run not found") - return run + return await self._get_run(session, id=id, user_id=user_id) async def get_runs( self, @@ -246,131 +249,64 @@ def select_qualifier(query: Select) -> Select: session, orm_type=orm.Run, select_qualifier=select_qualifier, pagination=pagination ) - async def get_run_messages(self, *, run_id: str) -> list[orm.Message]: - async with self._get_session() as session: - from sqlalchemy import literal_column - - run_chain = ( - select( - orm.Run.id.label("run_id"), - orm.Run.parent_run_id.label("parent_run_id"), - literal_column("0").label("depth"), - ) - .where(orm.Run.id == run_id) - .cte(name="run_chain", recursive=True) - ) - run_chain = run_chain.union_all( - select(orm.Run.id, orm.Run.parent_run_id, (run_chain.c.depth + 1).label("depth")) - .select_from(orm.Run) - .join(run_chain, orm.Run.id == run_chain.c.parent_run_id) - ) - - ranked = ( - select( - orm.Message.uid, - func.row_number() - .over( - partition_by=orm.Message.id, - order_by=run_chain.c.depth.asc(), - ) - .label("rn"), - ) - .join(run_chain, orm.Message.run_id == run_chain.c.run_id) - .subquery() + async def _get_thread_messages(self, session: AsyncSession, *, run_id: str) -> list[orm.Message]: + run_chain = ( + select( + orm.Run.id.label("run_id"), + orm.Run.parent_run_id.label("parent_run_id"), + literal_column("0").label("depth"), ) + .where(orm.Run.id == run_id) + .cte(name="run_chain", recursive=True) + ) + run_chain = run_chain.union_all( + select(orm.Run.id, orm.Run.parent_run_id, (run_chain.c.depth + 1).label("depth")) + .select_from(orm.Run) + .join(run_chain, orm.Run.id == run_chain.c.parent_run_id) + ) - query = ( - select(orm.Message) - .join(ranked, orm.Message.uid == ranked.c.uid) - .where(ranked.c.rn == 1) - .options( - selectinload(orm.UserMessage.input_contents).selectinload(orm.InputContent.file), - selectinload(orm.AssistantMessage.tool_calls), - selectinload(orm.ToolMessage.tool_call).selectinload(orm.ToolCall.tool_message), + ranked = ( + select( + orm.Message.uid, + func.row_number() + .over( + partition_by=orm.Message.id, + order_by=run_chain.c.depth.asc(), ) - .order_by(orm.Message.created_at.asc(), orm.Message.id.asc()) + .label("rn"), ) + .join(run_chain, orm.Message.run_id == run_chain.c.run_id) + .subquery() + ) - result = await session.execute(query) - return result.unique().scalars().all() - - async def get_latest_run(self, *, thread_id: str) -> orm.Run | None: - async with self._get_session() as session: - query = select(orm.Run).where(orm.Run.thread_id == thread_id).order_by(orm.Run.created_at.desc()).limit(1) - result = await session.execute(query) - return result.scalar_one_or_none() + query = ( + select(orm.Message) + .join(ranked, orm.Message.uid == ranked.c.uid) + .where(ranked.c.rn == 1) + .order_by(orm.Message.created_at.asc(), orm.Message.id.asc()) + ) - async def get_run_history( - self, *, run_id: str, user_id: str, thread_id: str - ) -> tuple[orm.Run | None, list[orm.Message]]: - """Return the parent run and full message snapshot for a run chain. + result = await session.execute(query) + return result.unique().scalars().all() - Validates that the resolved run belongs to the given thread and user. - """ + async def get_thread_history( + self, *, user_id: str, thread_id: str, run_id: str | None + ) -> tuple[orm.Thread, orm.Run | None, list[orm.Message]]: async with self._get_session() as session: - # Resolve the target run and validate ownership in one query - query = ( - select(orm.Run) - .join(orm.Thread, orm.Run.thread_id == orm.Thread.id) - .where( - (orm.Run.id == run_id) - & (orm.Run.thread_id == thread_id) - & (orm.Thread.user_id == user_id) - ) - ) - result = await session.execute(query) - target_run = result.scalar_one_or_none() - if target_run is None: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Run not found") - - # Build the recursive CTE for the run chain - from sqlalchemy import literal_column - - run_chain = ( - select( - orm.Run.id.label("run_id"), - orm.Run.parent_run_id.label("parent_run_id"), - literal_column("0").label("depth"), - ) - .where(orm.Run.id == run_id) - .cte(name="run_chain", recursive=True) - ) - run_chain = run_chain.union_all( - select(orm.Run.id, orm.Run.parent_run_id, (run_chain.c.depth + 1).label("depth")) - .select_from(orm.Run) - .join(run_chain, orm.Run.id == run_chain.c.parent_run_id) - ) - - ranked = ( - select( - orm.Message.uid, - func.row_number() - .over( - partition_by=orm.Message.id, - order_by=run_chain.c.depth.asc(), - ) - .label("rn"), - ) - .join(run_chain, orm.Message.run_id == run_chain.c.run_id) - .subquery() - ) - - messages_query = ( - select(orm.Message) - .join(ranked, orm.Message.uid == ranked.c.uid) - .where(ranked.c.rn == 1) - .options( - selectinload(orm.UserMessage.input_contents).selectinload(orm.InputContent.file), - selectinload(orm.AssistantMessage.tool_calls), - selectinload(orm.ToolMessage.tool_call).selectinload(orm.ToolCall.tool_message), - ) - .order_by(orm.Message.created_at.asc(), orm.Message.id.asc()) - ) - - messages_result = await session.execute(messages_query) - messages = messages_result.unique().scalars().all() - - return target_run, list(messages) + thread = await self._get_thread(session, user_id=user_id, id=thread_id) + + if run_id is None: + if not thread.runs: + return thread, None, [] + run = thread.runs[-1] + else: + try: + run = next(r for r in thread.runs if r.id == run_id) + except StopIteration: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Parent run not found") from None + + messages = await self._get_thread_messages(session, run_id=run.id) + return thread, run, messages async def delete_threads(self, *, user_id: str, ids: list[str]) -> None: async with self._get_session() as session: diff --git a/src/_ravnar/orm.py b/src/_ravnar/orm.py index 8156ce5..358235e 100644 --- a/src/_ravnar/orm.py +++ b/src/_ravnar/orm.py @@ -117,6 +117,7 @@ class Thread(Base, kw_only=True, repr=False): back_populates="thread", cascade="all, delete-orphan", order_by="Run.created_at.asc()", + lazy="selectin", ) diff --git a/tests/api/test_threads.py b/tests/api/test_threads.py index 6354822..7abcfca 100644 --- a/tests/api/test_threads.py +++ b/tests/api/test_threads.py @@ -423,4 +423,4 @@ def test_invalid_parent_run_id(self, app_client): parent_run_id=run_id, ).model_dump(mode="json", by_alias=True), ) as event_source: - assert event_source.response.status_code == 422 + assert event_source.response.status_code == 404 From 83d4163b1e5b45b11196f9d3246ad7baa9aaa3be Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 13 May 2026 23:12:43 +0200 Subject: [PATCH 4/7] fix mypy --- src/_ravnar/database.py | 2 +- src/_ravnar/events.py | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/_ravnar/database.py b/src/_ravnar/database.py index 53b318c..2bb17dc 100644 --- a/src/_ravnar/database.py +++ b/src/_ravnar/database.py @@ -287,7 +287,7 @@ async def _get_thread_messages(self, session: AsyncSession, *, run_id: str) -> l ) result = await session.execute(query) - return result.unique().scalars().all() + return list(result.unique().scalars().all()) async def get_thread_history( self, *, user_id: str, thread_id: str, run_id: str | None diff --git a/src/_ravnar/events.py b/src/_ravnar/events.py index ff132e1..8cde375 100644 --- a/src/_ravnar/events.py +++ b/src/_ravnar/events.py @@ -220,10 +220,7 @@ def _process_event(self, event: ag_ui.core.Event) -> ag_ui.core.Event | None: ) self._progress = RunProgress.STARTED case ag_ui.core.RunFinishedEvent(): - if ( - event.thread_id != self._run_agent_input.thread_id - or event.run_id != self._run_agent_input.run_id - ): + if event.thread_id != self._run_agent_input.thread_id or event.run_id != self._run_agent_input.run_id: logger.warn( "event", state="overridden", @@ -362,6 +359,7 @@ def _process_event(self, event: ag_ui.core.Event) -> ag_ui.core.Event | None: message = self._messages.get(event.message_id) if message is None: message = self._parent_messages.get(event.message_id) + assert isinstance(message, orm.ActivityMessage) if message is None: logger.error( "event", From 5661750137d413235b0fa7827be3c1ff39a72a88 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 13 May 2026 23:18:52 +0200 Subject: [PATCH 5/7] cleanup tests --- tests/api/test_threads.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/tests/api/test_threads.py b/tests/api/test_threads.py index 7abcfca..ef1d2c5 100644 --- a/tests/api/test_threads.py +++ b/tests/api/test_threads.py @@ -340,25 +340,21 @@ def test_run_crud(self, app_client): ) list(event_stream) - # List runs response = app_client.get(f"/api/threads/{thread_id}/runs").raise_for_status() runs_page = schema.Page[schema.Run].model_validate_json(response.content) assert len(runs_page.items) == 1 assert runs_page.items[0].id == run_id - # Get run response = app_client.get(f"/api/threads/{thread_id}/runs/{run_id}").raise_for_status() run = schema.Run.model_validate_json(response.content) assert run.id == run_id assert run.thread_id == thread_id assert run.parent_run_id is None - # Get run messages response = app_client.get(f"/api/threads/{thread_id}/runs/{run_id}/messages").raise_for_status() messages = pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_json(response.content) assert len(messages) == 2 # user message + assistant response - # Backward compat thread messages returns latest run snapshot response = app_client.get(f"/api/threads/{thread_id}/messages").raise_for_status() thread_messages = pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_json(response.content) assert len(thread_messages) == 2 @@ -375,7 +371,6 @@ def test_child_run(self, app_client): ) list(event_stream) - # Create child run with explicit parent_run_id event_stream = self.create_run( app_client, thread_id=thread_id, @@ -391,15 +386,13 @@ def test_child_run(self, app_client): run2 = schema.Run.model_validate_json(response.content) assert run2.parent_run_id == run1_id - # Run2 messages include run1 messages + new ones response = app_client.get(f"/api/threads/{thread_id}/runs/{run2_id}/messages").raise_for_status() messages = pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_json(response.content) assert len(messages) == 4 # run1 user+assistant, run2 user+assistant - # Run1 messages unchanged response = app_client.get(f"/api/threads/{thread_id}/runs/{run1_id}/messages").raise_for_status() messages1 = pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_json(response.content) - assert len(messages1) == 2 + assert len(messages1) == 2 # run1 user+assistant def test_invalid_parent_run_id(self, app_client): thread1_id = self.create_thread(app_client).id @@ -413,7 +406,6 @@ def test_invalid_parent_run_id(self, app_client): ) list(event_stream) - # Try to use run from thread1 as parent for thread2 with httpx_sse.connect_sse( app_client, "POST", From 9a4cee8cff1e22e40b8e9d29461a5e62a1008260 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 13 May 2026 23:56:50 +0200 Subject: [PATCH 6/7] cleanup event processor --- src/_ravnar/api/threads.py | 10 ++---- src/_ravnar/database.py | 2 +- src/_ravnar/events.py | 71 ++++++++++++++++---------------------- tests/api/test_threads.py | 2 +- 4 files changed, 34 insertions(+), 51 deletions(-) diff --git a/src/_ravnar/api/threads.py b/src/_ravnar/api/threads.py index b953dc8..56bf6db 100644 --- a/src/_ravnar/api/threads.py +++ b/src/_ravnar/api/threads.py @@ -138,8 +138,6 @@ async def create_run( ) input_content.metadata = WrappedMetadata(raw=input_content.metadata, file_id=file.id) - client_message_ids = {m.id for m in data.messages} - run_agent_input = ag_ui.core.RunAgentInput( thread_id=thread_id, run_id=data.id, @@ -152,12 +150,8 @@ async def create_run( ) async def callback(event_processor: EventProcessor) -> None: - # Client-supplied messages are part of this run's delta, not inherited - for msg_id in client_message_ids: - if msg_id in event_processor._parent_messages: - event_processor._messages[msg_id] = event_processor._parent_messages.pop(msg_id) - run = event_processor.extract() - await database.create_run(run=run) + run = event_processor.extract(include_input_message_ids={m.id for m in data.messages}) + await database.create_run(run) return await agent_handler.run(thread.agent_id, run_agent_input, callback=callback) diff --git a/src/_ravnar/database.py b/src/_ravnar/database.py index 2bb17dc..cae7c2c 100644 --- a/src/_ravnar/database.py +++ b/src/_ravnar/database.py @@ -211,7 +211,7 @@ async def rename_thread(self, *, user_id: str, id: str, name: str) -> orm.Thread thread.name = name return thread - async def create_run(self, *, run: orm.Run) -> None: + async def create_run(self, run: orm.Run) -> None: async with self._get_session() as session: session.add(run) diff --git a/src/_ravnar/events.py b/src/_ravnar/events.py index 8cde375..79c4667 100644 --- a/src/_ravnar/events.py +++ b/src/_ravnar/events.py @@ -4,7 +4,7 @@ import enum import time import uuid -from collections.abc import AsyncIterable, AsyncIterator +from collections.abc import AsyncIterable, AsyncIterator, Collection from datetime import UTC, datetime from typing import TYPE_CHECKING, Any, TypeVar, cast @@ -78,7 +78,7 @@ def __init__(self, *, run_agent_input: ag_ui.core.RunAgentInput): self._run_agent_input = run_agent_input self._state = run_agent_input.state - self._parent_messages = self._convert_messages(run_agent_input.messages) + self._input_messages = self._convert_input_messages(run_agent_input.messages) self._messages: dict[str, orm.Message] = {} self._progress = RunProgress.NOT_STARTED @@ -94,7 +94,7 @@ def __init__(self, *, run_agent_input: ag_ui.core.RunAgentInput): parent_run_id=run_agent_input.parent_run_id, ) - def _convert_messages(self, messages: list[ag_ui.core.Message]) -> dict[str, orm.Message]: + def _convert_input_messages(self, messages: list[ag_ui.core.Message]) -> dict[str, orm.Message]: message_uids = {m.id: uuid.uuid4() for m in messages} tool_calls = { @@ -150,8 +150,9 @@ def _convert_messages(self, messages: list[ag_ui.core.Message]) -> dict[str, orm data = m.model_dump() data["uid"] = message_uids[m.id] - # Parent messages are never persisted; run_id is a placeholder to satisfy SQLAlchemy - data["run_id"] = "" + # setting the run ID to the current run here avoids the need to set it later in case the message is either + # promoted to the current run because it was mutated or it is actually an input message of the current run + data["run_id"] = self._run_agent_input.run_id converted_messages[m.id] = cls(**data) return converted_messages @@ -335,7 +336,9 @@ def _process_event(self, event: ag_ui.core.Event) -> ag_ui.core.Event | None: # return ag_ui.core.MessagesSnapshotEvent(messages=messages, timestamp=event.timestamp, raw_event=event) # activity events case ag_ui.core.ActivitySnapshotEvent(): - if event.message_id in self._messages and not event.replace: + if not event.replace and ( + event.message_id in self._messages or event.message_id in self._input_messages + ): logger.info( "event", state="dropped", @@ -345,12 +348,9 @@ def _process_event(self, event: ag_ui.core.Event) -> ag_ui.core.Event | None: ) return None - if event.message_id in self._parent_messages and not event.replace: - return None - self._messages[event.message_id] = orm.ActivityMessage( id=event.message_id, - run_id="", + run_id=self._run_agent_input.run_id, created_at=parse_timestamp(event.timestamp), content=event.content, activity_type=event.activity_type, @@ -358,25 +358,15 @@ def _process_event(self, event: ag_ui.core.Event) -> ag_ui.core.Event | None: case ag_ui.core.ActivityDeltaEvent(): message = self._messages.get(event.message_id) if message is None: - message = self._parent_messages.get(event.message_id) - assert isinstance(message, orm.ActivityMessage) - if message is None: - logger.error( - "event", - state="dropped", - reason="message does not exist", - message_id=event.message_id, - ) - return None - # Copy inherited message to delta bucket - self._messages[event.message_id] = orm.ActivityMessage( - id=message.id, - run_id="", - created_at=message.created_at, - content=message.content, - activity_type=message.activity_type, + message = self._input_messages.pop(event.message_id, None) + if message is None: + logger.error( + "event", + state="dropped", + reason="message does not exist", + message_id=event.message_id, ) - message = self._messages[event.message_id] + return None if not isinstance(message, orm.ActivityMessage): logger.error( @@ -496,19 +486,16 @@ def _apply_jsonpatch(document: dict[str, Any], patches: list[Any], *, logger: Fi ) return None - def extract(self) -> orm.Run: - delta_messages = list(self._messages.values()) + self._extract_messages() - for m in delta_messages: - m.run_id = self._run_agent_input.run_id + def extract(self, *, include_input_message_ids: Collection[str]) -> orm.Run: return orm.Run( id=self._run_agent_input.run_id, thread_id=self._run_agent_input.thread_id, parent_run_id=self._run_agent_input.parent_run_id, state=self._state, - messages=delta_messages, + messages=self._extract_messages(include_input_message_ids), ) - def _extract_messages(self) -> list[orm.Message]: + def _extract_messages(self, include_input_message_ids: Collection[str]) -> list[orm.Message]: grouped_tool_calls: dict[str, list[ToolCallData]] = {} for tcd in self._tool_call_data.values(): if not tcd.finished: @@ -533,7 +520,7 @@ def _extract_messages(self) -> list[orm.Message]: assistant_messages[tmd.message_id] = orm.AssistantMessage( uid=uuid.uuid4(), - run_id="", + run_id=self._run_agent_input.run_id, id=tmd.message_id, created_at=tmd.created_at, content="".join(tmd.content_deltas) or None, @@ -544,7 +531,7 @@ def _extract_messages(self) -> list[orm.Message]: if parent_message_id not in assistant_messages: assistant_messages[parent_message_id] = orm.AssistantMessage( uid=uuid.uuid4(), - run_id="", + run_id=self._run_agent_input.run_id, id=parent_message_id, created_at=min(tcd.created_at for tcd in tcds), content=None, @@ -566,9 +553,11 @@ def _extract_messages(self) -> list[orm.Message]: tool_calls[tool_call.id] = tool_call parent_msg.tool_calls.append(tool_call) - messages: list[orm.Message] = [] - - messages.extend(assistant_messages.values()) + messages: list[orm.Message] = [ + *(m for id, m in self._input_messages.items() if id in include_input_message_ids), + *self._messages.values(), + *assistant_messages.values(), + ] for rd in self._reasoning_data.values(): if not rd.finished: @@ -578,7 +567,7 @@ def _extract_messages(self) -> list[orm.Message]: messages.append( orm.ReasoningMessage( uid=uuid.uuid4(), - run_id="", + run_id=self._run_agent_input.run_id, id=rd.message_id, created_at=rd.created_at, content="".join(rd.content_deltas), @@ -601,7 +590,7 @@ def _extract_messages(self) -> list[orm.Message]: messages.append( orm.ToolMessage( uid=msg_uid, - run_id="", + run_id=self._run_agent_input.run_id, id=trd.message_id, created_at=trd.created_at, content=trd.content, diff --git a/tests/api/test_threads.py b/tests/api/test_threads.py index ef1d2c5..538fa07 100644 --- a/tests/api/test_threads.py +++ b/tests/api/test_threads.py @@ -415,4 +415,4 @@ def test_invalid_parent_run_id(self, app_client): parent_run_id=run_id, ).model_dump(mode="json", by_alias=True), ) as event_source: - assert event_source.response.status_code == 404 + assert event_source.response.status_code == status.HTTP_404_NOT_FOUND From 53459ec6d3425c3c474ceaa120ffcfdd3cf51434 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 14 May 2026 00:02:11 +0200 Subject: [PATCH 7/7] fix test --- tests/test_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_events.py b/tests/test_events.py index 0e5794b..bc16e4d 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -56,7 +56,7 @@ async def test_event_processing(self, test_case: test_events_cases.EventProcessi self.assert_equal(actual_event_stream, test_case.expected_event_stream) - actual_run = event_processor.extract() + actual_run = event_processor.extract(include_input_message_ids={m.id for m in test_case.messages}) actual_messages = pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_python( actual_run.messages, from_attributes=True )