diff --git a/elementary/messages/messaging_integrations/slack_web.py b/elementary/messages/messaging_integrations/slack_web.py index f83b2ae51..5b23f7572 100644 --- a/elementary/messages/messaging_integrations/slack_web.py +++ b/elementary/messages/messaging_integrations/slack_web.py @@ -1,7 +1,10 @@ import json +import re import ssl import time -from typing import Any, Dict, Iterator, Optional +from dataclasses import dataclass +from http import HTTPStatus +from typing import Any, Dict, Iterator, List, Optional, Tuple from ratelimit import limits, sleep_and_retry from slack_sdk import WebClient @@ -30,6 +33,32 @@ ONE_MINUTE = 60 ONE_SECOND = 1 +_CHANNEL_ID_PATTERN = re.compile(r"^[CGD][A-Z0-9]{8,}$") + + +def _is_channel_id(value: str) -> bool: + return bool(_CHANNEL_ID_PATTERN.match(value)) + + +def _normalize_channel_input(raw: str) -> str: + normalized = raw.strip() + if normalized.startswith("#"): + normalized = normalized[1:].strip() + return normalized + + +@dataclass +class ResolvedChannel: + name: str + id: str + + +@dataclass +class ChannelsResponse: + channels: list[ResolvedChannel] + retry_after: int | None + cursor: str | None + Channel: TypeAlias = str @@ -51,6 +80,7 @@ def __init__( self.client = client self.tracking = tracking self._email_to_user_id_cache: Dict[str, str] = {} + self._channel_cache: Dict[Tuple[str, bool], ResolvedChannel] = {} self.reply_broadcast = reply_broadcast @classmethod @@ -132,7 +162,7 @@ def _handle_send_err(self, err: SlackApiError, channel_name: str): logger.info( f'Elementary app is not in the channel "{channel_name}". Attempting to join.' ) - channel_id = self._get_channel_id(channel_name, only_public=True) + channel_id = self.resolve_channel(channel_name, only_public=True).id self._join_channel(channel_id=channel_id) logger.info(f"Joined channel {channel_name}") elif err_type == "channel_not_found": @@ -143,6 +173,19 @@ def _handle_send_err(self, err: SlackApiError, channel_name: str): f"Failed to send a message to channel - {channel_name}" ) + def _list_conversations( + self, cursor: Optional[str] = None + ) -> Tuple[List[dict], Optional[str]]: + response = self.client.conversations_list( + cursor=cursor, + types="public_channel,private_channel", + exclude_archived=True, + limit=1000, + ) + channels = response.get("channels", []) + cursor = response.get("response_metadata", {}).get("next_cursor") + return channels, cursor + @sleep_and_retry @limits(calls=20, period=ONE_MINUTE) def _iter_channels( @@ -155,29 +198,83 @@ def _iter_channels( raise MessagingIntegrationError("Channel iteration timed out") call_start = time.time() - response = self.client.conversations_list( - cursor=cursor, - types="public_channel" if only_public else "public_channel,private_channel", - exclude_archived=True, - limit=1000, - ) + channels, cursor = self._list_conversations(cursor) call_duration = time.time() - call_start - channels = response["channels"] yield from channels - response_metadata = response.get("response_metadata") or {} - next_cursor = response_metadata.get("next_cursor") - if next_cursor: - if not isinstance(next_cursor, str): - raise ValueError("Next cursor is not a string") + if cursor: timeout_left = timeout - call_duration - yield from self._iter_channels(next_cursor, only_public, timeout_left) + yield from self._iter_channels(cursor, only_public, timeout_left) + + @sleep_and_retry + @limits(calls=50, period=ONE_MINUTE) + def resolve_channel( + self, channel: str, only_public: bool = False + ) -> ResolvedChannel: + normalized = _normalize_channel_input(channel) + cache_key = (normalized, only_public) + if cache_key in self._channel_cache: + return self._channel_cache[cache_key] + + if _is_channel_id(normalized): + try: + response = self.client.conversations_info(channel=normalized) + except SlackApiError as e: + if self.tracking: + self.tracking.record_internal_exception(e) + raise MessagingIntegrationError( + f"Channel {normalized} not found" + ) from e + ch = response["channel"] + resolved = ResolvedChannel(name=ch["name"], id=ch["id"]) + else: + for ch in self._iter_channels(only_public=only_public): + if ch["name"] == normalized: + resolved = ResolvedChannel(name=ch["name"], id=ch["id"]) + break + else: + raise MessagingIntegrationError(f"Channel {normalized} not found") + + self._channel_cache[cache_key] = resolved + return resolved + + def get_channels( + self, + cursor: str | None = None, + timeout_seconds: int = 15, + ) -> ChannelsResponse: + channels_response = ChannelsResponse(channels=[], retry_after=None, cursor=None) + + start_time = time.time() + time_elapsed: float = 0 + while time_elapsed < timeout_seconds: + try: + channels, cursor = self._list_conversations(cursor) + time_elapsed = time.time() - start_time + logger.debug( + f"Got a batch of {len(channels)} channels! time elapsed: {time_elapsed} seconds" + ) + + channels_response.channels.extend( + [ + ResolvedChannel(name=chan["name"], id=chan["id"]) + for chan in channels + ] + ) + + if not cursor: + break + + except SlackApiError as err: + if err.response.status_code == HTTPStatus.TOO_MANY_REQUESTS: + channels_response.retry_after = int( + err.response.headers["Retry-After"] + ) + break + raise - def _get_channel_id(self, channel_name: str, only_public: bool = False) -> str: - for channel in self._iter_channels(only_public=only_public): - if channel["name"] == channel_name: - return channel["id"] - raise MessagingIntegrationError(f"Channel {channel_name} not found") + channels_response.cursor = cursor + return channels_response def _join_channel(self, channel_id: str) -> None: try: