Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions src/bub/app/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from bub.core import AgentLoop, InputRouter, LoopResult, ModelRunner
from bub.integrations.republic_client import build_llm, build_tape_store, read_workspace_agents_prompt
from bub.skills.loader import SkillMetadata, discover_skills
from bub.tape import TapeService
from bub.tape import TapeService, default_tape_context
from bub.tools import ProgressiveToolView, ToolRegistry
from bub.tools.builtin import register_builtin_tools

Expand All @@ -47,7 +47,9 @@ class SessionRuntime:
tool_view: ProgressiveToolView

async def handle_input(self, text: str) -> LoopResult:
return await self.loop.handle_input(text)
with self.tape.fork_tape() as tape:
tape.context = default_tape_context({"session_id": self.session_id})
return await self.loop.handle_input(text)

def reset_context(self) -> None:
"""Clear volatile in-memory context while keeping the same session identity."""
Expand Down Expand Up @@ -108,13 +110,7 @@ def get_session(self, session_id: str) -> SessionRuntime:
tape.ensure_bootstrap_anchor()

registry = ToolRegistry(self._allowed_tools)
register_builtin_tools(
registry,
workspace=self.workspace,
tape=tape,
runtime=self,
session_id=session_id,
)
register_builtin_tools(registry, workspace=self.workspace, tape=tape, runtime=self)
tool_view = ProgressiveToolView(registry)
router = InputRouter(registry, tool_view, tape, self.workspace)
runner = ModelRunner(
Expand Down
49 changes: 24 additions & 25 deletions src/bub/core/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,36 +29,35 @@ def __init__(self, *, router: InputRouter, model_runner: ModelRunner, tape: Tape
self._tape = tape

async def handle_input(self, raw: str) -> LoopResult:
with self._tape.fork_tape():
route = await self._router.route_user(raw)
if route.exit_requested:
return LoopResult(
immediate_output=route.immediate_output,
assistant_output="",
exit_requested=True,
steps=0,
error=None,
)

if not route.enter_model:
return LoopResult(
immediate_output=route.immediate_output,
assistant_output="",
exit_requested=False,
steps=0,
error=None,
)
route = await self._router.route_user(raw)
if route.exit_requested:
return LoopResult(
immediate_output=route.immediate_output,
assistant_output="",
exit_requested=True,
steps=0,
error=None,
)

model_result = await self._model_runner.run(route.model_prompt)
self._record_result(model_result)
if not route.enter_model:
return LoopResult(
immediate_output=route.immediate_output,
assistant_output=model_result.visible_text,
exit_requested=model_result.exit_requested,
steps=model_result.steps,
error=model_result.error,
assistant_output="",
exit_requested=False,
steps=0,
error=None,
)

model_result = await self._model_runner.run(route.model_prompt)
self._record_result(model_result)
return LoopResult(
immediate_output=route.immediate_output,
assistant_output=model_result.visible_text,
exit_requested=model_result.exit_requested,
steps=model_result.steps,
error=model_result.error,
)

def _record_result(self, result: ModelTurnResult) -> None:
self._tape.append_event(
"loop.result",
Expand Down
3 changes: 2 additions & 1 deletion src/bub/tape/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Tape helpers."""

from bub.tape.anchors import AnchorSummary
from bub.tape.context import default_tape_context
from bub.tape.service import TapeService
from bub.tape.store import FileTapeStore

__all__ = ["AnchorSummary", "FileTapeStore", "TapeService"]
__all__ = ["AnchorSummary", "FileTapeStore", "TapeService", "default_tape_context"]
8 changes: 4 additions & 4 deletions src/bub/tape/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@
from __future__ import annotations

import json
from collections.abc import Sequence
from collections.abc import Iterable
from typing import Any

from republic import TapeContext, TapeEntry


def default_tape_context() -> TapeContext:
def default_tape_context(state: dict[str, Any] | None = None) -> TapeContext:
"""Return the default context selection for Bub."""

return TapeContext(select=_select_messages)
return TapeContext(select=_select_messages, state=state or {})


def _select_messages(entries: Sequence[TapeEntry], _context: TapeContext) -> list[dict[str, Any]]:
def _select_messages(entries: Iterable[TapeEntry], _context: TapeContext) -> list[dict[str, Any]]:
messages: list[dict[str, Any]] = []
pending_calls: list[dict[str, Any]] = []

Expand Down
23 changes: 10 additions & 13 deletions src/bub/tape/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,11 @@ def fork_tape(self) -> Generator[Tape, None, None]:
logger.info("Merged forked tape '{}' back into '{}'", fork_name, self._tape.name)

def ensure_bootstrap_anchor(self) -> None:
anchors = [entry for entry in self.read_entries() if entry.kind == "anchor"]
anchors = list(self._tape.query.kinds("anchor").all())
if anchors:
return
self.handoff("session/start", state={"owner": "human"})

def read_entries(self) -> list[TapeEntry]:
return cast(list[TapeEntry], self.tape.read_entries())

def handoff(self, name: str, *, state: dict[str, Any] | None = None) -> list[TapeEntry]:
return cast(list[TapeEntry], self.tape.handoff(name, state=state))

Expand All @@ -88,7 +85,7 @@ def append_system(self, content: str) -> None:
self.tape.append(TapeEntry.system(content))

def info(self) -> TapeInfo:
entries = self._tape.read_entries()
entries = list(self._tape.query.all())
anchors = [entry for entry in entries if entry.kind == "anchor"]
last_anchor = anchors[-1].payload.get("name") if anchors else None
if last_anchor is not None:
Expand All @@ -115,7 +112,7 @@ def reset(self, *, archive: bool = False) -> str:
return f"Archived: {archive_path}" if archive_path else "ok"

def anchors(self, *, limit: int = 20) -> list[AnchorSummary]:
entries = [entry for entry in self._tape.read_entries() if entry.kind == "anchor"]
entries = list(self._tape.query.kinds("anchor").all())
results: list[AnchorSummary] = []
for entry in entries[-limit:]:
name = str(entry.payload.get("name", "-"))
Expand All @@ -125,22 +122,22 @@ def anchors(self, *, limit: int = 20) -> list[AnchorSummary]:
return results

def between_anchors(self, start: str, end: str, *, kinds: tuple[str, ...] = ()) -> list[TapeEntry]:
query = self.tape.query().between_anchors(start, end)
query = self.tape.query.between_anchors(start, end)
if kinds:
query = query.kinds(*kinds)
return cast(list[TapeEntry], query.all())
return list(query.all())

def after_anchor(self, anchor: str, *, kinds: tuple[str, ...] = ()) -> list[TapeEntry]:
query = self.tape.query().after_anchor(anchor)
query = self.tape.query.after_anchor(anchor)
if kinds:
query = query.kinds(*kinds)
return cast(list[TapeEntry], query.all())
return list(query.all())

def from_last_anchor(self, *, kinds: tuple[str, ...] = ()) -> list[TapeEntry]:
query = self.tape.query().last_anchor()
query = self.tape.query.last_anchor()
if kinds:
query = query.kinds(*kinds)
return cast(list[TapeEntry], query.all())
return list(query.all())

def search(self, query: str, *, limit: int = 20, all_tapes: bool = False) -> list[TapeEntry]:
normalized_query = query.strip().lower()
Expand All @@ -153,7 +150,7 @@ def search(self, query: str, *, limit: int = 20, all_tapes: bool = False) -> lis

for tape in tapes:
count = 0
for entry in reversed(tape.read_entries()):
for entry in reversed(list(tape.query.all())):
payload_text = json.dumps(entry.payload, ensure_ascii=False)
entry_meta = getattr(entry, "meta", {})
meta_text = json.dumps(entry_meta, ensure_ascii=False)
Expand Down
4 changes: 2 additions & 2 deletions src/bub/tape/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from typing import cast
from urllib.parse import quote, unquote

from republic.tape import TapeEntry
from republic.tape import InMemoryQueryMixin, TapeEntry

TAPE_FILE_SUFFIX = ".jsonl"

Expand Down Expand Up @@ -151,7 +151,7 @@ def archive(self) -> Path | None:
return archive_file


class FileTapeStore:
class FileTapeStore(InMemoryQueryMixin):
"""Append-only JSONL tape store compatible with Republic TapeStore protocol."""

def __init__(self, home: Path, workspace_path: Path) -> None:
Expand Down
25 changes: 14 additions & 11 deletions src/bub/tools/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,15 +171,14 @@ def register_builtin_tools(
workspace: Path,
tape: TapeService,
runtime: AppRuntime,
session_id: str,
) -> None:
"""Register built-in tools and internal commands."""
from bub.tools.schedule import run_scheduled_reminder

register = registry.register

@register(name="bash", short_description="Run shell command", model=BashInput)
async def run_bash(params: BashInput) -> str:
@register(name="bash", short_description="Run shell command", model=BashInput, context=True)
async def run_bash(params: BashInput, context: ToolContext) -> str:
"""Execute bash in workspace. Non-zero exit raises an error.
IMPORTANT: please DO NOT use sleep to delay execution, use schedule.add tool instead.
"""
Expand All @@ -191,7 +190,7 @@ async def run_bash(params: BashInput) -> str:
workspace_env = workspace / ".env"
if workspace_env.is_file():
env.update((k, v) for k, v in dotenv.dotenv_values(workspace_env).items() if v is not None)
env[SESSION_ID_ENV_VAR] = session_id
env[SESSION_ID_ENV_VAR] = context.state.get("session_id", "")
completed = await asyncio.create_subprocess_exec(
executable,
"-lc",
Expand Down Expand Up @@ -299,7 +298,11 @@ def schedule_add(params: ScheduleAddInput, context: ToolContext) -> str:
run_scheduled_reminder,
trigger=trigger,
id=job_id,
kwargs={"message": params.message, "session_id": session_id, "workspace": str(runtime.workspace)},
kwargs={
"message": params.message,
"session_id": context.state.get("session_id", ""),
"workspace": str(runtime.workspace),
},
coalesce=True,
max_instances=1,
)
Expand All @@ -320,8 +323,8 @@ def schedule_remove(params: ScheduleRemoveInput) -> str:
raise RuntimeError(f"job not found: {params.job_id}") from exc
return f"removed: {params.job_id}"

@register(name="schedule.list", short_description="List scheduled jobs", model=EmptyInput)
def schedule_list(_params: EmptyInput) -> str:
@register(name="schedule.list", short_description="List scheduled jobs", model=EmptyInput, context=True)
def schedule_list(_params: EmptyInput, context: ToolContext) -> str:
"""List scheduled jobs for current workspace."""
jobs = runtime.scheduler.get_jobs()
rows: list[str] = []
Expand All @@ -331,7 +334,7 @@ def schedule_list(_params: EmptyInput) -> str:
next_run = job.next_run_time.isoformat()
message = str(job.kwargs.get("message", ""))
job_session = job.kwargs.get("session_id")
if job_session and job_session != session_id:
if job_session and job_session != context.state.get("session_id", ""):
continue
rows.append(f"{job.id} next={next_run} msg={message}")

Expand Down Expand Up @@ -477,11 +480,11 @@ def tape_search(params: TapeSearchInput) -> str:
return "(no matches)"
return "\n".join(f"#{entry.id} {entry.kind} {entry.payload}" for entry in entries)

@register(name="tape.reset", short_description="Reset tape", model=TapeResetInput)
def tape_reset(params: TapeResetInput) -> str:
@register(name="tape.reset", short_description="Reset tape", model=TapeResetInput, context=True)
def tape_reset(params: TapeResetInput, context: ToolContext) -> str:
"""Reset current tape; can archive before clearing."""
result = tape.reset(archive=params.archive)
runtime.reset_session_context(session_id)
runtime.reset_session_context(context.state.get("session_id", ""))
return result

@register(name="skills.list", short_description="List skills", model=EmptyInput)
Expand Down
11 changes: 8 additions & 3 deletions tests/test_tape_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ class FakeEntry:


class FakeTape:
class _Query:
def __init__(self, tape: "FakeTape") -> None:
self._tape = tape

def all(self) -> list[FakeEntry]:
return list(self._tape.entries)

def __init__(self) -> None:
self.name = "fake"
self.entries: list[FakeEntry] = [
Expand All @@ -23,9 +30,7 @@ def __init__(self) -> None:
)
]
self.reset_calls = 0

def read_entries(self) -> list[FakeEntry]:
return list(self.entries)
self.query = self._Query(self)

def handoff(self, name: str, state: dict[str, object] | None = None) -> list[FakeEntry]:
entry = FakeEntry(
Expand Down
17 changes: 11 additions & 6 deletions tests/test_tools_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import pytest
from apscheduler.schedulers.background import BackgroundScheduler
from republic import ToolContext

from bub.config.settings import Settings
from bub.tools.builtin import register_builtin_tools
Expand Down Expand Up @@ -67,15 +68,21 @@ def _build_registry(workspace: Path, settings: Settings, scheduler: BackgroundSc
workspace=workspace,
tape=_DummyTape(), # type: ignore[arg-type]
runtime=runtime, # type: ignore[arg-type]
session_id="cli:test",
)
return registry


def _execute_tool(registry: ToolRegistry, name: str, *, kwargs: dict[str, Any]) -> Any:
def _execute_tool(
registry: ToolRegistry,
name: str,
*,
kwargs: dict[str, Any],
session_id: str = "cli:test",
) -> Any:
descriptor = registry.get(name)
context = ToolContext(tape="test", run_id="test-run", state={"session_id": session_id})
if descriptor is not None and descriptor.tool.context:
result = descriptor.tool.run(context=None, **kwargs)
result = descriptor.tool.run(context=context, **kwargs)
else:
result = registry.execute(name, kwargs=kwargs)
if inspect.isawaitable(result):
Expand Down Expand Up @@ -356,7 +363,6 @@ def discover_skills(self) -> list[_Skill]:
workspace=tmp_path,
tape=_DummyTape(), # type: ignore[arg-type]
runtime=runtime, # type: ignore[arg-type]
session_id="cli:test",
)

assert _execute_tool(registry, "skills.list", kwargs={}) == "alpha: first"
Expand Down Expand Up @@ -428,9 +434,8 @@ def test_tape_reset_also_clears_session_runtime_context(tmp_path: Path, schedule
workspace=tmp_path,
tape=_DummyTape(), # type: ignore[arg-type]
runtime=runtime, # type: ignore[arg-type]
session_id="telegram:123",
)

result = _execute_tool(registry, "tape.reset", kwargs={"archive": True})
result = _execute_tool(registry, "tape.reset", kwargs={"archive": True}, session_id="telegram:123")
assert result == "reset"
assert runtime.reset_calls == ["telegram:123"]
6 changes: 3 additions & 3 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading