1818import json
1919import threading
2020import uuid
21+ import weakref
2122from abc import abstractmethod
2223from base64 import urlsafe_b64encode
2324from collections .abc import Awaitable , Callable , Mapping , Sequence
@@ -880,9 +881,11 @@ class FileHistoryProvider(HistoryProvider):
880881 DEFAULT_SOURCE_ID : ClassVar [str ] = "file_history"
881882 DEFAULT_SESSION_FILE_STEM : ClassVar [str ] = "default"
882883 FILE_EXTENSION : ClassVar [str ] = ".jsonl"
884+ _FILE_LOCK_STRIPE_COUNT : ClassVar [int ] = 64
883885 _ENCODED_SESSION_PREFIX : ClassVar [str ] = "~session-"
884- _FILE_WRITE_LOCKS : ClassVar [dict [Path , threading .Lock ]] = {}
885- _FILE_WRITE_LOCKS_GUARD : ClassVar [threading .Lock ] = threading .Lock ()
886+ _FILE_WRITE_LOCKS : ClassVar [tuple [threading .Lock , ...]] = tuple (
887+ threading .Lock () for _ in range (_FILE_LOCK_STRIPE_COUNT )
888+ )
886889 _WINDOWS_RESERVED_FILE_STEMS : ClassVar [frozenset [str ]] = frozenset ({
887890 "CON" ,
888891 "PRN" ,
@@ -955,6 +958,10 @@ def __init__(
955958 self .skip_excluded = skip_excluded
956959 self .dumps = dumps or _default_json_dumps
957960 self .loads = loads or _default_json_loads
961+ self ._async_write_locks_by_loop : weakref .WeakKeyDictionary [
962+ asyncio .AbstractEventLoop ,
963+ tuple [asyncio .Lock , ...],
964+ ] = weakref .WeakKeyDictionary ()
958965
959966 async def get_messages (
960967 self ,
@@ -966,38 +973,42 @@ async def get_messages(
966973 """Retrieve messages from the session's JSON Lines file."""
967974 del state , kwargs
968975 file_path = self ._session_file_path (session_id )
976+ async_lock = self ._session_async_write_lock (file_path )
977+ thread_lock = self ._session_write_lock (file_path )
969978
970979 def _read_messages () -> list [Message ]:
971- if not file_path .exists ():
972- return []
973-
974- messages : list [Message ] = []
975- with file_path .open (encoding = "utf-8" ) as file_handle :
976- for line_number , line in enumerate (file_handle , start = 1 ):
977- serialized = line .strip ()
978- if not serialized :
979- continue
980- try :
981- payload = self .loads (serialized )
982- except (TypeError , ValueError ) as exc :
983- raise ValueError (
984- f"Failed to deserialize history line { line_number } from '{ file_path } '."
985- ) from exc
986- if not isinstance (payload , Mapping ):
987- raise ValueError (
988- f"History line { line_number } in '{ file_path } ' did not deserialize to a mapping."
989- )
990-
991- try :
992- message = Message .from_dict (dict (cast (Mapping [str , Any ], payload )))
993- except ValueError as exc :
994- raise ValueError (
995- f"History line { line_number } in '{ file_path } ' is not a valid Message payload."
996- ) from exc
997- messages .append (message )
998- return messages
999-
1000- messages = await asyncio .to_thread (_read_messages )
980+ with thread_lock :
981+ if not file_path .exists ():
982+ return []
983+
984+ messages : list [Message ] = []
985+ with file_path .open (encoding = "utf-8" ) as file_handle :
986+ for line_number , line in enumerate (file_handle , start = 1 ):
987+ serialized = line .strip ()
988+ if not serialized :
989+ continue
990+ try :
991+ payload = self .loads (serialized )
992+ except (TypeError , ValueError ) as exc :
993+ raise ValueError (
994+ f"Failed to deserialize history line { line_number } from '{ file_path } '."
995+ ) from exc
996+ if not isinstance (payload , Mapping ):
997+ raise ValueError (
998+ f"History line { line_number } in '{ file_path } ' did not deserialize to a mapping."
999+ )
1000+
1001+ try :
1002+ message = Message .from_dict (dict (cast (Mapping [str , Any ], payload )))
1003+ except ValueError as exc :
1004+ raise ValueError (
1005+ f"History line { line_number } in '{ file_path } ' is not a valid Message payload."
1006+ ) from exc
1007+ messages .append (message )
1008+ return messages
1009+
1010+ async with async_lock :
1011+ messages = await asyncio .to_thread (_read_messages )
10011012 if self .skip_excluded :
10021013 messages = [m for m in messages if not m .additional_properties .get ("_excluded" , False )]
10031014 return messages
@@ -1016,14 +1027,16 @@ async def save_messages(
10161027 return
10171028
10181029 file_path = self ._session_file_path (session_id )
1030+ async_lock = self ._session_async_write_lock (file_path )
10191031 file_lock = self ._session_write_lock (file_path )
10201032
10211033 def _append_messages () -> None :
10221034 with file_lock , file_path .open ("a" , encoding = "utf-8" ) as file_handle :
10231035 for message in messages :
10241036 file_handle .write (f"{ self ._serialize_message (message )} \n " )
10251037
1026- await asyncio .to_thread (_append_messages )
1038+ async with async_lock :
1039+ await asyncio .to_thread (_append_messages )
10271040
10281041 def _serialize_message (self , message : Message ) -> str :
10291042 """Serialize a message payload to a single JSON Lines record."""
@@ -1055,11 +1068,24 @@ def _session_file_stem(self, session_id: str | None) -> str:
10551068 encoded_session_id = urlsafe_b64encode (raw_session_id .encode ("utf-8" )).decode ("ascii" ).rstrip ("=" )
10561069 return f"{ self ._ENCODED_SESSION_PREFIX } { encoded_session_id or self .DEFAULT_SESSION_FILE_STEM } "
10571070
1071+ def _session_async_write_lock (self , file_path : Path ) -> asyncio .Lock :
1072+ """Return the event-loop-local async lock for a session history file."""
1073+ loop = asyncio .get_running_loop ()
1074+ locks = self ._async_write_locks_by_loop .get (loop )
1075+ if locks is None :
1076+ locks = tuple (asyncio .Lock () for _ in range (self ._FILE_LOCK_STRIPE_COUNT ))
1077+ self ._async_write_locks_by_loop [loop ] = locks
1078+ return locks [self ._lock_index (file_path )]
1079+
10581080 @classmethod
10591081 def _session_write_lock (cls , file_path : Path ) -> threading .Lock :
1060- """Return the process-local append lock for a session history file."""
1061- with cls ._FILE_WRITE_LOCKS_GUARD :
1062- return cls ._FILE_WRITE_LOCKS .setdefault (file_path , threading .Lock ())
1082+ """Return the process-local thread lock for a session history file."""
1083+ return cls ._FILE_WRITE_LOCKS [cls ._lock_index (file_path )]
1084+
1085+ @classmethod
1086+ def _lock_index (cls , file_path : Path ) -> int :
1087+ """Map a session history file to a bounded lock stripe."""
1088+ return hash (file_path ) % cls ._FILE_LOCK_STRIPE_COUNT
10631089
10641090 @classmethod
10651091 def _is_literal_session_file_stem_safe (cls , session_id : str ) -> bool :
0 commit comments