Skip to content

Commit a9e42e8

Browse files
committed
refactor: streamline handler management and improve logging
1 parent 2b15630 commit a9e42e8

File tree

10 files changed

+156
-72
lines changed

10 files changed

+156
-72
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ classifiers = [
3535
"Topic :: Software Development :: Libraries :: Python Modules",
3636
"Topic :: Software Development :: Libraries :: Application Frameworks",
3737
]
38-
dependencies = ["pyaes", "python-socks[asyncio]", "aiosqlite"]
38+
dependencies = ["pyaes", "python-socks[asyncio]", "aiosqlite", "qrcode"]
3939
license = "LGPL-3.0-or-later"
4040
license-files = ["COPYING", "COPYING.lesser"]
4141
keywords = [

pyrogram/client.py

Lines changed: 71 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,16 @@
1717
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
1818

1919
import asyncio
20-
from collections import OrderedDict
21-
from contextlib import suppress
22-
import functools
23-
import inspect
2420
import logging
2521
import os
2622
import platform
2723
import re
2824
import shutil
2925
import sys
3026
import time
27+
from collections import OrderedDict
3128
from concurrent.futures.thread import ThreadPoolExecutor
29+
from contextlib import suppress
3230
from datetime import datetime
3331
from hashlib import sha256
3432
from importlib import import_module
@@ -391,7 +389,11 @@ def __init__(
391389
)
392390
elif self.in_memory:
393391
self.storage = SQLiteStorage(self.name, workdir=self.workdir, in_memory=True)
394-
elif isinstance(storage_engine, Storage):
392+
elif storage_engine is not None:
393+
if not isinstance(storage_engine, Storage):
394+
raise TypeError(
395+
f"storage_engine must be a Storage instance, got {type(storage_engine).__name__}"
396+
)
395397
self.storage = storage_engine
396398
else:
397399
self.storage = SQLiteStorage(self.name, workdir=self.workdir)
@@ -438,12 +440,14 @@ def __init__(
438440
self.last_update_time = datetime.now()
439441
self._last_update_monotonic = time.monotonic()
440442

443+
self.listeners = {listener_type: [] for listener_type in pyrogram.enums.ListenerTypes}
444+
441445
if isinstance(loop, asyncio.AbstractEventLoop):
442446
self.loop = loop
443447
else:
444448
self.loop = None
445449

446-
self.__config: "raw.types.Config" = None
450+
self.__config: Optional["raw.types.Config"] = None
447451

448452
@property
449453
def loop(self) -> asyncio.AbstractEventLoop:
@@ -455,9 +459,6 @@ def loop(self) -> asyncio.AbstractEventLoop:
455459
def loop(self, value: asyncio.AbstractEventLoop):
456460
self._loop = value
457461

458-
if not hasattr(self, "listeners"):
459-
self.listeners = {listener_type: [] for listener_type in pyrogram.enums.ListenerTypes}
460-
461462
def __enter__(self):
462463
return utils.get_event_loop().run_until_complete(self.start())
463464

@@ -598,7 +599,7 @@ async def authorize(self) -> User:
598599
self.phone_code = await ainput("Enter confirmation code: ", loop=self.loop)
599600

600601
try:
601-
signed_in = await self.sign_in(self.phone_number, sent_code.phone_code_hash, self.phone_code)
602+
sign_in_result = await self.sign_in(self.phone_number, sent_code.phone_code_hash, self.phone_code)
602603
except BadRequest as e:
603604
print(e.MESSAGE)
604605
self.phone_code = None
@@ -608,8 +609,10 @@ async def authorize(self) -> User:
608609
else:
609610
break
610611

611-
if isinstance(signed_in, User):
612-
return signed_in
612+
if isinstance(sign_in_result, User):
613+
return sign_in_result
614+
615+
terms_of_service = sign_in_result if isinstance(sign_in_result, TermsOfService) else None
613616

614617
while True:
615618
first_name = await ainput("Enter first name: ", loop=self.loop)
@@ -627,13 +630,13 @@ async def authorize(self) -> User:
627630
else:
628631
break
629632

630-
if isinstance(signed_in, TermsOfService):
631-
print("\n" + signed_in.text + "\n")
632-
await self.accept_terms_of_service(signed_in.id)
633+
if terms_of_service:
634+
print("\n" + terms_of_service.text + "\n")
635+
await self.accept_terms_of_service(terms_of_service.id)
633636

634637
return signed_up
635638

636-
async def authorize_qr(self, except_ids: List[int] = None) -> "User":
639+
async def authorize_qr(self, except_ids: Optional[List[int]] = None) -> "User":
637640
from qrcode import QRCode
638641

639642
qr_login = QRLogin(self, except_ids or [])
@@ -770,10 +773,10 @@ async def fetch_peers(self, peers: List[Union[raw.types.User, raw.types.Chat, ra
770773
phone_number = peer.phone
771774
peer_type = "bot" if peer.bot else "user"
772775

773-
if peer.username:
776+
if peer.usernames:
777+
usernames.extend(u.username.lower() for u in peer.usernames)
778+
elif peer.username:
774779
usernames.append(peer.username.lower())
775-
elif peer.usernames:
776-
usernames.extend(username.username.lower() for username in peer.usernames)
777780
elif isinstance(peer, (raw.types.Chat, raw.types.ChatForbidden)):
778781
peer_id = -peer.id
779782
access_hash = 0
@@ -783,10 +786,10 @@ async def fetch_peers(self, peers: List[Union[raw.types.User, raw.types.Chat, ra
783786
access_hash = peer.access_hash
784787
peer_type = "direct" if peer.monoforum else "channel" if peer.broadcast else "forum" if peer.forum else "supergroup"
785788

786-
if peer.username:
789+
if peer.usernames:
790+
usernames.extend(u.username.lower() for u in peer.usernames)
791+
elif peer.username:
787792
usernames.append(peer.username.lower())
788-
elif peer.usernames:
789-
usernames.extend(username.username.lower() for username in peer.usernames)
790793
elif isinstance(peer, raw.types.ChannelForbidden):
791794
peer_id = utils.get_channel_id(peer.id)
792795
access_hash = peer.access_hash
@@ -811,10 +814,23 @@ async def handle_updates(self, updates):
811814
self._last_update_monotonic = time.monotonic()
812815

813816
if isinstance(updates, (raw.types.Updates, raw.types.UpdatesCombined)):
814-
is_min = any((
815-
await self.fetch_peers(updates.users),
816-
await self.fetch_peers(updates.chats),
817-
))
817+
is_min_users = await self.fetch_peers(updates.users)
818+
is_min_chats = await self.fetch_peers(updates.chats)
819+
is_min = is_min_users or is_min_chats
820+
821+
if isinstance(updates, raw.types.UpdatesCombined):
822+
seq_start = getattr(updates, "seq_start", updates.seq)
823+
stored_states = await self.storage.update_state()
824+
if stored_states:
825+
global_state = next((s for s in stored_states if s[0] == 0), None)
826+
if global_state and global_state[4] is not None:
827+
stored_seq = global_state[4]
828+
if seq_start > stored_seq + 1:
829+
log.warning(
830+
"Seq gap detected: stored=%s, seq_start=%s. Triggering recovery.",
831+
stored_seq, seq_start
832+
)
833+
await self.recover_gaps()
818834

819835
users = {u.id: u for u in updates.users}
820836
chats = {c.id: c for c in updates.chats}
@@ -848,7 +864,7 @@ async def handle_updates(self, updates):
848864
if isinstance(update, raw.types.UpdateNewChannelMessage) and is_min:
849865
message = update.message
850866

851-
if not isinstance(message, raw.types.MessageEmpty):
867+
if not isinstance(message, raw.types.MessageEmpty) and pts and pts_count:
852868
try:
853869
diff = await self.invoke(
854870
raw.functions.updates.GetChannelDifference(
@@ -1023,6 +1039,10 @@ def load_plugins(self):
10231039
pass
10241040
else:
10251041
for path, handlers in include:
1042+
if ".." in path.split("."):
1043+
log.warning('[%s] [LOAD] Skipping suspicious include path "%s"', self.name, path)
1044+
continue
1045+
10261046
module_path = root + "." + path
10271047
warn_non_existent_functions = True
10281048

@@ -1058,6 +1078,10 @@ def load_plugins(self):
10581078

10591079
if exclude:
10601080
for path, handlers in exclude:
1081+
if ".." in path.split("."):
1082+
log.warning('[%s] [UNLOAD] Skipping suspicious exclude path "%s"', self.name, path)
1083+
continue
1084+
10611085
module_path = root + "." + path
10621086
warn_non_existent_functions = True
10631087

@@ -1476,6 +1500,8 @@ async def get_session(
14761500
if cached_session is None:
14771501
sessions[dc_id] = session
14781502
else:
1503+
with suppress(Exception):
1504+
await session.stop()
14791505
session = cached_session
14801506

14811507
pending_session = self._session_futures.pop(session_key, None)
@@ -1597,6 +1623,9 @@ async def set_dc(
15971623
await self.storage.server_address(server_address)
15981624
await self.storage.port(port)
15991625

1626+
if self.session is None:
1627+
return
1628+
16001629
if self.session.server_address != server_address or self.session.port != port:
16011630
self.session.server_address = server_address
16021631
self.session.port = port
@@ -1626,22 +1655,34 @@ def guess_extension(self, mime_type: str) -> Optional[str]:
16261655

16271656

16281657
class Cache:
1658+
"""LRU cache backed by OrderedDict.
1659+
1660+
Note: ``__getitem__`` returns ``None`` for missing keys instead of raising
1661+
``KeyError``. All existing call-sites rely on this behaviour.
1662+
"""
1663+
16291664
def __init__(self, capacity: int):
16301665
self.capacity = capacity
16311666
self.store = OrderedDict()
16321667

1633-
def __getitem__(self, key):
1668+
def get(self, key, default=None):
16341669
try:
16351670
self.store.move_to_end(key)
16361671
return self.store[key]
16371672
except KeyError:
1638-
return None
1673+
return default
1674+
1675+
def __getitem__(self, key):
1676+
return self.get(key)
16391677

16401678
def __setitem__(self, key, value):
1679+
if self.capacity == 0:
1680+
return
1681+
16411682
try:
16421683
self.store.move_to_end(key)
16431684
except KeyError:
1644-
if self.capacity > 0 and len(self.store) >= self.capacity:
1685+
if len(self.store) >= self.capacity:
16451686
self.store.popitem(last=False)
16461687
self.store[key] = value
16471688

pyrogram/methods/utilities/add_handler.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,21 @@
1616
# You should have received a copy of the GNU Lesser General Public License
1717
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
1818

19+
import logging
20+
1921
import pyrogram
2022
from pyrogram.handlers import StartHandler, StopHandler, ConnectHandler, DisconnectHandler
2123
from pyrogram.handlers.handler import Handler
2224

25+
log = logging.getLogger(__name__)
26+
27+
_LIFECYCLE_HANDLER_MAP = {
28+
StartHandler: "start_handler",
29+
StopHandler: "stop_handler",
30+
ConnectHandler: "connect_handler",
31+
DisconnectHandler: "disconnect_handler",
32+
}
33+
2334

2435
class AddHandler:
2536
def add_handler(
@@ -59,14 +70,17 @@ async def hello(client, message):
5970
6071
app.run()
6172
"""
62-
if isinstance(handler, StartHandler):
63-
self.start_handler = handler.callback
64-
elif isinstance(handler, StopHandler):
65-
self.stop_handler = handler.callback
66-
elif isinstance(handler, ConnectHandler):
67-
self.connect_handler = handler.callback
68-
elif isinstance(handler, DisconnectHandler):
69-
self.disconnect_handler = handler.callback
73+
attr = _LIFECYCLE_HANDLER_MAP.get(type(handler))
74+
75+
if attr is not None:
76+
if getattr(self, attr) is not None:
77+
log.warning(
78+
"Replacing existing %s (was %r, now %r)",
79+
type(handler).__name__,
80+
getattr(self, attr),
81+
handler.callback,
82+
)
83+
setattr(self, attr, handler.callback)
7084
else:
7185
self.dispatcher.add_handler(handler, group)
7286

pyrogram/methods/utilities/compose.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,14 @@
1717
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
1818

1919
import asyncio
20+
import logging
2021
from typing import List
2122

2223
import pyrogram
2324
from .idle import idle
2425

26+
log = logging.getLogger(__name__)
27+
2528

2629
async def compose(
2730
clients: List["pyrogram.Client"],
@@ -63,11 +66,23 @@ async def main():
6366
asyncio.run(main())
6467
6568
"""
66-
if sequential:
67-
for c in clients:
68-
await c.start()
69-
else:
70-
await asyncio.gather(*[c.start() for c in clients])
69+
started = []
70+
71+
try:
72+
if sequential:
73+
for c in clients:
74+
await c.start()
75+
started.append(c)
76+
else:
77+
await asyncio.gather(*[c.start() for c in clients])
78+
started = list(clients)
79+
except Exception:
80+
for c in reversed(started):
81+
try:
82+
await c.stop()
83+
except Exception as e:
84+
log.warning("Failed to stop client %s during cleanup: %s", c.name, e)
85+
raise
7186

7287
await idle()
7388

pyrogram/methods/utilities/idle.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,19 @@ def signal_handler(signum, __):
7575
log.info(f"Stop signal received ({signals[signum]}). Exiting...")
7676
task.cancel()
7777

78-
for s in (SIGINT, SIGTERM, SIGABRT):
79-
signal_fn(s, signal_handler)
80-
81-
while True:
82-
task = asyncio.create_task(asyncio.sleep(600))
78+
original_handlers = {}
8379

84-
try:
85-
await task
86-
except asyncio.CancelledError:
87-
break
80+
for s in (SIGINT, SIGTERM, SIGABRT):
81+
original_handlers[s] = signal_fn(s, signal_handler)
82+
83+
try:
84+
while True:
85+
task = asyncio.create_task(asyncio.sleep(600))
86+
87+
try:
88+
await task
89+
except asyncio.CancelledError:
90+
break
91+
finally:
92+
for s, handler in original_handlers.items():
93+
signal_fn(s, handler)

0 commit comments

Comments
 (0)