diff --git a/src/vm-repair/HISTORY.rst b/src/vm-repair/HISTORY.rst index 1a331136141..132995054b2 100644 --- a/src/vm-repair/HISTORY.rst +++ b/src/vm-repair/HISTORY.rst @@ -2,6 +2,13 @@ Release History =============== +2.2.1 +++++++ +Migrating telemetry from Application Insights SDK to azure.cli.core telemetry pipeline +Adding PII scrubbing for error messages and stack traces +Fixing if/elif bug in command_helper_class.py telemetry dispatch +Removing unused opencensus dependency + 2.2.0 ++++++ Adding `--tags` parameter to `vm repair create` and `vm repair repair-and-restore` commands to allow users to tag the repair VM for organizational requirements diff --git a/src/vm-repair/azext_vm_repair/command_helper_class.py b/src/vm-repair/azext_vm_repair/command_helper_class.py index 05945c1fc9f..7354a96db7b 100644 --- a/src/vm-repair/azext_vm_repair/command_helper_class.py +++ b/src/vm-repair/azext_vm_repair/command_helper_class.py @@ -10,9 +10,7 @@ import inspect from knack.log import get_logger -from azure.cli.core.commands.client_factory import get_subscription_id - -from .telemetry import _track_command_telemetry, _track_run_command_telemetry, _track_command_telemetry_repair_and_restore +from .telemetry import _track_command_telemetry, _track_run_command_telemetry, _track_command_telemetry_repair_and_restore, _generate_user_hash from .repair_utils import _get_function_param_dict @@ -65,9 +63,18 @@ def __init__(self, logger, cmd, command_name): # Error stack trace self.error_stack_trace = '' + # Exception type for telemetry (e.g. 'SkuNotAvailableError') + self.exception_type = '' + # Return dict self.return_dict = {} + # Extra telemetry context (VM properties, feature flags, etc.) + self.telemetry_context = {} + + # Pseudonymous caller hash + self.telemetry_context['UserHash'] = _generate_user_hash(cmd) + # Verbose flag for command self.is_verbose = any(handler.level == logging.INFO for handler in get_logger().handlers) @@ -87,12 +94,17 @@ def __del__(self): self.cmd.cli_ctx.get_progress_controller().end() # Track telemetry data elapsed_time = timeit.default_timer() - self.start_time - if self.command_name == VM_REPAIR_RUN_COMMAND: - _track_run_command_telemetry(self.logger, self.command_name, self.command_params, self.status, self.message, self.error_message, self.error_stack_trace, elapsed_time, get_subscription_id(self.cmd.cli_ctx), self.return_dict, self.script.run_id, self.script.status, self.script.output, self.script.run_time) - if self.command_name == VM_REPAIR_AND_RESTORE_COMMAND: - _track_command_telemetry_repair_and_restore(self.logger, self.command_name, self.status, self.message, self.error_message, self.error_stack_trace, elapsed_time, get_subscription_id(self.cmd.cli_ctx)) - else: - _track_command_telemetry(self.logger, self.command_name, self.command_params, self.status, self.message, self.error_message, self.error_stack_trace, elapsed_time, get_subscription_id(self.cmd.cli_ctx), self.return_dict) + if self.exception_type: + self.telemetry_context['ExceptionType'] = self.exception_type + try: + if self.command_name == VM_REPAIR_RUN_COMMAND: + _track_run_command_telemetry(self.logger, self.command_name, self.command_params, self.status, self.message, self.error_message, self.error_stack_trace, elapsed_time, self.return_dict, self.script.run_id, self.script.status, self.script.output, self.script.run_time, context=self.telemetry_context) + elif self.command_name == VM_REPAIR_AND_RESTORE_COMMAND: + _track_command_telemetry_repair_and_restore(self.logger, self.command_name, self.status, self.message, self.error_message, self.error_stack_trace, elapsed_time, context=self.telemetry_context) + else: + _track_command_telemetry(self.logger, self.command_name, self.command_params, self.status, self.message, self.error_message, self.error_stack_trace, elapsed_time, self.return_dict, context=self.telemetry_context) + except Exception: # pylint: disable=broad-except + self.logger.debug('Failed to send telemetry for %s', self.command_name) def set_status_success(self): """ Set command status to success """ diff --git a/src/vm-repair/azext_vm_repair/custom.py b/src/vm-repair/azext_vm_repair/custom.py index c8147a5547b..2ab604ef756 100644 --- a/src/vm-repair/azext_vm_repair/custom.py +++ b/src/vm-repair/azext_vm_repair/custom.py @@ -144,6 +144,17 @@ def create(cmd, vm_name, resource_group_name, repair_password=None, repair_usern # Fetching the architecture of the source VM. architecture_type = _fetch_architecture(source_vm) + # Enrich telemetry with VM context + command.telemetry_context.update({ + 'OsType': 'Linux' if is_linux else 'Windows', + 'HyperVGeneration': str(vm_hypervgen), + 'Architecture': str(architecture_type), + 'IsManagedDisk': str(is_managed), + 'IsEncrypted': str(bool(encrypt_recovery_key)), + 'EnableNested': str(bool(enable_nested)), + 'AssociatePublicIp': str(bool(associate_public_ip)), + }) + # Checking if the source VM's OS is Linux and if it uses a managed disk. if is_linux and _uses_managed_disk(source_vm): # Setting the OS type to 'Linux'. @@ -422,34 +433,42 @@ def create(cmd, vm_name, resource_group_name, repair_password=None, repair_usern command.error_stack_trace = traceback.format_exc() command.error_message = "Command interrupted by user input." command.message = "Command interrupted by user input. Cleaning up resources." + command.exception_type = 'KeyboardInterrupt' except AzCommandError as azCommandError: command.error_stack_trace = traceback.format_exc() command.error_message = str(azCommandError) command.message = "Repair create failed. Cleaning up created resources." + command.exception_type = 'AzCommandError' except SkuDoesNotSupportHyperV as skuDoesNotSupportHyperV: command.error_stack_trace = traceback.format_exc() command.error_message = str(skuDoesNotSupportHyperV) command.message = "provided sku does not support nested VM in hyperv. Please run command without --enabled-nested or provide a valid --size parameter. Cleaning up created resources." + command.exception_type = 'SkuDoesNotSupportHyperV' except ScriptReturnsError as scriptReturnsError: command.error_stack_trace = traceback.format_exc() command.error_message = str(scriptReturnsError) command.message = "Error returned from script when enabling hyperv." + command.exception_type = 'ScriptReturnsError' except SkuNotAvailableError as skuNotAvailableError: command.error_stack_trace = traceback.format_exc() command.error_message = str(skuNotAvailableError) command.message = "Please check if the current subscription can create more VM resources. Cleaning up created resources." + command.exception_type = 'SkuNotAvailableError' except UnmanagedDiskCopyError as unmanagedDiskCopyError: command.error_stack_trace = traceback.format_exc() command.error_message = str(unmanagedDiskCopyError) command.message = "Repair create failed. Please try again at another time. Cleaning up created resources." + command.exception_type = 'UnmanagedDiskCopyError' except WindowsOsNotAvailableError: command.error_stack_trace = traceback.format_exc() command.error_message = 'Compatible Windows OS image not available.' command.message = 'A compatible Windows OS image is not available at this time, please check subscription.' + command.exception_type = 'WindowsOsNotAvailableError' except Exception as exception: command.error_stack_trace = traceback.format_exc() command.error_message = str(exception) command.message = 'An unexpected error occurred. Try running again with the --debug flag to debug.' + command.exception_type = type(exception).__name__ finally: if command.error_stack_trace: @@ -546,16 +565,19 @@ def restore(cmd, vm_name, resource_group_name, disk_name=None, repair_vm_id=None command.error_stack_trace = traceback.format_exc() command.error_message = "Command interrupted by user input." command.message = "Command interrupted by user input. If the restore command fails at retry, please rerun the repair process from \'az vm repair create\'." + command.exception_type = 'KeyboardInterrupt' except AzCommandError as azCommandError: # Capture the stack trace and set the error message if an Azure command error occurs command.error_stack_trace = traceback.format_exc() command.error_message = str(azCommandError) command.message = "Repair restore failed. If the restore command fails at retry, please rerun the repair process from \'az vm repair create\'." + command.exception_type = 'AzCommandError' except Exception as exception: # Capture the stack trace and set the error message if an unexpected error occurs command.error_stack_trace = traceback.format_exc() command.error_message = str(exception) command.message = 'An unexpected error occurred. Try running again with the --debug flag to debug.' + command.exception_type = type(exception).__name__ finally: # Log the stack trace if an error has occurred if command.error_stack_trace: @@ -707,22 +729,27 @@ def run(cmd, vm_name, resource_group_name, run_id=None, repair_vm_id=None, custo command.error_stack_trace = traceback.format_exc() command.error_message = "Command interrupted by user input." command.message = "Repair run failed. Command interrupted by user input." + command.exception_type = 'KeyboardInterrupt' except AzCommandError as azCommandError: command.error_stack_trace = traceback.format_exc() command.error_message = str(azCommandError) command.message = "Repair run failed." + command.exception_type = 'AzCommandError' except requests.exceptions.RequestException as exception: command.error_stack_trace = traceback.format_exc() command.error_message = str(exception) command.message = "Failed to fetch run script data from GitHub. Please check this repository is reachable: https://github.com/Azure/repair-script-library" + command.exception_type = 'RequestException' except RunScriptNotFoundForIdError as exception: command.error_stack_trace = traceback.format_exc() command.error_message = str(exception) command.message = "Repair run failed. Run ID not found." + command.exception_type = 'RunScriptNotFoundForIdError' except Exception as exception: command.error_stack_trace = traceback.format_exc() command.error_message = str(exception) command.message = 'An unexpected error occurred. Try running again with the --debug flag to debug.' + command.exception_type = type(exception).__name__ finally: if command.error_stack_trace: logger.debug(command.error_stack_trace) @@ -918,26 +945,31 @@ def reset_nic(cmd, vm_name, resource_group_name, yes=False): command.error_stack_trace = traceback.format_exc() command.error_message = "Command interrupted by user input." command.message = "Command interrupted by user input." + command.exception_type = 'KeyboardInterrupt' except AzCommandError as azCommandError: command.set_status_error() command.error_stack_trace = traceback.format_exc() command.error_message = str(azCommandError) command.message = "Reset NIC failed." + command.exception_type = 'AzCommandError' except SupportingResourceNotFoundError as resourceError: command.set_status_error() command.error_stack_trace = traceback.format_exc() command.error_message = str(resourceError) command.message = "Reset NIC could not be initiated." + command.exception_type = 'SupportingResourceNotFoundError' except CommandCanceledByUserError as canceledError: command.set_status_error() command.error_stack_trace = traceback.format_exc() command.error_message = str(canceledError) command.message = VM_OFF_MESSAGE + command.exception_type = 'CommandCanceledByUserError' except Exception as exception: command.set_status_error() command.error_stack_trace = traceback.format_exc() command.error_message = str(exception) command.message = 'An unexpected error occurred. Try running again with the --debug flag to debug.' + command.exception_type = type(exception).__name__ else: command.set_status_success() command.message = 'VM guest NIC reset complete. The VM is in running state.' @@ -1017,6 +1049,7 @@ def repair_and_restore(cmd, vm_name, resource_group_name, repair_password=None, command.error_stack_trace = traceback.format_exc() command.error_message = "Command failed when running fstab script." command.message = "Command failed when running fstab script." + command.exception_type = 'FstabScriptError' # If the resource group existed before, confirm before cleaning up resources # Otherwise, clean up resources without confirmation @@ -1117,6 +1150,7 @@ def repair_button(cmd, vm_name, resource_group_name, button_command, repair_pass command.error_stack_trace = traceback.format_exc() command.error_message = "Command failed when running script." command.message = "Command failed when running script." + command.exception_type = 'ButtonScriptError' if existing_rg: _clean_up_resources(repair_group_name, confirm=True) else: diff --git a/src/vm-repair/azext_vm_repair/telemetry.py b/src/vm-repair/azext_vm_repair/telemetry.py index b765863dbbc..c0d1aacf519 100644 --- a/src/vm-repair/azext_vm_repair/telemetry.py +++ b/src/vm-repair/azext_vm_repair/telemetry.py @@ -3,72 +3,133 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -# pylint: disable=line-too-long, import-error, broad-except +# pylint: disable=line-too-long +import hashlib import json -from applicationinsights import TelemetryClient -from .repair_utils import _get_current_vmrepair_version +import re -# For test releases and testing -TEST_KEY = 'a6bdff92-33b5-426f-9123-33875d0ae98b' -PROD_KEY = '3e7130f2-759b-41d4-afb8-f1405d1d7ed9' +from azure.cli.core import telemetry as telemetry_core -tc = TelemetryClient(PROD_KEY) -tc.context.application.ver = _get_current_vmrepair_version() +EXTENSION_NAME = 'vm-repair' +# Patterns for scrubbing PII from error messages and stack traces +_EMAIL_RE = re.compile(r'[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+') +_HOME_DIR_RE = re.compile(r'(?:/(?:home|Users)/|[A-Za-z]:\\Users\\)[^\s/\\]+', re.IGNORECASE) -def _track_command_telemetry(logger, command_name, parameters, status, message, error_message, error_stack_trace, duration, subscription_id, result_json): - try: - properties = { - 'command_name': command_name, - 'parameters': json.dumps(parameters), - 'command_status': status, - 'message': message, - 'error_message': error_message, - 'error_stack_trace': error_stack_trace, - 'subscription_id': subscription_id, - 'result_json': json.dumps(result_json) - } - measurements = {'command_duration': duration} - tc.track_event(command_name, properties, measurements) - tc.flush() - except Exception as exception: - logger.error('Unexpected error sending telemetry with exception: %s', str(exception)) - - -def _track_run_command_telemetry(logger, command_name, parameters, status, message, error_message, error_stack_trace, duration, subscription_id, result_json, script_run_id, script_status, script_output, script_duration): - try: - properties = { - 'command_name': command_name, - 'parameters': json.dumps(parameters), - 'command_status': status, - 'message': message, - 'error_message': error_message, - 'error_stack_trace': error_stack_trace, - 'subscription_id': subscription_id, - 'result_json': json.dumps(result_json), - 'script_run_id': script_run_id, - 'script_status': script_status, - 'script_output': script_output - } - measurements = {'command_duration': duration, 'script_duration': script_duration} - tc.track_event(command_name, properties, measurements) - tc.flush() - except Exception as exception: - logger.error('Unexpected error sending telemetry with exception: %s', str(exception)) - - -def _track_command_telemetry_repair_and_restore(logger, command_name, status, message, error_message, error_stack_trace, duration, subscription_id): + +def _scrub_pii(value): + """Remove emails and home directory usernames from a string.""" + if not isinstance(value, str) or not value: + return value + value = _EMAIL_RE.sub('[REDACTED_EMAIL]', value) + value = _HOME_DIR_RE.sub('[REDACTED_PATH]', value) + return value + + +def _hash_value(value): + """One-way hash a single value. + + Returns None and empty string unchanged. Any other value is converted to a + string, SHA-256 hashed, and truncated to a 16-character hexadecimal digest. + """ + if value is None or value == '': + return value + return hashlib.sha256(str(value).encode('utf-8')).hexdigest()[:16] + + +# Parameter keys whose values are Azure resource identifiers and should be +# hashed rather than sent in cleartext. Keys not in this set are kept as-is +# (booleans, enums, flags, etc.). +_RESOURCE_ID_KEYS = frozenset([ + 'vm_name', 'resource_group_name', 'subscription_id', + 'repair_vm_name', 'repair_group_name', 'copy_disk_name', + 'disk_name', 'repair_vm_id', +]) + + +def _hash_resource_params(parameters): + """Return a copy of *parameters* with resource identifiers hashed.""" + if not isinstance(parameters, dict): + return parameters + sanitized = {} + for key, value in parameters.items(): + if key in _RESOURCE_ID_KEYS and value and value != '********': + sanitized[key] = _hash_value(value) + else: + sanitized[key] = value + return sanitized + + +def _hash_result_json(result_json): + """Return a copy of *result_json* with known resource fields hashed.""" + if not isinstance(result_json, dict): + return result_json + _result_resource_keys = { + 'repair_vm_name', 'copied_disk_name', 'copied_disk_uri', + 'repair_resource_group', 'repair_vm_id', + } + sanitized = {} + for key, value in result_json.items(): + if key in _result_resource_keys and isinstance(value, str) and value: + sanitized[key] = _hash_value(value) + else: + sanitized[key] = value + return sanitized + + +def _generate_user_hash(cmd): + """Generate a one-way pseudonymous identifier for the current caller. + + Combines subscription_id + user principal with a namespace prefix, then + SHA-256 hashes it. The 16-hex-char result is deterministic (same + user = same hash across sessions) but irreversible. + """ try: - properties = { - 'command_name': command_name, - 'command_status': status, - 'message': message, - 'error_message': error_message, - 'error_stack_trace': error_stack_trace, - 'subscription_id': subscription_id - } - measurements = {'command_duration': duration} - tc.track_event(command_name, properties, measurements) - tc.flush() - except Exception as exception: - logger.error('Unexpected error sending telemetry with exception: %s', str(exception)) + from azure.cli.core._profile import Profile + profile = Profile(cli_ctx=cmd.cli_ctx) + account = profile.get_subscription() + user_identity = account.get('user', {}).get('name', '') + sub_id = account.get('id', '') + except Exception: # pylint: disable=broad-except + return 'unknown' + raw = f"vmrepair:{sub_id}:{user_identity}" + return hashlib.sha256(raw.encode('utf-8')).hexdigest()[:16] + + +def _build_base_properties(command_name, status, message, error_message, error_stack_trace, duration, context=None): + """Build the common property dict shared by all telemetry functions.""" + props = { + 'Context.Default.AzureCLI.VmRepairCommandName': command_name, + 'Context.Default.AzureCLI.VmRepairStatus': status, + 'Context.Default.AzureCLI.VmRepairMessage': _scrub_pii(message), + 'Context.Default.AzureCLI.VmRepairErrorMessage': _scrub_pii(error_message), + 'Context.Default.AzureCLI.VmRepairErrorStackTrace': _scrub_pii(error_stack_trace), + 'Context.Default.AzureCLI.VmRepairCommandDuration': duration + } + if context: + for key, value in context.items(): + props[f'Context.Default.AzureCLI.VmRepair{key}'] = value + return props + + +def _track_command_telemetry(logger, command_name, parameters, status, message, error_message, error_stack_trace, duration, result_json, context=None): # pylint: disable=unused-argument + properties = _build_base_properties(command_name, status, message, error_message, error_stack_trace, duration, context) + properties['Context.Default.AzureCLI.VmRepairParameters'] = json.dumps(_hash_resource_params(parameters)) + properties['Context.Default.AzureCLI.VmRepairResultJson'] = json.dumps(_hash_result_json(result_json)) + telemetry_core.add_extension_event(EXTENSION_NAME, properties) + + +def _track_run_command_telemetry(logger, command_name, parameters, status, message, error_message, error_stack_trace, duration, result_json, script_run_id, script_status, script_output, script_duration, context=None): # pylint: disable=unused-argument + properties = _build_base_properties(command_name, status, message, error_message, error_stack_trace, duration, context) + properties['Context.Default.AzureCLI.VmRepairParameters'] = json.dumps(_hash_resource_params(parameters)) + properties['Context.Default.AzureCLI.VmRepairResultJson'] = json.dumps(_hash_result_json(result_json)) + properties['Context.Default.AzureCLI.VmRepairScriptRunId'] = script_run_id + properties['Context.Default.AzureCLI.VmRepairScriptStatus'] = script_status + properties['Context.Default.AzureCLI.VmRepairScriptOutput'] = _scrub_pii(script_output) + properties['Context.Default.AzureCLI.VmRepairScriptDuration'] = script_duration + telemetry_core.add_extension_event(EXTENSION_NAME, properties) + + +def _track_command_telemetry_repair_and_restore(logger, command_name, status, message, error_message, error_stack_trace, duration, context=None): # pylint: disable=unused-argument + properties = _build_base_properties(command_name, status, message, error_message, error_stack_trace, duration, context) + telemetry_core.add_extension_event(EXTENSION_NAME, properties) diff --git a/src/vm-repair/azext_vm_repair/tests/latest/test_telemetry.py b/src/vm-repair/azext_vm_repair/tests/latest/test_telemetry.py new file mode 100644 index 00000000000..1c13df8c348 --- /dev/null +++ b/src/vm-repair/azext_vm_repair/tests/latest/test_telemetry.py @@ -0,0 +1,340 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import unittest +from unittest.mock import patch, MagicMock + +from azext_vm_repair.telemetry import ( + _scrub_pii, + _hash_value, + _hash_resource_params, + _hash_result_json, + _generate_user_hash, + _build_base_properties, + _track_command_telemetry, + _track_run_command_telemetry, + _track_command_telemetry_repair_and_restore, +) + + +class TestScrubPii(unittest.TestCase): + """Tests for the _scrub_pii helper in telemetry.py.""" + + def test_none_returns_none(self): + assert _scrub_pii(None) is None + + def test_empty_string_returns_empty(self): + assert _scrub_pii('') == '' + + def test_non_string_passthrough(self): + assert _scrub_pii(42) == 42 + assert _scrub_pii(3.14) == 3.14 + + def test_email_redacted(self): + result = _scrub_pii('Login failed for user@example.com') + assert '[REDACTED_EMAIL]' in result + assert 'user@example.com' not in result + + def test_multiple_emails_redacted(self): + result = _scrub_pii('From a@b.com to c@d.org') + assert result.count('[REDACTED_EMAIL]') == 2 + assert 'a@b.com' not in result + assert 'c@d.org' not in result + + def test_linux_home_path_redacted(self): + result = _scrub_pii('File not found: /home/johndoe/script.sh') + assert '[REDACTED_PATH]' in result + assert 'johndoe' not in result + + def test_windows_home_path_redacted(self): + result = _scrub_pii('File not found: C:\\Users\\johndoe') + assert '[REDACTED_PATH]' in result + assert 'johndoe' not in result + + def test_macos_home_path_redacted(self): + result = _scrub_pii('Error at /Users/alice/script.py') + assert '[REDACTED_PATH]' in result + assert 'alice' not in result + + def test_other_drive_windows_path_redacted(self): + result = _scrub_pii('Log at D:\\Users\\bob\\logs') + assert '[REDACTED_PATH]' in result + assert 'bob' not in result + + def test_lowercase_windows_path_redacted(self): + result = _scrub_pii('File c:\\users\\carol\\data.txt') + assert '[REDACTED_PATH]' in result + assert 'carol' not in result + + def test_no_pii_unchanged(self): + msg = 'Disk swap completed successfully' + assert _scrub_pii(msg) == msg + + def test_stack_trace_with_mixed_pii(self): + trace = ( + 'Traceback (most recent call last):\n' + ' File "/home/adminuser/repair.py", line 10\n' + 'AuthenticationError: user admin@contoso.com denied' + ) + result = _scrub_pii(trace) + assert 'adminuser' not in result + assert 'admin@contoso.com' not in result + assert '[REDACTED_PATH]' in result + assert '[REDACTED_EMAIL]' in result + + +class TestHashFunctions(unittest.TestCase): + """Tests for _hash_value, _hash_resource_params, and _hash_result_json.""" + + def test_hash_value_deterministic(self): + assert _hash_value('my-vm') == _hash_value('my-vm') + + def test_hash_value_irreversible(self): + result = _hash_value('my-secret-vm') + assert 'my-secret-vm' not in result + assert len(result) == 16 + + def test_hash_value_none_passthrough(self): + assert _hash_value(None) is None + assert _hash_value('') == '' + + def test_hash_resource_params_hashes_resource_keys(self): + params = {'vm_name': 'myvm', 'resource_group_name': 'myrg', 'yes': True, 'size': 'Standard_D2s_v3'} + result = _hash_resource_params(params) + assert result['vm_name'] != 'myvm' + assert result['resource_group_name'] != 'myrg' + assert result['yes'] is True + assert result['size'] == 'Standard_D2s_v3' + + def test_hash_resource_params_skips_redacted(self): + params = {'repair_password': '********'} + result = _hash_resource_params(params) + assert result['repair_password'] == '********' + + def test_hash_result_json_hashes_resource_fields(self): + result = {'repair_vm_name': 'rep-vm', 'status': 'SUCCESS', 'message': 'ok'} + hashed = _hash_result_json(result) + assert hashed['repair_vm_name'] != 'rep-vm' + assert len(hashed['repair_vm_name']) == 16 + assert hashed['status'] == 'SUCCESS' + assert hashed['message'] == 'ok' + + def test_hash_result_json_non_dict_passthrough(self): + assert _hash_result_json([1, 2]) == [1, 2] + assert _hash_result_json(None) is None + + +class TestTrackCommandTelemetry(unittest.TestCase): + """Tests that _track_command_telemetry sends properly structured events.""" + + @patch('azext_vm_repair.telemetry.telemetry_core') + def test_calls_add_extension_event(self, mock_telemetry_core): + logger = MagicMock() + _track_command_telemetry( + logger, 'vm repair create', {'verbose': True}, 'SUCCESS', + 'ok', '', '', 1.5, {'status': 'SUCCESS'} + ) + + mock_telemetry_core.add_extension_event.assert_called_once() + args = mock_telemetry_core.add_extension_event.call_args + assert args[0][0] == 'vm-repair' + props = args[0][1] + assert props['Context.Default.AzureCLI.VmRepairCommandName'] == 'vm repair create' + assert props['Context.Default.AzureCLI.VmRepairStatus'] == 'SUCCESS' + assert props['Context.Default.AzureCLI.VmRepairCommandDuration'] == 1.5 + + @patch('azext_vm_repair.telemetry.telemetry_core') + def test_scrubs_error_message(self, mock_telemetry_core): + logger = MagicMock() + _track_command_telemetry( + logger, 'vm repair create', {}, 'ERROR', + 'msg', 'auth failed for user@corp.com', 'at /home/alice/script.py', + 2.0, {} + ) + + props = mock_telemetry_core.add_extension_event.call_args[0][1] + assert 'user@corp.com' not in props['Context.Default.AzureCLI.VmRepairErrorMessage'] + assert 'alice' not in props['Context.Default.AzureCLI.VmRepairErrorStackTrace'] + + @patch('azext_vm_repair.telemetry.telemetry_core') + def test_hashes_resource_params_and_result(self, mock_telemetry_core): + logger = MagicMock() + params = {'vm_name': 'my-secret-vm', 'resource_group_name': 'my-rg', 'yes': True} + result = {'repair_vm_name': 'repair-abc', 'status': 'SUCCESS', 'message': 'ok'} + _track_command_telemetry( + logger, 'test', params, 'SUCCESS', '', '', '', 0, result + ) + + props = mock_telemetry_core.add_extension_event.call_args[0][1] + import json as _json + sent_params = _json.loads(props['Context.Default.AzureCLI.VmRepairParameters']) + sent_result = _json.loads(props['Context.Default.AzureCLI.VmRepairResultJson']) + # Resource names should be hashed (16-char hex) + assert sent_params['vm_name'] != 'my-secret-vm' + assert len(sent_params['vm_name']) == 16 + assert sent_params['resource_group_name'] != 'my-rg' + # Non-resource fields kept as-is + assert sent_params['yes'] is True + # Result resource fields hashed + assert sent_result['repair_vm_name'] != 'repair-abc' + assert len(sent_result['repair_vm_name']) == 16 + # Non-resource result fields kept + assert sent_result['status'] == 'SUCCESS' + + +class TestTrackRunCommandTelemetry(unittest.TestCase): + + @patch('azext_vm_repair.telemetry.telemetry_core') + def test_includes_script_fields(self, mock_telemetry_core): + logger = MagicMock() + _track_run_command_telemetry( + logger, 'vm repair run', {}, 'SUCCESS', 'ok', '', '', + 3.0, {}, 'run-123', 'Succeeded', 'script output', 1.5 + ) + + props = mock_telemetry_core.add_extension_event.call_args[0][1] + assert props['Context.Default.AzureCLI.VmRepairScriptRunId'] == 'run-123' + assert props['Context.Default.AzureCLI.VmRepairScriptStatus'] == 'Succeeded' + assert props['Context.Default.AzureCLI.VmRepairScriptDuration'] == 1.5 + + @patch('azext_vm_repair.telemetry.telemetry_core') + def test_scrubs_script_output(self, mock_telemetry_core): + logger = MagicMock() + _track_run_command_telemetry( + logger, 'vm repair run', {}, 'ERROR', '', '', '', + 1.0, {}, 'run-1', 'Failed', + 'Error for admin@company.com at /home/bob/run.sh', 0.5 + ) + + props = mock_telemetry_core.add_extension_event.call_args[0][1] + assert 'admin@company.com' not in props['Context.Default.AzureCLI.VmRepairScriptOutput'] + assert 'bob' not in props['Context.Default.AzureCLI.VmRepairScriptOutput'] + + +class TestTrackRepairAndRestore(unittest.TestCase): + + @patch('azext_vm_repair.telemetry.telemetry_core') + def test_minimal_properties(self, mock_telemetry_core): + logger = MagicMock() + _track_command_telemetry_repair_and_restore( + logger, 'vm repair repair-and-restore', 'SUCCESS', 'done', '', '', 4.2 + ) + + props = mock_telemetry_core.add_extension_event.call_args[0][1] + assert props['Context.Default.AzureCLI.VmRepairCommandName'] == 'vm repair repair-and-restore' + assert props['Context.Default.AzureCLI.VmRepairCommandDuration'] == 4.2 + # Should not have script-specific keys + assert 'Context.Default.AzureCLI.VmRepairScriptRunId' not in props + + +class TestGenerateUserHash(unittest.TestCase): + """Tests for the _generate_user_hash pseudonymous identifier.""" + + @patch('azure.cli.core._profile.Profile') + def test_returns_16_char_hex(self, mock_profile_cls): + cmd = MagicMock() + mock_profile_cls.return_value.get_subscription.return_value = { + 'id': 'sub-123', + 'user': {'name': 'user@example.com'} + } + result = _generate_user_hash(cmd) + assert len(result) == 16 + assert all(c in '0123456789abcdef' for c in result) + + @patch('azure.cli.core._profile.Profile') + def test_deterministic(self, mock_profile_cls): + mock_profile_cls.return_value.get_subscription.return_value = { + 'id': 'sub-123', + 'user': {'name': 'user@example.com'} + } + cmd = MagicMock() + assert _generate_user_hash(cmd) == _generate_user_hash(cmd) + + @patch('azure.cli.core._profile.Profile') + def test_different_users_different_hashes(self, mock_profile_cls): + cmd = MagicMock() + mock_profile_cls.return_value.get_subscription.return_value = { + 'id': 'sub-123', + 'user': {'name': 'alice@example.com'} + } + hash_alice = _generate_user_hash(cmd) + mock_profile_cls.return_value.get_subscription.return_value = { + 'id': 'sub-123', + 'user': {'name': 'bob@example.com'} + } + hash_bob = _generate_user_hash(cmd) + assert hash_alice != hash_bob + + @patch('azure.cli.core._profile.Profile') + def test_irreversible(self, mock_profile_cls): + mock_profile_cls.return_value.get_subscription.return_value = { + 'id': 'sub-123', + 'user': {'name': 'user@example.com'} + } + cmd = MagicMock() + result = _generate_user_hash(cmd) + assert 'user@example.com' not in result + assert 'sub-123' not in result + + def test_returns_unknown_on_error(self): + cmd = MagicMock() + # Profile import will work but get_subscription will fail + with patch('azure.cli.core._profile.Profile', side_effect=Exception('no profile')): + result = _generate_user_hash(cmd) + assert result == 'unknown' + + +class TestBuildBaseProperties(unittest.TestCase): + """Tests for _build_base_properties helper.""" + + def test_without_context(self): + props = _build_base_properties('cmd', 'SUCCESS', 'msg', '', '', 1.0) + assert props['Context.Default.AzureCLI.VmRepairCommandName'] == 'cmd' + assert props['Context.Default.AzureCLI.VmRepairStatus'] == 'SUCCESS' + assert 'Context.Default.AzureCLI.VmRepairUserHash' not in props + + def test_with_context(self): + ctx = {'OsType': 'Linux', 'UserHash': 'abc123', 'ExceptionType': 'SkuNotAvailableError'} + props = _build_base_properties('cmd', 'ERROR', '', 'err', 'trace', 2.0, context=ctx) + assert props['Context.Default.AzureCLI.VmRepairOsType'] == 'Linux' + assert props['Context.Default.AzureCLI.VmRepairUserHash'] == 'abc123' + assert props['Context.Default.AzureCLI.VmRepairExceptionType'] == 'SkuNotAvailableError' + + +class TestContextPassthrough(unittest.TestCase): + """Tests that context dict is forwarded through telemetry functions.""" + + @patch('azext_vm_repair.telemetry.telemetry_core') + def test_command_telemetry_with_context(self, mock_telemetry_core): + ctx = {'OsType': 'Linux', 'UserHash': 'deadbeef12345678'} + _track_command_telemetry( + MagicMock(), 'create', {}, 'SUCCESS', '', '', '', 1.0, {}, context=ctx + ) + props = mock_telemetry_core.add_extension_event.call_args[0][1] + assert props['Context.Default.AzureCLI.VmRepairOsType'] == 'Linux' + assert props['Context.Default.AzureCLI.VmRepairUserHash'] == 'deadbeef12345678' + + @patch('azext_vm_repair.telemetry.telemetry_core') + def test_run_telemetry_with_context(self, mock_telemetry_core): + ctx = {'ExceptionType': 'AzCommandError'} + _track_run_command_telemetry( + MagicMock(), 'run', {}, 'ERROR', '', 'err', '', 1.0, {}, + 'run-1', 'Failed', 'output', 0.5, context=ctx + ) + props = mock_telemetry_core.add_extension_event.call_args[0][1] + assert props['Context.Default.AzureCLI.VmRepairExceptionType'] == 'AzCommandError' + + @patch('azext_vm_repair.telemetry.telemetry_core') + def test_repair_and_restore_with_context(self, mock_telemetry_core): + ctx = {'UserHash': 'abc123'} + _track_command_telemetry_repair_and_restore( + MagicMock(), 'repair-and-restore', 'SUCCESS', '', '', '', 1.0, context=ctx + ) + props = mock_telemetry_core.add_extension_event.call_args[0][1] + assert props['Context.Default.AzureCLI.VmRepairUserHash'] == 'abc123' + + +if __name__ == '__main__': + unittest.main() diff --git a/src/vm-repair/conftest.py b/src/vm-repair/conftest.py new file mode 100644 index 00000000000..818a3cf7553 --- /dev/null +++ b/src/vm-repair/conftest.py @@ -0,0 +1,25 @@ +import sys +from unittest.mock import MagicMock + +# Pre-mock azure.cli.core and other dependencies so azext_vm_repair can be imported +# without the full CLI installed. This conftest runs before pytest collects any +# test modules in this directory tree. +_mocks = {} +for mod in [ + 'azure', + 'azure.cli', + 'azure.cli.core', + 'azure.cli.core.commands', + 'azure.cli.core.commands.client_factory', + 'azure.cli.core._profile', + 'knack', + 'knack.log', + 'knack.help_files', +]: + if mod not in sys.modules: + _mocks[mod] = MagicMock() + sys.modules[mod] = _mocks[mod] + +# Ensure the telemetry sub-attribute is a stable mock +sys.modules['azure.cli.core'].telemetry = MagicMock() +sys.modules['azure.cli.core'].AzCommandsLoader = MagicMock() diff --git a/src/vm-repair/setup.py b/src/vm-repair/setup.py index 8fa357755d5..dc1ee53b5b7 100644 --- a/src/vm-repair/setup.py +++ b/src/vm-repair/setup.py @@ -8,7 +8,7 @@ from codecs import open from setuptools import setup, find_packages -VERSION = "2.2.0" +VERSION = "2.2.1" CLASSIFIERS = [ 'Development Status :: 4 - Beta', @@ -25,7 +25,7 @@ 'License :: OSI Approved :: MIT License', ] -DEPENDENCIES = ['opencensus~=0.11.4'] +DEPENDENCIES = [] with open('HISTORY.rst', 'r', encoding='utf-8') as f: HISTORY = f.read()