diff --git a/binance/ws/keepalive_websocket.py b/binance/ws/keepalive_websocket.py index 3f1c3f2e..e122b179 100644 --- a/binance/ws/keepalive_websocket.py +++ b/binance/ws/keepalive_websocket.py @@ -45,6 +45,7 @@ async def __aexit__(self, *args, **kwargs): # Unregister the queue from ws_api before unsubscribing if hasattr(self._client, "ws_api") and self._client.ws_api: self._client.ws_api.unregister_subscription_queue(self._subscription_id) + self._client.ws_api.remove_reconnect_callback(self._handle_reconnect) await self._unsubscribe_from_user_data_stream() if self._uses_ws_api_subscription: # For ws_api subscriptions, we don't manage the connection @@ -72,6 +73,7 @@ async def _before_connect(self): self._client.ws_api.register_subscription_queue( self._subscription_id, self._queue ) + self._client.ws_api.add_reconnect_callback(self._handle_reconnect) self._path = f"user_subscription:{self._subscription_id}" return if self._keepalive_type == "margin": @@ -86,6 +88,7 @@ async def _before_connect(self): self._client.ws_api.register_subscription_queue( self._subscription_id, self._queue ) + self._client.ws_api.add_reconnect_callback(self._handle_reconnect) self._path = f"margin_subscription:{self._subscription_id}" return # Check if this is isolated margin (when keepalive_type is a symbol string) @@ -111,6 +114,7 @@ async def _before_connect(self): self._client.ws_api.register_subscription_queue( self._subscription_id, self._queue ) + self._client.ws_api.add_reconnect_callback(self._handle_reconnect) self._path = f"isolated_margin_subscription:{self._subscription_id}" return if not self._listen_key: @@ -209,6 +213,76 @@ async def _subscribe_to_isolated_margin_data_stream(self, symbol: str): ) return response.get("subscriptionId") + async def _handle_reconnect(self, ws, ws_api): + """Called by WebsocketAPI after reconnect to restore subscription.""" + if not self._uses_ws_api_subscription: + return + + old_subscription_id = self._subscription_id + try: + payload = await self._build_subscribe_payload() + await ws.send(ws_api.json_dumps(payload)) + res = await asyncio.wait_for(ws.recv(), timeout=self.TIMEOUT) + res = ws_api.json_loads(res) + + if "error" in res: + self._log.error(f"Re-subscribe failed: {res}") + return + + new_sub_id = res.get("result", {}).get("subscriptionId") + if new_sub_id is None: + self._log.error(f"Re-subscribe response missing subscriptionId: {res}") + return + + if old_subscription_id is not None: + ws_api.unregister_subscription_queue(old_subscription_id) + ws_api.register_subscription_queue(new_sub_id, self._queue) + self._subscription_id = new_sub_id + + if self._keepalive_type == "user": + self._path = f"user_subscription:{new_sub_id}" + elif self._keepalive_type == "margin": + self._path = f"margin_subscription:{new_sub_id}" + else: + self._path = f"isolated_margin_subscription:{new_sub_id}" + + self._log.info(f"Re-subscribed after reconnect: {old_subscription_id} -> {new_sub_id}") + except Exception as e: + self._log.error(f"Failed to re-subscribe after reconnect: {e}") + + async def _build_subscribe_payload(self): + """Build a payload for subscribing.""" + new_id = str(uuid.uuid4()) + if self._keepalive_type == "user": + params = self._client._sign_ws_params( + {}, self._client._generate_ws_api_signature + ) + return { + "id": new_id, + "method": "userDataStream.subscribe.signature", + "params": params, + } + if self._keepalive_type == "margin": + token_response = await self._client.margin_create_listen_token( + is_isolated=False + ) + self._listen_token_expiration = token_response.get("expirationTime") + return { + "id": new_id, + "method": "userDataStream.subscribe.listenToken", + "params": {"listenToken": token_response["token"]}, + } + # isolated margin + token_response = await self._client.margin_create_listen_token( + symbol=self._keepalive_type, is_isolated=True + ) + self._listen_token_expiration = token_response.get("expirationTime") + return { + "id": new_id, + "method": "userDataStream.subscribe.listenToken", + "params": {"listenToken": token_response["token"]}, + } + async def _unsubscribe_from_user_data_stream(self): """Unsubscribe from user data stream using WebSocket API""" if self._subscription_id is not None: diff --git a/binance/ws/websocket_api.py b/binance/ws/websocket_api.py index 3ef9ed13..437518b5 100644 --- a/binance/ws/websocket_api.py +++ b/binance/ws/websocket_api.py @@ -16,6 +16,7 @@ def __init__(self, url: str, tld: str = "com", testnet: bool = False, https_prox self._connection_lock: Optional[asyncio.Lock] = None # Subscription queues for routing user data stream events self._subscription_queues: Dict[str, asyncio.Queue] = {} + self._on_reconnect_callbacks = [] super().__init__(url=url, prefix="", path="", is_binary=False, https_proxy=https_proxy) def register_subscription_queue(self, subscription_id: str, queue: asyncio.Queue) -> None: @@ -26,6 +27,17 @@ def unregister_subscription_queue(self, subscription_id: str) -> None: """Unregister a subscription queue.""" self._subscription_queues.pop(subscription_id, None) + def add_reconnect_callback(self, callback) -> None: + self._on_reconnect_callbacks.append(callback) + + def remove_reconnect_callback(self, callback) -> None: + if callback in self._on_reconnect_callbacks: + self._on_reconnect_callbacks.remove(callback) + + async def _after_connect(self): + for cb in self._on_reconnect_callbacks: + await cb(self.ws, self) + @property def connection_lock(self) -> asyncio.Lock: if self._connection_lock is None: