diff --git a/src/bub/channels/runner.py b/src/bub/channels/runner.py index 9b8a810..45c50f7 100644 --- a/src/bub/channels/runner.py +++ b/src/bub/channels/runner.py @@ -48,7 +48,6 @@ async def process_message(self, channel: BaseChannel, message: Any) -> None: self._last_mentioned_at = None logger.info("session.receive ignored session_id={} message={}", self.session_id, prompt) return - self._prompts.append(prompt) if prompt.startswith(","): logger.info("session.receive.command session_id={} message={}", self.session_id, prompt) try: @@ -56,7 +55,9 @@ async def process_message(self, channel: BaseChannel, message: Any) -> None: await channel.process_output(self.session_id, result) except Exception: logger.exception("session.run.error session_id={}", self.session_id) - elif is_mentioned: + return + self._prompts.append(prompt) + if is_mentioned: # wait at most 1 second to reply to mentioned messages. self._last_mentioned_at = now logger.info("session.receive.mentioned session_id={} message={}", self.session_id, prompt) diff --git a/tests/test_session_runner.py b/tests/test_session_runner.py new file mode 100644 index 0000000..c166646 --- /dev/null +++ b/tests/test_session_runner.py @@ -0,0 +1,38 @@ +import pytest + +from bub.channels.runner import SessionRunner + + +class DummyChannel: + output_channel = "dummy" + + def __init__(self) -> None: + self.prompts: list[str] = [] + self.outputs: list[tuple[str, object]] = [] + + def is_mentioned(self, _message: object) -> bool: + return True + + async def get_session_prompt(self, message: object) -> tuple[str, str]: + assert isinstance(message, str) + return "session", message + + async def run_prompt(self, session_id: str, prompt: str) -> str: + self.prompts.append(prompt) + return f"result:{session_id}" + + async def process_output(self, session_id: str, output: object) -> None: + self.outputs.append((session_id, output)) + + +@pytest.mark.asyncio +async def test_command_prompt_is_not_buffered() -> None: + channel = DummyChannel() + runner = SessionRunner("session", debounce_seconds=1, message_delay_seconds=1, active_time_window_seconds=60) + + await runner.process_message(channel, ",help") + + assert channel.prompts == [",help"] + assert channel.outputs == [("session", "result:session")] + assert runner._prompts == [] + assert runner._running_task is None