1717# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
1818
1919import asyncio
20- from collections import OrderedDict
21- from contextlib import suppress
22- import functools
23- import inspect
2420import logging
2521import os
2622import platform
2723import re
2824import shutil
2925import sys
3026import time
27+ from collections import OrderedDict
3128from concurrent .futures .thread import ThreadPoolExecutor
29+ from contextlib import suppress
3230from datetime import datetime
3331from hashlib import sha256
3432from importlib import import_module
@@ -391,7 +389,11 @@ def __init__(
391389 )
392390 elif self .in_memory :
393391 self .storage = SQLiteStorage (self .name , workdir = self .workdir , in_memory = True )
394- elif isinstance (storage_engine , Storage ):
392+ elif storage_engine is not None :
393+ if not isinstance (storage_engine , Storage ):
394+ raise TypeError (
395+ f"storage_engine must be a Storage instance, got { type (storage_engine ).__name__ } "
396+ )
395397 self .storage = storage_engine
396398 else :
397399 self .storage = SQLiteStorage (self .name , workdir = self .workdir )
@@ -438,12 +440,14 @@ def __init__(
438440 self .last_update_time = datetime .now ()
439441 self ._last_update_monotonic = time .monotonic ()
440442
443+ self .listeners = {listener_type : [] for listener_type in pyrogram .enums .ListenerTypes }
444+
441445 if isinstance (loop , asyncio .AbstractEventLoop ):
442446 self .loop = loop
443447 else :
444448 self .loop = None
445449
446- self .__config : "raw.types.Config" = None
450+ self .__config : Optional [ "raw.types.Config" ] = None
447451
448452 @property
449453 def loop (self ) -> asyncio .AbstractEventLoop :
@@ -455,9 +459,6 @@ def loop(self) -> asyncio.AbstractEventLoop:
455459 def loop (self , value : asyncio .AbstractEventLoop ):
456460 self ._loop = value
457461
458- if not hasattr (self , "listeners" ):
459- self .listeners = {listener_type : [] for listener_type in pyrogram .enums .ListenerTypes }
460-
461462 def __enter__ (self ):
462463 return utils .get_event_loop ().run_until_complete (self .start ())
463464
@@ -598,7 +599,7 @@ async def authorize(self) -> User:
598599 self .phone_code = await ainput ("Enter confirmation code: " , loop = self .loop )
599600
600601 try :
601- signed_in = await self .sign_in (self .phone_number , sent_code .phone_code_hash , self .phone_code )
602+ sign_in_result = await self .sign_in (self .phone_number , sent_code .phone_code_hash , self .phone_code )
602603 except BadRequest as e :
603604 print (e .MESSAGE )
604605 self .phone_code = None
@@ -608,8 +609,10 @@ async def authorize(self) -> User:
608609 else :
609610 break
610611
611- if isinstance (signed_in , User ):
612- return signed_in
612+ if isinstance (sign_in_result , User ):
613+ return sign_in_result
614+
615+ terms_of_service = sign_in_result if isinstance (sign_in_result , TermsOfService ) else None
613616
614617 while True :
615618 first_name = await ainput ("Enter first name: " , loop = self .loop )
@@ -627,13 +630,13 @@ async def authorize(self) -> User:
627630 else :
628631 break
629632
630- if isinstance ( signed_in , TermsOfService ) :
631- print ("\n " + signed_in .text + "\n " )
632- await self .accept_terms_of_service (signed_in .id )
633+ if terms_of_service :
634+ print ("\n " + terms_of_service .text + "\n " )
635+ await self .accept_terms_of_service (terms_of_service .id )
633636
634637 return signed_up
635638
636- async def authorize_qr (self , except_ids : List [int ] = None ) -> "User" :
639+ async def authorize_qr (self , except_ids : Optional [ List [int ] ] = None ) -> "User" :
637640 from qrcode import QRCode
638641
639642 qr_login = QRLogin (self , except_ids or [])
@@ -770,10 +773,10 @@ async def fetch_peers(self, peers: List[Union[raw.types.User, raw.types.Chat, ra
770773 phone_number = peer .phone
771774 peer_type = "bot" if peer .bot else "user"
772775
773- if peer .username :
776+ if peer .usernames :
777+ usernames .extend (u .username .lower () for u in peer .usernames )
778+ elif peer .username :
774779 usernames .append (peer .username .lower ())
775- elif peer .usernames :
776- usernames .extend (username .username .lower () for username in peer .usernames )
777780 elif isinstance (peer , (raw .types .Chat , raw .types .ChatForbidden )):
778781 peer_id = - peer .id
779782 access_hash = 0
@@ -783,10 +786,10 @@ async def fetch_peers(self, peers: List[Union[raw.types.User, raw.types.Chat, ra
783786 access_hash = peer .access_hash
784787 peer_type = "direct" if peer .monoforum else "channel" if peer .broadcast else "forum" if peer .forum else "supergroup"
785788
786- if peer .username :
789+ if peer .usernames :
790+ usernames .extend (u .username .lower () for u in peer .usernames )
791+ elif peer .username :
787792 usernames .append (peer .username .lower ())
788- elif peer .usernames :
789- usernames .extend (username .username .lower () for username in peer .usernames )
790793 elif isinstance (peer , raw .types .ChannelForbidden ):
791794 peer_id = utils .get_channel_id (peer .id )
792795 access_hash = peer .access_hash
@@ -811,10 +814,23 @@ async def handle_updates(self, updates):
811814 self ._last_update_monotonic = time .monotonic ()
812815
813816 if isinstance (updates , (raw .types .Updates , raw .types .UpdatesCombined )):
814- is_min = any ((
815- await self .fetch_peers (updates .users ),
816- await self .fetch_peers (updates .chats ),
817- ))
817+ is_min_users = await self .fetch_peers (updates .users )
818+ is_min_chats = await self .fetch_peers (updates .chats )
819+ is_min = is_min_users or is_min_chats
820+
821+ if isinstance (updates , raw .types .UpdatesCombined ):
822+ seq_start = getattr (updates , "seq_start" , updates .seq )
823+ stored_states = await self .storage .update_state ()
824+ if stored_states :
825+ global_state = next ((s for s in stored_states if s [0 ] == 0 ), None )
826+ if global_state and global_state [4 ] is not None :
827+ stored_seq = global_state [4 ]
828+ if seq_start > stored_seq + 1 :
829+ log .warning (
830+ "Seq gap detected: stored=%s, seq_start=%s. Triggering recovery." ,
831+ stored_seq , seq_start
832+ )
833+ await self .recover_gaps ()
818834
819835 users = {u .id : u for u in updates .users }
820836 chats = {c .id : c for c in updates .chats }
@@ -848,7 +864,7 @@ async def handle_updates(self, updates):
848864 if isinstance (update , raw .types .UpdateNewChannelMessage ) and is_min :
849865 message = update .message
850866
851- if not isinstance (message , raw .types .MessageEmpty ):
867+ if not isinstance (message , raw .types .MessageEmpty ) and pts and pts_count :
852868 try :
853869 diff = await self .invoke (
854870 raw .functions .updates .GetChannelDifference (
@@ -1023,6 +1039,10 @@ def load_plugins(self):
10231039 pass
10241040 else :
10251041 for path , handlers in include :
1042+ if ".." in path .split ("." ):
1043+ log .warning ('[%s] [LOAD] Skipping suspicious include path "%s"' , self .name , path )
1044+ continue
1045+
10261046 module_path = root + "." + path
10271047 warn_non_existent_functions = True
10281048
@@ -1058,6 +1078,10 @@ def load_plugins(self):
10581078
10591079 if exclude :
10601080 for path , handlers in exclude :
1081+ if ".." in path .split ("." ):
1082+ log .warning ('[%s] [UNLOAD] Skipping suspicious exclude path "%s"' , self .name , path )
1083+ continue
1084+
10611085 module_path = root + "." + path
10621086 warn_non_existent_functions = True
10631087
@@ -1476,6 +1500,8 @@ async def get_session(
14761500 if cached_session is None :
14771501 sessions [dc_id ] = session
14781502 else :
1503+ with suppress (Exception ):
1504+ await session .stop ()
14791505 session = cached_session
14801506
14811507 pending_session = self ._session_futures .pop (session_key , None )
@@ -1597,6 +1623,9 @@ async def set_dc(
15971623 await self .storage .server_address (server_address )
15981624 await self .storage .port (port )
15991625
1626+ if self .session is None :
1627+ return
1628+
16001629 if self .session .server_address != server_address or self .session .port != port :
16011630 self .session .server_address = server_address
16021631 self .session .port = port
@@ -1626,22 +1655,34 @@ def guess_extension(self, mime_type: str) -> Optional[str]:
16261655
16271656
16281657class Cache :
1658+ """LRU cache backed by OrderedDict.
1659+
1660+ Note: ``__getitem__`` returns ``None`` for missing keys instead of raising
1661+ ``KeyError``. All existing call-sites rely on this behaviour.
1662+ """
1663+
16291664 def __init__ (self , capacity : int ):
16301665 self .capacity = capacity
16311666 self .store = OrderedDict ()
16321667
1633- def __getitem__ (self , key ):
1668+ def get (self , key , default = None ):
16341669 try :
16351670 self .store .move_to_end (key )
16361671 return self .store [key ]
16371672 except KeyError :
1638- return None
1673+ return default
1674+
1675+ def __getitem__ (self , key ):
1676+ return self .get (key )
16391677
16401678 def __setitem__ (self , key , value ):
1679+ if self .capacity == 0 :
1680+ return
1681+
16411682 try :
16421683 self .store .move_to_end (key )
16431684 except KeyError :
1644- if self . capacity > 0 and len (self .store ) >= self .capacity :
1685+ if len (self .store ) >= self .capacity :
16451686 self .store .popitem (last = False )
16461687 self .store [key ] = value
16471688
0 commit comments