Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 48 additions & 19 deletions src/_ravnar/api/threads.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from __future__ import annotations

import base64
import uuid
from collections.abc import Callable
from typing import TYPE_CHECKING, Annotated, Any

import ag_ui.core
import ag_ui.encoder
import fastsse
import pydantic
from fastapi import Depends, HTTPException, Path, Query, status
Expand All @@ -15,7 +13,7 @@
from _ravnar import schema
from _ravnar.file_storage import FileHandler, WrappedMetadata
from _ravnar.observability import traced
from _ravnar.utils import as_awaitable, now
from _ravnar.utils import as_awaitable

tracer = trace.get_tracer(__name__)

Expand All @@ -26,8 +24,10 @@
from . import AgentHandler

ThreadsSortBy = str
RunsSortBy = str
else:
ThreadsSortBy = schema.create_str_literal("created_at", "updated_at", default="created_at")
ThreadsSortBy = schema.create_str_literal("created_at", default="created_at")
RunsSortBy = schema.create_str_literal("created_at", default="created_at")


def make_router(
Expand Down Expand Up @@ -70,46 +70,75 @@ async def get_thread(

@router.get("/{threadId}/messages")
async def get_thread_messages(
id: Annotated[str, Path(alias="threadId")],
thread_id: Annotated[str, Path(alias="threadId")],
user: schema.User = Depends(authenticated_user), # noqa: B008
) -> list[schema.AugmentedMessage]:
thread = await database.get_thread(user_id=user.id, id=id, with_messages=True)
return pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_python(
thread.messages, from_attributes=True
_, _, messages = await database.get_thread_history(user_id=user.id, thread_id=thread_id, run_id=None)
return pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_python(messages, from_attributes=True)

@router.get("/{threadId}/runs")
async def get_runs(
*,
user: schema.User = Depends(authenticated_user), # noqa: B008
thread_id: Annotated[str, Path(alias="threadId")],
pagination: Annotated[schema.Pagination[RunsSortBy], Query()],
) -> schema.Page[schema.Run]:
return schema.Page[schema.Run].model_validate(
await database.get_runs(user_id=user.id, thread_id=thread_id, pagination=pagination),
from_attributes=True,
)

@router.sse("/{threadId}/run", methods=["POST"], response_model=schema.Event, tags=["Runs"])
@router.get("/{threadId}/runs/{runId}")
async def get_run(
*,
user: schema.User = Depends(authenticated_user), # noqa: B008
thread_id: Annotated[str, Path(alias="threadId")],
run_id: Annotated[str, Path(alias="runId")],
) -> schema.Run:
return schema.Run.model_validate(await database.get_run(id=run_id, user_id=user.id), from_attributes=True)

@router.get("/{threadId}/runs/{runId}/messages")
async def get_run_messages(
*,
user: schema.User = Depends(authenticated_user), # noqa: B008
thread_id: Annotated[str, Path(alias="threadId")],
run_id: Annotated[str, Path(alias="runId")],
) -> list[schema.AugmentedMessage]:
_, _, messages = await database.get_thread_history(user_id=user.id, thread_id=thread_id, run_id=run_id)
return pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_python(messages, from_attributes=True)

@router.sse("/{threadId}/runs", methods=["POST"], response_model=schema.Event, tags=["Runs"])
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only intended BC break:

  • old: POST {threadId}/run
  • new: POST {threadId}/runs

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the impetus for the change? IMHO a POST endpoint should be a verb, not a noun.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Runs are now a proper object. So this becomes a regular CRUD endpoint alongside

  • GET /api/threads/{threadId}/runs
  • GET /api/threads/{threadId}/runs/{runId}

async def create_run(
*,
user: schema.User = Depends(authenticated_user), # noqa: B008
thread_id: Annotated[str, Path(alias="threadId")],
data: schema.CreateRunData,
) -> fastsse.Response:
thread = await database.get_thread(user_id=user.id, id=thread_id, with_messages=True)
thread, parent_run, parent_messages = await database.get_thread_history(
user_id=user.id, thread_id=thread_id, run_id=data.parent_run_id
)

messages = pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_python(
thread.messages, from_attributes=True
parent_messages, from_attributes=True
)
messages.extend(data.messages)

await hydrate_files(messages, user=user, file_handler=file_handler)

run_agent_input = ag_ui.core.RunAgentInput(
thread_id=thread.id,
run_id=str(uuid.uuid4()),
parent_run_id=None,
state=thread.state,
thread_id=thread_id,
run_id=data.id,
parent_run_id=data.parent_run_id,
state=parent_run.state if parent_run is not None else None,
messages=[pydantic.TypeAdapter(ag_ui.core.Message).validate_python(m.model_dump()) for m in messages],
tools=data.tools,
context=data.context,
forwarded_props=data.forwarded_props,
)

async def callback(event_processor: EventProcessor) -> None:
with tracer.start_as_current_span("persist_run"):
thread.state, thread.messages = event_processor.extract()
thread.updated_at = now()
await database.update_thread(thread)
run = event_processor.extract(include_input_message_ids={m.id for m in data.messages})
await database.create_run(run)

return await agent_handler.run(thread.agent_id, run_agent_input, callback=callback)

Expand Down
8 changes: 1 addition & 7 deletions src/_ravnar/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,7 @@ async def run(
self.assert_available(agent_id)
agent = self._agents[agent_id]

event_processor = EventProcessor(
thread_id=run_agent_input.thread_id,
run_id=run_agent_input.run_id,
parent_run_id=run_agent_input.parent_run_id,
state=run_agent_input.state,
messages=run_agent_input.messages,
)
event_processor = EventProcessor(run_agent_input=run_agent_input)

span = tracer.start_span("AgentHandler.run")
span.set_attribute("agent_id", agent_id)
Expand Down
127 changes: 106 additions & 21 deletions src/_ravnar/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@

from fastapi import HTTPException, status
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
from sqlalchemy import Engine, Select, asc, create_engine, desc, func, inspect, select
from sqlalchemy import Engine, Select, asc, create_engine, desc, func, inspect, literal_column, select
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import InvalidRequestError
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import Session, selectinload, sessionmaker
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm.interfaces import ORMOption
from starlette.concurrency import run_in_threadpool
from typing_extensions import TypedDict
Expand Down Expand Up @@ -166,16 +166,13 @@ async def create_thread(self, *, user_id: str, id: str, name: str | None, agent_
if thread is not None:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Thread exists")

created_at = updated_at = now()
thread = orm.Thread(
id=id,
user_id=user_id,
agent_id=agent_id,
name=name,
created_at=created_at,
updated_at=updated_at,
state=None,
messages=[],
created_at=now(),
runs=[],
)
session.add(thread)
return thread
Expand All @@ -202,38 +199,126 @@ async def get_threads(self, *, user_id: str, pagination: schema.Pagination) -> o
async with self._get_session() as session:
return await self._get_threads(session, user_id=user_id, pagination=pagination)

async def _get_thread(self, session: AsyncSession, *, user_id: str, id: str, with_messages: bool) -> orm.Thread:
async def _get_thread(self, session: AsyncSession, *, user_id: str, id: str) -> orm.Thread:
query = select(orm.Thread).where((orm.Thread.id == id) & (orm.Thread.user_id == user_id))
if with_messages:
query = query.options(selectinload(orm.Thread.messages))
result = await session.execute(query)
thread = result.scalar_one_or_none()
if thread is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Thread not found")
return thread

@traced
async def get_thread(self, *, user_id: str, id: str, with_messages: bool = False) -> orm.Thread:
async def get_thread(self, *, user_id: str, id: str) -> orm.Thread:
async with self._get_session() as session:
return await self._get_thread(session, user_id=user_id, id=id, with_messages=with_messages)

@traced
async def append_messages_to_thread(self, *, user_id: str, id: str, messages: list[orm.Message]) -> None:
async with self._get_session() as session:
thread = await self._get_thread(session, user_id=user_id, id=id, with_messages=False)
thread.messages.extend(messages)
return await self._get_thread(session, user_id=user_id, id=id)

@traced
async def rename_thread(self, *, user_id: str, id: str, name: str) -> orm.Thread:
async with self._get_session() as session:
thread = await self._get_thread(session, user_id=user_id, id=id, with_messages=False)
thread = await self._get_thread(session, user_id=user_id, id=id)
thread.name = name
return thread

@traced
async def update_thread(self, thread: orm.Thread) -> None:
async def create_run(self, run: orm.Run) -> None:
async with self._get_session() as session:
session.add(run)

async def _get_run(self, session: AsyncSession, *, id: str, user_id: str) -> orm.Run:
query = (
select(orm.Run)
.join(orm.Thread, orm.Run.thread_id == orm.Thread.id)
.where((orm.Run.id == id) & (orm.Thread.user_id == user_id))
)
result = await session.execute(query)
run = result.scalar_one_or_none()
if run is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Run not found")
return run

@traced
async def get_run(self, *, id: str, user_id: str) -> orm.Run:
async with self._get_session() as session:
return await self._get_run(session, id=id, user_id=user_id)

@traced
async def get_runs(
self,
*,
user_id: str,
thread_id: str,
pagination: schema.Pagination | None = None,
) -> orm.Page[orm.Run]:
async with self._get_session() as session:

def select_qualifier(query: Select) -> Select:
return query.join(orm.Thread, orm.Run.thread_id == orm.Thread.id).where(
(orm.Run.thread_id == thread_id) & (orm.Thread.user_id == user_id)
)

return await self._get_page(
session, orm_type=orm.Run, select_qualifier=select_qualifier, pagination=pagination
)

async def _get_thread_messages(self, session: AsyncSession, *, run_id: str) -> list[orm.Message]:
run_chain = (
select(
orm.Run.id.label("run_id"),
orm.Run.parent_run_id.label("parent_run_id"),
literal_column("0").label("depth"),
)
.where(orm.Run.id == run_id)
.cte(name="run_chain", recursive=True)
)
run_chain = run_chain.union_all(
select(orm.Run.id, orm.Run.parent_run_id, (run_chain.c.depth + 1).label("depth"))
.select_from(orm.Run)
.join(run_chain, orm.Run.id == run_chain.c.parent_run_id)
)

ranked = (
select(
orm.Message.uid,
func.row_number()
.over(
partition_by=orm.Message.id,
order_by=run_chain.c.depth.asc(),
)
.label("rn"),
)
.join(run_chain, orm.Message.run_id == run_chain.c.run_id)
.subquery()
)

query = (
select(orm.Message)
.join(ranked, orm.Message.uid == ranked.c.uid)
.where(ranked.c.rn == 1)
.order_by(orm.Message.created_at.asc(), orm.Message.id.asc())
)

result = await session.execute(query)
return list(result.unique().scalars().all())

@traced
async def get_thread_history(
self, *, user_id: str, thread_id: str, run_id: str | None
) -> tuple[orm.Thread, orm.Run | None, list[orm.Message]]:
async with self._get_session() as session:
await session.merge(thread)
thread = await self._get_thread(session, user_id=user_id, id=thread_id)

if run_id is None:
if not thread.runs:
return thread, None, []
run = thread.runs[-1]
else:
try:
run = next(r for r in thread.runs if r.id == run_id)
except StopIteration:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Parent run not found") from None

messages = await self._get_thread_messages(session, run_id=run.id)
return thread, run, messages

@traced
async def delete_threads(self, *, user_id: str, ids: list[str]) -> None:
Expand Down
Loading
Loading