-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
94 lines (82 loc) · 2.9 KB
/
main.py
File metadata and controls
94 lines (82 loc) · 2.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import asyncio
import contextlib
import sys
import threading
import time
from itertools import cycle
from loguru import logger
from core.config import configure_logging, load_app_config
from core.conversation_manager import ConversationManager
def run_spinner(message: str, stop_event: threading.Event) -> None:
start = time.monotonic()
for ch in cycle("-\\|/"):
if stop_event.is_set():
break
elapsed = time.monotonic() - start
line = f"{message} {ch} {elapsed:0.1f}s"
sys.stderr.write(f"\r{line}")
sys.stderr.flush()
time.sleep(0.1)
clear_len = len(message) + 8
sys.stderr.write(f"\r{' ' * clear_len}\r")
sys.stderr.flush()
async def main() -> None:
"""Main interactive chat loop."""
app_config = load_app_config()
configure_logging(app_config)
spinner_stop = threading.Event()
spinner_thread = threading.Thread(
target=run_spinner,
args=("Loading model", spinner_stop),
daemon=True,
)
spinner_thread.start()
conversation_manager = ConversationManager()
spinner_stop.set()
spinner_thread.join()
if conversation_manager.first_message:
print(f"{conversation_manager.first_message}") # noqa: T201
print() # noqa: T201
try:
while True:
try:
query = (await asyncio.to_thread(input, "User: ")).strip()
if not query:
continue
thinking_stop = threading.Event()
thinking_thread = threading.Thread(
target=run_spinner,
args=("Thinking", thinking_stop),
daemon=True,
)
thinking_thread.start()
spinner_stopped = False
def stop_spinner_once(
stop_event: threading.Event = thinking_stop,
spinner: threading.Thread = thinking_thread,
) -> None:
nonlocal spinner_stopped
if spinner_stopped:
return
stop_event.set()
spinner.join()
spinner_stopped = True
def stream_callback(chunk: str) -> None:
stop_spinner_once()
print(chunk, flush=True, end="") # noqa: T201
await conversation_manager.ask_question(query, stream_callback=stream_callback)
stop_spinner_once()
print() # noqa: T201
except KeyboardInterrupt:
print() # noqa: T201
raise
finally:
with contextlib.suppress(Exception):
spinner_stop.set()
spinner_thread.join()
del conversation_manager
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
logger.debug("Caught keyboard interrupt. Exiting...")