Skip to content

Commit a7aff57

Browse files
Refine file history provider locking
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 5f37eac commit a7aff57

2 files changed

Lines changed: 86 additions & 36 deletions

File tree

python/packages/core/agent_framework/_sessions.py

Lines changed: 62 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import json
1919
import threading
2020
import uuid
21+
import weakref
2122
from abc import abstractmethod
2223
from base64 import urlsafe_b64encode
2324
from collections.abc import Awaitable, Callable, Mapping, Sequence
@@ -880,9 +881,11 @@ class FileHistoryProvider(HistoryProvider):
880881
DEFAULT_SOURCE_ID: ClassVar[str] = "file_history"
881882
DEFAULT_SESSION_FILE_STEM: ClassVar[str] = "default"
882883
FILE_EXTENSION: ClassVar[str] = ".jsonl"
884+
_FILE_LOCK_STRIPE_COUNT: ClassVar[int] = 64
883885
_ENCODED_SESSION_PREFIX: ClassVar[str] = "~session-"
884-
_FILE_WRITE_LOCKS: ClassVar[dict[Path, threading.Lock]] = {}
885-
_FILE_WRITE_LOCKS_GUARD: ClassVar[threading.Lock] = threading.Lock()
886+
_FILE_WRITE_LOCKS: ClassVar[tuple[threading.Lock, ...]] = tuple(
887+
threading.Lock() for _ in range(_FILE_LOCK_STRIPE_COUNT)
888+
)
886889
_WINDOWS_RESERVED_FILE_STEMS: ClassVar[frozenset[str]] = frozenset({
887890
"CON",
888891
"PRN",
@@ -955,6 +958,10 @@ def __init__(
955958
self.skip_excluded = skip_excluded
956959
self.dumps = dumps or _default_json_dumps
957960
self.loads = loads or _default_json_loads
961+
self._async_write_locks_by_loop: weakref.WeakKeyDictionary[
962+
asyncio.AbstractEventLoop,
963+
tuple[asyncio.Lock, ...],
964+
] = weakref.WeakKeyDictionary()
958965

959966
async def get_messages(
960967
self,
@@ -966,38 +973,42 @@ async def get_messages(
966973
"""Retrieve messages from the session's JSON Lines file."""
967974
del state, kwargs
968975
file_path = self._session_file_path(session_id)
976+
async_lock = self._session_async_write_lock(file_path)
977+
thread_lock = self._session_write_lock(file_path)
969978

970979
def _read_messages() -> list[Message]:
971-
if not file_path.exists():
972-
return []
973-
974-
messages: list[Message] = []
975-
with file_path.open(encoding="utf-8") as file_handle:
976-
for line_number, line in enumerate(file_handle, start=1):
977-
serialized = line.strip()
978-
if not serialized:
979-
continue
980-
try:
981-
payload = self.loads(serialized)
982-
except (TypeError, ValueError) as exc:
983-
raise ValueError(
984-
f"Failed to deserialize history line {line_number} from '{file_path}'."
985-
) from exc
986-
if not isinstance(payload, Mapping):
987-
raise ValueError(
988-
f"History line {line_number} in '{file_path}' did not deserialize to a mapping."
989-
)
990-
991-
try:
992-
message = Message.from_dict(dict(cast(Mapping[str, Any], payload)))
993-
except ValueError as exc:
994-
raise ValueError(
995-
f"History line {line_number} in '{file_path}' is not a valid Message payload."
996-
) from exc
997-
messages.append(message)
998-
return messages
999-
1000-
messages = await asyncio.to_thread(_read_messages)
980+
with thread_lock:
981+
if not file_path.exists():
982+
return []
983+
984+
messages: list[Message] = []
985+
with file_path.open(encoding="utf-8") as file_handle:
986+
for line_number, line in enumerate(file_handle, start=1):
987+
serialized = line.strip()
988+
if not serialized:
989+
continue
990+
try:
991+
payload = self.loads(serialized)
992+
except (TypeError, ValueError) as exc:
993+
raise ValueError(
994+
f"Failed to deserialize history line {line_number} from '{file_path}'."
995+
) from exc
996+
if not isinstance(payload, Mapping):
997+
raise ValueError(
998+
f"History line {line_number} in '{file_path}' did not deserialize to a mapping."
999+
)
1000+
1001+
try:
1002+
message = Message.from_dict(dict(cast(Mapping[str, Any], payload)))
1003+
except ValueError as exc:
1004+
raise ValueError(
1005+
f"History line {line_number} in '{file_path}' is not a valid Message payload."
1006+
) from exc
1007+
messages.append(message)
1008+
return messages
1009+
1010+
async with async_lock:
1011+
messages = await asyncio.to_thread(_read_messages)
10011012
if self.skip_excluded:
10021013
messages = [m for m in messages if not m.additional_properties.get("_excluded", False)]
10031014
return messages
@@ -1016,14 +1027,16 @@ async def save_messages(
10161027
return
10171028

10181029
file_path = self._session_file_path(session_id)
1030+
async_lock = self._session_async_write_lock(file_path)
10191031
file_lock = self._session_write_lock(file_path)
10201032

10211033
def _append_messages() -> None:
10221034
with file_lock, file_path.open("a", encoding="utf-8") as file_handle:
10231035
for message in messages:
10241036
file_handle.write(f"{self._serialize_message(message)}\n")
10251037

1026-
await asyncio.to_thread(_append_messages)
1038+
async with async_lock:
1039+
await asyncio.to_thread(_append_messages)
10271040

10281041
def _serialize_message(self, message: Message) -> str:
10291042
"""Serialize a message payload to a single JSON Lines record."""
@@ -1055,11 +1068,24 @@ def _session_file_stem(self, session_id: str | None) -> str:
10551068
encoded_session_id = urlsafe_b64encode(raw_session_id.encode("utf-8")).decode("ascii").rstrip("=")
10561069
return f"{self._ENCODED_SESSION_PREFIX}{encoded_session_id or self.DEFAULT_SESSION_FILE_STEM}"
10571070

1071+
def _session_async_write_lock(self, file_path: Path) -> asyncio.Lock:
1072+
"""Return the event-loop-local async lock for a session history file."""
1073+
loop = asyncio.get_running_loop()
1074+
locks = self._async_write_locks_by_loop.get(loop)
1075+
if locks is None:
1076+
locks = tuple(asyncio.Lock() for _ in range(self._FILE_LOCK_STRIPE_COUNT))
1077+
self._async_write_locks_by_loop[loop] = locks
1078+
return locks[self._lock_index(file_path)]
1079+
10581080
@classmethod
10591081
def _session_write_lock(cls, file_path: Path) -> threading.Lock:
1060-
"""Return the process-local append lock for a session history file."""
1061-
with cls._FILE_WRITE_LOCKS_GUARD:
1062-
return cls._FILE_WRITE_LOCKS.setdefault(file_path, threading.Lock())
1082+
"""Return the process-local thread lock for a session history file."""
1083+
return cls._FILE_WRITE_LOCKS[cls._lock_index(file_path)]
1084+
1085+
@classmethod
1086+
def _lock_index(cls, file_path: Path) -> int:
1087+
"""Map a session history file to a bounded lock stripe."""
1088+
return hash(file_path) % cls._FILE_LOCK_STRIPE_COUNT
10631089

10641090
@classmethod
10651091
def _is_literal_session_file_stem_safe(cls, session_id: str) -> bool:

python/packages/core/tests/core/test_sessions.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,30 @@ async def test_invalid_jsonl_line_raises(self, tmp_path: Path) -> None:
620620
with pytest.raises(ValueError, match="Failed to deserialize history line 1"):
621621
await provider.get_messages("broken")
622622

623+
async def test_missing_session_file_returns_empty_messages(self, tmp_path: Path) -> None:
624+
provider = FileHistoryProvider(tmp_path)
625+
626+
loaded = await provider.get_messages("missing")
627+
628+
assert loaded == []
629+
630+
async def test_none_session_id_uses_default_jsonl_file(self, tmp_path: Path) -> None:
631+
provider = FileHistoryProvider(tmp_path)
632+
633+
await provider.save_messages(None, [Message(role="user", contents=["hello"])])
634+
635+
session_file = provider._session_file_path(None)
636+
assert session_file.name == "default.jsonl"
637+
loaded = await provider.get_messages(None)
638+
assert [message.text for message in loaded] == ["hello"]
639+
640+
async def test_non_mapping_jsonl_line_raises(self, tmp_path: Path) -> None:
641+
provider = FileHistoryProvider(tmp_path)
642+
await asyncio.to_thread(provider._session_file_path("non-mapping").write_text, "[1, 2, 3]\n", encoding="utf-8")
643+
644+
with pytest.raises(ValueError, match="did not deserialize to a mapping"):
645+
await provider.get_messages("non-mapping")
646+
623647
async def test_skip_excluded_omits_excluded_messages(self, tmp_path: Path) -> None:
624648
provider = FileHistoryProvider(tmp_path, skip_excluded=True)
625649

0 commit comments

Comments
 (0)