diff --git a/.gitignore b/.gitignore index 24e00cf0b..2d6f8bfef 100644 --- a/.gitignore +++ b/.gitignore @@ -169,6 +169,7 @@ cython_debug/ # PyRIT secrets file .env +.pyrit_cache/ # Cache for generating docs doc/generate_docs/cache/* diff --git a/doc/api.rst b/doc/api.rst index 1b9bdc775..cd812de72 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -500,6 +500,7 @@ API Reference PromptTarget RealtimeTarget TextTarget + WebSocketCopilotTarget :py:mod:`pyrit.score` ===================== diff --git a/pyrit/auth/copilot_authenticator.py b/pyrit/auth/copilot_authenticator.py new file mode 100644 index 000000000..907916041 --- /dev/null +++ b/pyrit/auth/copilot_authenticator.py @@ -0,0 +1,332 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging +import asyncio +import os +from datetime import datetime, timezone +from typing import Optional + +import json +import re +from msal_extensions import build_encrypted_persistence, FilePersistence + +from pyrit.auth.authenticator import Authenticator +from pyrit.common.path import PYRIT_CACHE_PATH + +logger = logging.getLogger(__name__) + + +class CopilotAuthenticator(Authenticator): + """ + Playwright-based authenticator for Microsoft Copilot. Used by WebSocketCopilotTarget. + + This authenticator automates browser login to obtain and refresh access tokens that are necessary for accessing + Microsoft Copilot via WebSocket connections. It uses Playwright to simulate user interactions for authentication, and msal-extensions for encrypted token persistence. + + An access token acquired by this authenticator is usually valid for about 60 minutes. + + Note: + To be able to use this authenticator, you must set the following environment variables: + + - COPILOT_USERNAME: Your Microsoft account username (email). + - COPILOT_PASSWORD: Your Microsoft account password. + + Additionally, you need to have playwright installed and set up: + ``pip install playwright && playwright install chromium``. + """ + + CACHE_FILE_NAME: str = "copilot_token_cache.bin" + + def __init__( + self, + *, + headless: bool = False, + maximized: bool = True, + timeout_for_elements: int = 10, + fallback_to_plaintext: bool = False, + ): + """ + Initialize the CopilotAuthenticator. + + Args: + headless (bool): Whether to run the browser in headless mode. Default is False. + maximized (bool): Whether to start the browser maximized. Default is True. + timeout_for_elements (int): Timeout used when waiting for page elements, in seconds. Default is 10. + fallback_to_plaintext (bool): Whether to fallback to plaintext storage if encryption is unavailable. + If set to False (default), an exception will be raised if encryption cannot be used. + + Raises: + ValueError: If the required environment variables are not set. + """ + super().__init__() + + self._username = os.getenv("COPILOT_USERNAME") + self._password = os.getenv("COPILOT_PASSWORD") + + self._headless = headless + self._maximized = maximized + self._timeout = timeout_for_elements * 1000 # ms + self._fallback_to_plaintext = fallback_to_plaintext + + self._cache_dir = PYRIT_CACHE_PATH + self._cache_dir.mkdir(parents=True, exist_ok=True) + self._cache_file = str(self._cache_dir / self.CACHE_FILE_NAME) + + if not self._username or not self._password: + raise ValueError("COPILOT_USERNAME and COPILOT_PASSWORD environment variables must be set.") + + self._token_cache = self._create_persistent_cache(self._cache_file, self._fallback_to_plaintext) + + @staticmethod + def _create_persistent_cache(cache_file: str, fallback_to_plaintext: bool = False): + # https://github.com/AzureAD/microsoft-authentication-extensions-for-python + + try: + logger.info(f"Using encrypted persistent token cache: {cache_file}") + return build_encrypted_persistence(cache_file) + except Exception as e: + if fallback_to_plaintext: + logger.warning(f"Encryption unavailable ({e}). Opting in to plain text.") + return FilePersistence(cache_file) + logger.error("Encryption unavailable and fallback_to_plaintext is False.") + raise + + def _get_cached_token_if_available_and_valid(self) -> Optional[dict]: + try: + cache_data = self._token_cache.load() + if not cache_data: + logger.info("No cached token data found.") + return None + + token_data = json.loads(cache_data) + if "access_token" not in token_data: + logger.info("No access token in cache.") + return None + + expires_at = token_data.get("expires_at") + if expires_at: + expiry_time = datetime.fromtimestamp(expires_at, tz=timezone.utc) + current_time = datetime.now(timezone.utc) + + # TODO: add n-minute buffer to avoid using tokens about to expire + if current_time >= expiry_time: + logger.info("Cached token has expired.") + return None + + minutes_left = (expiry_time - current_time).total_seconds() / 60 + logger.info(f"Cached token is valid for another {minutes_left:.2f} minutes") + + return token_data + + except Exception as e: + error_name = type(e).__name__ + if "PersistenceNotFound" in error_name or "FileNotFoundError" in error_name: + logger.info("Cache file does not exist yet. Will be created on first token save.") + else: + logger.error(f"Failed to load cached token ({error_name}): {e}") + return None + + def _save_token_to_cache(self, *, token: str, expires_in: Optional[int] = None) -> None: + token_data = { + "access_token": token, + "token_type": "Bearer", + "cached_at": datetime.now(timezone.utc).timestamp(), + } + + if expires_in: + expires_at = datetime.now(timezone.utc).timestamp() + expires_in + token_data["expires_at"] = expires_at + token_data["expires_in"] = expires_in + + try: + self._token_cache.save(json.dumps(token_data)) + logger.info("Token successfully cached.") + except Exception as e: + logger.error(f"Failed to cache token: {e}") + + def _clear_token_cache(self) -> None: + try: + self._token_cache.save(json.dumps({})) + logger.info("Token cache cleared.") + except Exception as e: + logger.error(f"Failed to clear cache: {e}") + + async def refresh_token(self) -> str: + """ + Refresh the authentication token asynchronously. + + This will clear the existing token cache and fetch a new token with automated browser login. + + Returns: + str: The refreshed authentication token. + + Raises: + RuntimeError: If token refresh fails. + """ + logger.info("Refreshing access token...") + self._clear_token_cache() + token = await self._fetch_access_token_with_playwright() + + if not token: + raise RuntimeError("Failed to refresh access token.") + + return token + + async def get_token(self) -> str: + """ + Get the current authentication token. + + This will check the cache first and only launch the browser if no valid token is found. + + Returns: + str: The current authentication token. + + Raises: + RuntimeError: If token retrieval fails. + """ + cached_token = self._get_cached_token_if_available_and_valid() + if cached_token and "access_token" in cached_token: + logger.info("Using cached access token.") + return cached_token["access_token"] + + logger.info("No valid cached token found.") + return await self.refresh_token() + + async def _fetch_access_token_with_playwright(self) -> Optional[str]: + """ + Fetch access token using Playwright browser automation. + + Raises: + RuntimeError: If Playwright is not installed. + + Returns: + Optional[str]: The bearer token if successfully retrieved, else None. + """ + try: + from playwright.async_api import async_playwright + + pass + except ImportError: + raise RuntimeError("Playwright is not installed. Please install it with 'pip install playwright'.") + + bearer_token = None + token_expires_in = None + + async with async_playwright() as playwright: + browser = None + context = None + + try: + logger.info(f"Launching browser for authentication (headless={self._headless})...") + browser = await playwright.chromium.launch( + headless=self._headless, args=["--start-maximized"] if self._maximized else [] + ) + + context = await browser.new_context(no_viewport=True) + page = await context.new_page() + + # response_handler >>> + async def response_handler(response): + nonlocal bearer_token, token_expires_in + + try: + url = response.url + + if "/oauth2/v2.0/token" in url: + try: + text = await response.text() + + if ( + '"token_type":"Bearer"' in text or '"tokenType":"Bearer"' in text + ) and "sydney" in text: + try: + data = json.loads(text) + if "access_token" in data: + bearer_token = data["access_token"] + token_expires_in = data.get("expires_in") + + except json.JSONDecodeError: + logger.info("Response JSON decode failed, trying regex extraction...") + + match = re.search(r'"access_token"\s*:\s*"([^"]+)"', text) + if match: + bearer_token = match.group(1) + logger.info("Captured bearer token using regex.") + + expires_match = re.search(r'"expires_in"\s*:\s*(\d+)', text) + if expires_match: + token_expires_in = int(expires_match.group(1)) + else: + logger.error("Failed to extract bearer token using regex.") + + except Exception as e: + logger.error(f"Error reading response: {e}") + + except Exception as e: + logger.error(f"Error handling response: {e}") + + # ^^^ response_handler + + page.on("response", response_handler) + + logger.info("Navigating to Office.com for authentication...") + await page.goto("https://www.office.com/") + + logger.info("Waiting for profile icon...") + await page.wait_for_selector("#mectrl_headerPicture", timeout=self._timeout) + await page.click("#mectrl_headerPicture") + + logger.info("Waiting for email input...") + await page.wait_for_selector("#i0116", timeout=self._timeout) + await page.fill("#i0116", self._username) + await page.click("#idSIButton9") + + logger.info("Waiting for password input...") + await page.wait_for_selector("#i0118", timeout=self._timeout) + await page.fill("#i0118", self._password) + await page.click("#idSIButton9") + + logger.info("Waiting for 'Stay signed in?' prompt...") + await page.wait_for_selector("#idSIButton9", timeout=self._timeout) + logger.info("Clicking 'Yes' to stay signed in...") + await page.click("#idSIButton9") + + logger.info("Successfully logged in.") + logger.info("Navigating to Copilot...") + + logger.info("Waiting for Copilot button and clicking it...") + await page.wait_for_selector('div[aria-label="M365 Copilot"]', timeout=self._timeout) + await page.click('div[aria-label="M365 Copilot"]', timeout=self._timeout) + + logger.info("Waiting 60 seconds for bearer token to be captured...") + for _ in range(60): + if bearer_token: + break + await asyncio.sleep(1) + + if bearer_token: + logger.info( + f"Bearer token successfully retrieved. Preview: {bearer_token[:16]}...{bearer_token[-16:]}" + ) + self._save_token_to_cache(token=bearer_token, expires_in=token_expires_in) + else: + logger.error("Failed to retrieve bearer token within 60 seconds.") + + return bearer_token + except Exception as e: + logger.error("Failed to retrieve access token using Playwright.") + + if str(e).startswith("BrowserType.launch"): + logger.error("Playwright browser launch failed. Did you run 'playwright install chromium'?") + else: + logger.error(f"Error details: {e}") + + return None + finally: + logger.info("Gracefully closing Playwright browser instance...") + + if context: + await context.close() + if browser: + await browser.close() diff --git a/pyrit/common/path.py b/pyrit/common/path.py index 40340f28a..14158d3ff 100644 --- a/pyrit/common/path.py +++ b/pyrit/common/path.py @@ -41,6 +41,10 @@ def in_git_repo() -> bool: DB_DATA_PATH = get_default_data_path("dbdata") DB_DATA_PATH.mkdir(parents=True, exist_ok=True) +# Path to where cache files are stored, i.e. token cache, etc. +PYRIT_CACHE_PATH = get_default_data_path(".pyrit_cache") +PYRIT_CACHE_PATH.mkdir(parents=True, exist_ok=True) + # Path to where the logs are located LOG_PATH = pathlib.Path(DB_DATA_PATH, "logs.txt").resolve() LOG_PATH.touch(exist_ok=True) diff --git a/pyrit/prompt_target/__init__.py b/pyrit/prompt_target/__init__.py index cdbbdb0ff..eee3ff6cb 100644 --- a/pyrit/prompt_target/__init__.py +++ b/pyrit/prompt_target/__init__.py @@ -37,6 +37,7 @@ from pyrit.prompt_target.playwright_copilot_target import CopilotType, PlaywrightCopilotTarget from pyrit.prompt_target.prompt_shield_target import PromptShieldTarget from pyrit.prompt_target.text_target import TextTarget +from pyrit.prompt_target.websocket_copilot_target import WebSocketCopilotTarget __all__ = [ "AzureBlobStorageTarget", @@ -66,4 +67,5 @@ "PromptTarget", "RealtimeTarget", "TextTarget", + "WebSocketCopilotTarget", ] diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py new file mode 100644 index 000000000..58a380ecd --- /dev/null +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -0,0 +1,352 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import asyncio +import json +import logging +import os +import uuid +from enum import IntEnum +from typing import Optional + +import websockets + +from pyrit.exceptions import ( + EmptyResponseException, + pyrit_target_retry, +) +from pyrit.models import Message, construct_response_from_request +from pyrit.prompt_target import PromptTarget, limit_requests_per_minute + +logger = logging.getLogger(__name__) + +# Useful links: +# https://github.com/mbrg/power-pwn/blob/main/src/powerpwn/copilot/copilot_connector/copilot_connector.py +# https://labs.zenity.io/p/access-copilot-m365-terminal + + +class CopilotMessageType(IntEnum): + """Enumeration for Copilot WebSocket message types.""" + + UNKNOWN = -1 + PARTIAL_RESPONSE = 1 + FINAL_CONTENT = 2 + STREAM_END = 3 + USER_PROMPT = 4 + PING = 6 + + +class WebSocketCopilotTarget(PromptTarget): + """ + A WebSocket-based prompt target for Microsoft Copilot integration. + + This target enables communication with Microsoft Copilot through a WebSocket connection. + Currently, authentication requires manually extracting a WebSocket URL from an active browser session. + In the future, more flexible authentication mechanisms will be added. + + To obtain the WebSocket URL: + 1. Ensure you are logged into Microsoft 365 with access to Copilot + 2. Navigate to https://m365.cloud.microsoft/chat or open Copilot in https://teams.microsoft.com/v2 + 3. Open browser developer tools and switch to the Network tab + 4. Begin typing or send a message to Copilot to establish the WebSocket connection + 5. Search the network requests for "chathub", "conversation", or "access_token" + 6. Identify the WebSocket connection (look for WS protocol) and copy its full URL + + Warning: + All target instances using the same `WEBSOCKET_URL` will share a single conversation session. + Only works with licensed Microsoft 365 Copilot. The free Copilot version is not compatible. + """ + + # TODO: add more flexible auth, use puppeteer? https://github.com/mbrg/power-pwn/blob/main/src/powerpwn/copilot/copilot_connector/copilot_connector.py#L248 + + SUPPORTED_DATA_TYPES = {"text"} # TODO: support more types? + + RESPONSE_TIMEOUT_SECONDS: int = 60 + CONNECTION_TIMEOUT_SECONDS: int = 30 + + def __init__( + self, + *, + verbose: bool = False, + max_requests_per_minute: Optional[int] = None, + model_name: str = "copilot", + response_timeout_seconds: int = RESPONSE_TIMEOUT_SECONDS, + ) -> None: + """ + Initialize the WebSocketCopilotTarget. + + Args: + verbose (bool): Enable verbose logging. Defaults to False. + max_requests_per_minute (int, Optional): Maximum number of requests per minute. + model_name (str): The model name. Defaults to "copilot". + response_timeout_seconds (int): Timeout for receiving responses in seconds. Defaults to 60s. + + Raises: + ValueError: If WebSocket URL is not provided, is empty, or has invalid format. + ValueError: If required parameters are missing or empty in the WebSocket URL. + """ + self._websocket_url = os.getenv("WEBSOCKET_URL") + if not self._websocket_url or self._websocket_url.strip() == "": + raise ValueError("WebSocket URL must be provided through the WEBSOCKET_URL environment variable") + + if not self._websocket_url.startswith("wss://"): + raise ValueError(f"WebSocket URL must start with 'wss://'. Received: {self._websocket_url[:10]}") + + if "ConversationId=" not in self._websocket_url: + raise ValueError("`ConversationId` parameter not found in WebSocket URL.") + self._conversation_id = self._websocket_url.split("ConversationId=")[1].split("&")[0] + if not self._conversation_id: + raise ValueError("`ConversationId` parameter is empty in WebSocket URL.") + + if "X-SessionId=" not in self._websocket_url: + raise ValueError("`X-SessionId` parameter not found in WebSocket URL.") + self._session_id = self._websocket_url.split("X-SessionId=")[1].split("&")[0] + if not self._session_id: + raise ValueError("`X-SessionId` parameter is empty in WebSocket URL.") + + super().__init__( + verbose=verbose, + max_requests_per_minute=max_requests_per_minute, + endpoint=self._websocket_url.split("?")[0], # wss://substrate.office.com/m365Copilot/Chathub/... + model_name=model_name, + ) + + if response_timeout_seconds <= 0: + raise ValueError("response_timeout_seconds must be a positive integer.") + self._response_timeout_seconds = response_timeout_seconds + + if self._verbose: + logger.info(f"WebSocketCopilotTarget initialized with conversation_id: {self._conversation_id}") + logger.info(f"Session ID: {self._session_id}") + + @staticmethod + def _dict_to_websocket(data: dict) -> str: + # Produce the smallest possible JSON string, followed by record separator + return json.dumps(data, separators=(",", ":")) + "\x1e" + + @staticmethod + def _parse_raw_message(message: str) -> list[tuple[CopilotMessageType, str]]: + """ + Extract actionable content from a raw WebSocket message. + Returns more than one JSON message if multiple are found. + + Args: + message (str): The raw WebSocket message string. + + Returns: + list[tuple[CopilotMessageType, str]]: A list of tuples where each tuple contains + message type and extracted content. + """ + results: list[tuple[CopilotMessageType, str]] = [] + + # https://github.com/dotnet/aspnetcore/blob/main/src/SignalR/docs/specs/HubProtocol.md#json-encoding + messages = message.split("\x1e") # record separator + + for message in messages: + if not message or not message.strip(): + continue + + try: + data = json.loads(message) + msg_type = CopilotMessageType._value2member_map_.get(data.get("type", -1), CopilotMessageType.UNKNOWN) + + if msg_type in ( + CopilotMessageType.PING, + CopilotMessageType.PARTIAL_RESPONSE, + CopilotMessageType.STREAM_END, + ): + results.append((msg_type, "")) + continue + + if msg_type == CopilotMessageType.FINAL_CONTENT: + bot_text = data.get("item", {}).get("result", {}).get("message", "") + if not bot_text: + # In this case, EmptyResponseException will be raised anyway + logger.warning("FINAL_CONTENT received but no parseable content found.") + logger.debug(f"Full raw message: {message}") + results.append((CopilotMessageType.FINAL_CONTENT, bot_text)) + continue + + results.append((msg_type, "")) + + except json.JSONDecodeError as e: + logger.error(f"Failed to decode JSON message: {str(e)}") + results.append((CopilotMessageType.UNKNOWN, "")) + + return results if results else [(CopilotMessageType.UNKNOWN, "")] + + def _build_prompt_message(self, prompt: str) -> dict: + return { + "arguments": [ + { + "source": "officeweb", # TODO: support 'teamshub' as well + # TODO: not sure whether to uuid.uuid4() or use a static like it's done in power-pwn + # https://github.com/mbrg/power-pwn/blob/main/src/powerpwn/copilot/copilot_connector/copilot_connector.py#L156 + "clientCorrelationId": str(uuid.uuid4()), + "sessionId": self._session_id, + "optionsSets": [ + "enterprise_flux_web", + "enterprise_flux_work", + "enable_request_response_interstitials", + "enterprise_flux_image_v1", + "enterprise_toolbox_with_skdsstore", + "enterprise_toolbox_with_skdsstore_search_message_extensions", + "enable_ME_auth_interstitial", + "skdsstorethirdparty", + "enable_confirmation_interstitial", + "enable_plugin_auth_interstitial", + "enable_response_action_processing", + "enterprise_flux_work_gptv", + "enterprise_flux_work_code_interpreter", + "enable_batch_token_processing", + ], + "options": {}, + "allowedMessageTypes": [ + "Chat", + "Suggestion", + "InternalSearchQuery", + "InternalSearchResult", + "Disengaged", + "InternalLoaderMessage", + "RenderCardRequest", + "AdsQuery", + "SemanticSerp", + "GenerateContentQuery", + "SearchQuery", + "ConfirmationCard", + "AuthError", + "DeveloperLogs", + ], + "sliceIds": [], + # TODO: enable using agents https://github.com/mbrg/power-pwn/blob/main/src/powerpwn/copilot/copilot_connector/copilot_connector.py#L192 + "threadLevelGptId": {}, + "conversationId": self._conversation_id, + "traceId": str(uuid.uuid4()).replace("-", ""), # TODO: same case as clientCorrelationId + "isStartOfSession": 0, + "productThreadType": "Office", + "clientInfo": {"clientPlatform": "web"}, + "message": { + "author": "user", + "inputMethod": "Keyboard", + "text": prompt, + "entityAnnotationTypes": ["People", "File", "Event", "Email", "TeamsMessage"], + "requestId": str(uuid.uuid4()).replace("-", ""), + "locationInfo": {"timeZoneOffset": 0, "timeZone": "UTC"}, + "locale": "en-US", + "messageType": "Chat", + "experienceType": "Default", + }, + "plugins": [], # TODO: support enabling some plugins? + } + ], + "invocationId": "0", # TODO: should be dynamic? + "target": "chat", + "type": CopilotMessageType.USER_PROMPT, + } + + async def _connect_and_send(self, prompt: str) -> str: + protocol_msg = {"protocol": "json", "version": 1} + prompt_dict = self._build_prompt_message(prompt) + + inputs = [protocol_msg, prompt_dict] + last_response = "" + + async with websockets.connect( + self._websocket_url, + open_timeout=self.CONNECTION_TIMEOUT_SECONDS, + close_timeout=self.CONNECTION_TIMEOUT_SECONDS, + ) as websocket: + for input_msg in inputs: + payload = self._dict_to_websocket(input_msg) + await websocket.send(payload) + + stop_polling = False + while not stop_polling: + try: + response = await asyncio.wait_for( + websocket.recv(), + timeout=self._response_timeout_seconds, + ) + except asyncio.TimeoutError: + raise TimeoutError( + f"Timed out waiting for Copilot response after {self._response_timeout_seconds} seconds." + ) + + if response is None: + raise RuntimeError( + "WebSocket connection closed unexpectedly: received None from websocket.recv()" + ) + + parsed_messages = self._parse_raw_message(response) + + for msg_type, content in parsed_messages: + if msg_type in ( + CopilotMessageType.UNKNOWN, + CopilotMessageType.FINAL_CONTENT, + CopilotMessageType.STREAM_END, + ): + stop_polling = True + + if msg_type == CopilotMessageType.FINAL_CONTENT: + last_response = content + elif msg_type == CopilotMessageType.UNKNOWN: + logger.debug("Received unknown or empty message type.") + + return last_response + + def _validate_request(self, *, message: Message) -> None: + n_pieces = len(message.message_pieces) + if n_pieces != 1: + raise ValueError(f"This target only supports a single message piece. Received: {n_pieces} pieces.") + + piece_type = message.message_pieces[0].converted_value_data_type + if piece_type != "text": + raise ValueError(f"This target only supports text prompt input. Received: {piece_type}.") + + @limit_requests_per_minute + @pyrit_target_retry + async def send_prompt_async(self, *, message: Message) -> list[Message]: + """ + Asynchronously send a message to Microsoft Copilot using WebSocket. + + Args: + message (Message): A message to be sent to the target. + + Returns: + list[Message]: A list containing the response from Copilot. + + Raises: + EmptyResponseException: If the response from Copilot is empty. + InvalidStatus: If the WebSocket handshake fails with an HTTP status error. + RuntimeError: If any other error occurs during WebSocket communication. + """ + self._validate_request(message=message) + request_piece = message.message_pieces[0] + + logger.info(f"Sending the following prompt to WebSocketCopilotTarget: {request_piece}") + + try: + prompt_text = request_piece.converted_value + response_text = await self._connect_and_send(prompt_text) + + if not response_text or not response_text.strip(): + logger.error("Empty response received from Copilot.") + raise EmptyResponseException(message="Copilot returned an empty response.") + logger.info(f"Received the following response from WebSocketCopilotTarget: {response_text[:100]}...") + + response_entry = construct_response_from_request( + request=request_piece, response_text_pieces=[response_text] + ) + + return [response_entry] + + except websockets.exceptions.InvalidStatus as e: + logger.error( + f"WebSocket connection failed: {str(e)}\n" + "Ensure the WEBSOCKET_URL environment variable is correct and valid." + " For more details about authentication, refer to the class documentation." + ) + raise + + except Exception as e: + raise RuntimeError(f"An error occurred during WebSocket communication: {str(e)}") from e diff --git a/tests/unit/target/test_websocket_copilot_target.py b/tests/unit/target/test_websocket_copilot_target.py new file mode 100644 index 000000000..0b4e713e1 --- /dev/null +++ b/tests/unit/target/test_websocket_copilot_target.py @@ -0,0 +1,152 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +from unittest.mock import patch + +import pytest + +from pyrit.prompt_target import WebSocketCopilotTarget + + +VALID_WEBSOCKET_URL = ( + "wss://substrate.office.com/m365Copilot/Chathub/test_object_id@test_tenant_id" + "?ClientRequestId=test_client_request_id" + "&X-SessionId=test_session_id&token=abc123" + "&ConversationId=test_conversation_id" + "&access_token=test_access_token" + # "&variants=feature.test_feature_one,feature.test_feature_two" + # "&agent=web" + # "&scenario=OfficeWebIncludedCopilot" +) + + +@pytest.fixture +def mock_env_websocket_url(): + with patch.dict(os.environ, {"WEBSOCKET_URL": VALID_WEBSOCKET_URL}): + yield + + +@pytest.mark.usefixtures("patch_central_database") +class TestWebSocketCopilotTargetInit: + def test_init_with_valid_wss_url(self, mock_env_websocket_url): + target = WebSocketCopilotTarget() + + assert target._websocket_url == VALID_WEBSOCKET_URL + assert target._conversation_id == "test_conversation_id" + assert target._session_id == "test_session_id" + assert target._model_name == "copilot" + + def test_init_with_missing_or_invalid_wss_url(self): + for env_vars in [{}, {"WEBSOCKET_URL": ""}, {"WEBSOCKET_URL": " "}]: + with patch.dict(os.environ, env_vars, clear=True): + with pytest.raises(ValueError, match="WebSocket URL must be provided"): + WebSocketCopilotTarget() + + for invalid_url in ["invalid_websocket_url", "ws://example.com", "https://example.com"]: + with patch.dict(os.environ, {"WEBSOCKET_URL": invalid_url}, clear=True): + with pytest.raises(ValueError, match="WebSocket URL must start with 'wss://'"): + WebSocketCopilotTarget() + + def test_init_with_missing_or_empty_required_params(self): + urls = [ + ("wss://example.com/?X-SessionId=session123", "`ConversationId` parameter not found"), + ("wss://example.com/?ConversationId=conv123", "`X-SessionId` parameter not found"), + ("wss://example.com/?ConversationId=&X-SessionId=session123", "`ConversationId` parameter is empty"), + ("wss://example.com/?ConversationId=conv123&X-SessionId=", "`X-SessionId` parameter is empty"), + ] + + for url, error_msg in urls: + with patch.dict(os.environ, {"WEBSOCKET_URL": url}, clear=True): + with pytest.raises(ValueError, match=error_msg): + WebSocketCopilotTarget() + + def test_init_sets_endpoint_correctly(self, mock_env_websocket_url): + target = WebSocketCopilotTarget() + assert target._endpoint == "wss://substrate.office.com/m365Copilot/Chathub/test_object_id@test_tenant_id" + + def test_init_with_custom_response_timeout(self, mock_env_websocket_url): + target = WebSocketCopilotTarget(response_timeout_seconds=120) + assert target._response_timeout_seconds == 120 + + for invalid_timeout in [0, -10]: + with pytest.raises(ValueError, match="response_timeout_seconds must be a positive integer."): + WebSocketCopilotTarget(response_timeout_seconds=invalid_timeout) + + +@pytest.mark.parametrize( + "data,expected", + [ + ({"key": "value"}, '{"key":"value"}\x1e'), + ({"protocol": "json", "version": 1}, '{"protocol":"json","version":1}\x1e'), + ({"outer": {"inner": "value"}}, '{"outer":{"inner":"value"}}\x1e'), + ({"items": [1, 2, 3]}, '{"items":[1,2,3]}\x1e'), + ], +) +def test_dict_to_websocket_static_method(data, expected): + result = WebSocketCopilotTarget._dict_to_websocket(data) + assert result == expected + + +class TestParseRawMessage: + from pyrit.prompt_target.websocket_copilot_target import CopilotMessageType + + @pytest.mark.parametrize( + "message,expected_types,expected_content", + [ + ("", [CopilotMessageType.UNKNOWN], [""]), + (" \n\t ", [CopilotMessageType.UNKNOWN], [""]), + ("{}\x1e", [CopilotMessageType.UNKNOWN], [""]), + ('{"type":6}\x1e', [CopilotMessageType.PING], [""]), + ( + '{"type":1,"target":"update","arguments":[{"messages":[{"text":"Partial","author":"bot"}]}]}\x1e', + [CopilotMessageType.PARTIAL_RESPONSE], + [""], + ), + ( + '{"type":2,"item":{"result":{"message":"Final."}}}\x1e{"type":3,"invocationId":"0"}\x1e', + [CopilotMessageType.FINAL_CONTENT, CopilotMessageType.STREAM_END], + [ + "Final.", + "", + ], + ), + ], + ) + def test_parse_raw_message_with_valid_data(self, message, expected_types, expected_content): + result = WebSocketCopilotTarget._parse_raw_message(message) + + assert len(result) == len(expected_types) + for i, expected_type in enumerate(expected_types): + assert result[i][0] == expected_type + assert result[i][1] == expected_content[i] + + def test_parse_final_message_without_content(self): + from pyrit.prompt_target.websocket_copilot_target import CopilotMessageType + + with patch("pyrit.prompt_target.websocket_copilot_target.logger") as mock_logger: + message = '{"type":2,"invocationId":"0"}\x1e' + result = WebSocketCopilotTarget._parse_raw_message(message) + + assert len(result) == 1 + assert result[0][0] == CopilotMessageType.FINAL_CONTENT + assert result[0][1] == "" + + mock_logger.warning.assert_called_with("FINAL_CONTENT received but no parseable content found.") + mock_logger.debug.assert_called_with(f"Full raw message: {message[:-1]}") + + @pytest.mark.parametrize( + "message", + [ + '{"type":99,"data":"unknown"}\x1e', + '{"data":"no type field"}\x1e', + '{"invalid json structure\x1e', + ], + ) + def test_parse_unknown_or_invalid_messages(self, message): + from pyrit.prompt_target.websocket_copilot_target import CopilotMessageType + + result = WebSocketCopilotTarget._parse_raw_message(message) + assert len(result) == 1 + assert result[0][0] == CopilotMessageType.UNKNOWN + assert result[0][1] == "" diff --git a/websocket_copilot_simple_example.py b/websocket_copilot_simple_example.py new file mode 100644 index 000000000..a1e13831a --- /dev/null +++ b/websocket_copilot_simple_example.py @@ -0,0 +1,31 @@ +""" +# TODO +THIS WILL BE REMOVED after proper unit tests are in place :) +""" + +import asyncio + +from pyrit.models import Message, MessagePiece +from pyrit.prompt_target import WebSocketCopilotTarget +from pyrit.setup import IN_MEMORY, initialize_pyrit_async + + +async def main(): + await initialize_pyrit_async(memory_db_type=IN_MEMORY) + target = WebSocketCopilotTarget() + + message_piece = MessagePiece( + role="user", + original_value="say only one random word", + original_value_data_type="text", + converted_value_data_type="text", + ) + message = Message(message_pieces=[message_piece]) + + responses = await target.send_prompt_async(message=message) + for response in responses: + print(f"{response.get_value()}") + + +if __name__ == "__main__": + asyncio.run(main())