diff --git a/backend/compact-connect/common_constructs/user_pool.py b/backend/compact-connect/common_constructs/user_pool.py index 1bb903bc8..2fa9c8921 100644 --- a/backend/compact-connect/common_constructs/user_pool.py +++ b/backend/compact-connect/common_constructs/user_pool.py @@ -144,10 +144,10 @@ def __init__( # pylint: disable=too-many-arguments ) def add_custom_app_client_domain( - self, - hosted_zone: IHostedZone, - scope: Construct, - app_client_domain_prefix: str, + self, + hosted_zone: IHostedZone, + scope: Construct, + app_client_domain_prefix: str, ): """ Creates a custom subdomain for the cognito app client in the form of: @@ -159,17 +159,11 @@ def add_custom_app_client_domain( domain_name = f'{domain_prefix}.{hosted_zone.zone_name}' cert_id = f'{app_client_domain_prefix}AuthCert' cert = Certificate( - scope, - cert_id, - domain_name=domain_name, - validation=CertificateValidation.from_dns(hosted_zone=hosted_zone) + scope, cert_id, domain_name=domain_name, validation=CertificateValidation.from_dns(hosted_zone=hosted_zone) ) domain = self.add_domain( f'{app_client_domain_prefix}UserPoolDomain', - custom_domain=CustomDomainOptions( - certificate=cert, - domain_name=domain_name - ), + custom_domain=CustomDomainOptions(certificate=cert, domain_name=domain_name), managed_login_version=ManagedLoginVersion.NEWER_MANAGED_LOGIN, ) @@ -195,7 +189,7 @@ def add_custom_app_client_domain( 'id': 'AwsSolutions-IAM5', 'appliesTo': ['Resource::*'], 'reason': 'This is an AWS-managed custom resource Lambda that requires wildcard permissions' - 'to describe CloudFront distributions.', + 'to describe CloudFront distributions.', } ], ) @@ -211,7 +205,7 @@ def add_custom_app_client_domain( 'appliesTo': [ 'Policy::arn::iam::aws:policy/service-role/AWSLambdaBasicExecutionRole' ], - 'reason': 'This is an AWS-managed custom resource Lambda that uses the standard execution role.' + 'reason': 'This is an AWS-managed custom resource Lambda that uses the standard execution role.', } ], ) @@ -223,12 +217,12 @@ def add_custom_app_client_domain( { 'id': 'HIPAA.Security-LambdaDLQ', 'reason': 'This is an AWS-managed custom resource Lambda used only during deployment.' - 'A DLQ is not necessary.', + 'A DLQ is not necessary.', }, { 'id': 'HIPAA.Security-LambdaInsideVPC', 'reason': 'This is an AWS-managed custom resource Lambda that needs internet access to' - 'describe CloudFront distributions.', + 'describe CloudFront distributions.', }, ], ) @@ -236,8 +230,8 @@ def add_custom_app_client_domain( self.app_client_custom_domain = domain def add_default_app_client_domain( - self, - non_custom_domain_prefix: str, + self, + non_custom_domain_prefix: str, ): """ Creates a cognito based sub domain in the form of: diff --git a/backend/compact-connect/disaster_recovery/FULL_TABLE_RECOVERY.md b/backend/compact-connect/disaster_recovery/FULL_TABLE_RECOVERY.md new file mode 100644 index 000000000..3ab83a5b6 --- /dev/null +++ b/backend/compact-connect/disaster_recovery/FULL_TABLE_RECOVERY.md @@ -0,0 +1,230 @@ +## Overview + +The Full Table Disaster Recovery (DR) system provides automated recovery capabilities for critical DynamoDB tables in the CompactConnect system. This system allows administrators to perform Point-in-Time Recovery (PITR) operations when tables become corrupted or require rollback to a previous state. + +**⚠️ WARNING: This system performs a HARD RESET of the target table, permanently deleting all current data before restoring from the specified timestamp.** + +## When to Use + +This Disaster Recovery process should only be run in the event that the system experiences an event that causes +system-wide failures, such as the following scenarios: + +1. **Data Corruption**: When a table contains corrupted or invalid data that cannot be fixed through normal operations +2. **Accidental Data Loss**: When critical data has been accidentally deleted or modified +3. **Failed Deployments**: When a deployment has caused data integrity issues +4. **Security Incidents**: When unauthorized modifications require rolling back to a clean state +5. **System-wide Issues**: When multiple tables need to be restored to a consistent point in time + +## Architecture + +### Two-Phase Recovery Process +DynamoDB PITR cannot directly restore data into your production database. Instead, it creates a new table with data matching the exact values you had in your production database at the specified timestamp. You as the owner of the database must decide what to do with that data from that point in time. For the purposes of disaster recovery rollback, we have determined to get the data into the production table by performing a 'hard reset', meaning **all the current data in the production table is deleted**, then we copy over the data from the temporary table into the production table. This process includes the following step functions. + +1. **RestoreDynamoDbTable Step Function** (Parent) + - Creates a backup of the current table for post-incident analysis + - Restores a temporary table from the specified PITR timestamp + - Invokes the SyncTableData Step Function + +2. **SyncTableData Step Function** (Child) + - **Delete Phase**: Removes all records from the production table + - **Copy Phase**: Copies all records from the temporary table to the production table + +Once this process is complete, the data in the target table will be restored with the data from the specified point in time. + +### Per-Table Isolation + +Each DynamoDB table has its own dedicated pair of Step Functions: + +- `DRRestoreDynamoDbTable{TableName}StateMachine` +- `{TableName}DRSyncTableDataStateMachine` + +This design allows for: +- **Targeted Recovery**: Restore only the affected table(s) +- **Granular Permissions**: Each Step Function has minimal, table-specific permissions + +## Supported Tables + +The following tables are configured for disaster recovery: + +| Table Name | Step Function Prefix | Purpose | Recovery Notes | +|------------|---------------------|---------|----------------| +| TransactionHistoryTable | `TransactionHistoryTable` | transaction data from authorize.net | Can be rolled back independently. After DR rollback, run the Transaction History Processing Workflow Step Function for each compact for every day where data was lost to restore all transaction data from Authorize.net accounts. The Transaction History Processing Workflow step functions are idempotent. They can be run multiple times without producing duplicate transaction items in the table. | +| ProviderTable | `ProviderTable` | Provider information and GSIs | **Dependent on SSN table** - Can be rolled back without updating SSN table since SSN table does not have a dependency on the provider table. **⚠️ WARNING**: If SSN table needs rollback, the provider table will likely need to be restored to same point in time as SSN table. Otherwise new provider IDs may be generated for existing SSNs causing data inconsistency/orphaned providers that won't receive license updates. After DR rollback, consider that the transaction history table will have a list of all privileges purchased as recorded in Authorize.net, and can be used as a data source for repopulating any privilege records that may have been lost as a result of the rollback.| +| CompactConfigurationTable | `CompactConfigurationTable` | System configuration data | Can be rolled back independently of other tables. Contains configuration set by compact and state admins. Admins may need to reset configurations that were lost as a result of the rollback. | +| DataEventTable | `DataEventTable` | License data events | Used for downstream processing events triggered by Event Bridge event bus. In the event of recovery, many of these events can likely be restored by replaying events placed on the event bus. See https://docs.aws.amazon.com/eventbridge/latest/userguide/eb-archive.html | +| UsersTable | `UsersTable` | Staff user permissions and account data | Can be rolled back independently. Contains staff user permissions and account information. Admins may need to re-invite new users or reset permissions that were lost as a result of the rollback. | + +> **Note**: The SSN table is excluded due to additional security requirements and will be handled in a future implementation. + +## Running the Disaster Recovery Workflow + +## Pre-Execution Checklist + +1. ✅ **Verify Impact**: Confirm which applications/users will be affected +2. ✅ **Communication**: Notify stakeholders of the planned recovery +3. ✅ **Timestamp Selection**: Determine the UTC timestamp to restore to (must be within 35 days) +4. ✅ **Access Verification**: Confirm you have necessary permissions (Currently only AWS account admins can trigger a DR) + +### Step 1: Start Recovery Mode + +Before executing the DR Step Function, you must throttle all Lambda functions to prevent other data operations from occurring while attempting to roll any databases back. There is a script provided to perform this action: + +```bash +# Navigate to the disaster_recovery directory +cd backend/compact-connect/disaster_recovery + +# Start recovery mode for the environment (replace "Prod" with your target environment) +python start_recovery_mode.py --environment Prod +``` + +This will put the system into recovery mode by: +- Setting reserved concurrency to 0 for all environment Lambda functions, so they can't be invoked +- Leaving Disaster Recovery functions operational +- **Important**: If any functions failed to throttle, you may rerun the script or manually check their reserved concurrency settings if needed. The script is idempotent and can be run multiple times. + +### Step 2: Execute Disaster Recovery Step Function For Specific Tables +#### Prerequisites +- Identify the exact table name from the DynamoDB console (needed for `tableNameRecoveryConfirmation`) +- Verify the PITR timestamp is correct +- Create a unique incident ID for tracking (see [Execution Request Parameter Details](#execution-request-parameter-details)) + +When you are ready to perform a rollback, find the step function for the specific table you need to rollback (`DRRestoreDynamoDbTable{TableName}StateMachine`) and start an execution with the following input (replace placeholders with your values) + +```json +{ + "incidentId": "", + "pitrBackupTime": "", + "tableNameRecoveryConfirmation": "" +} +``` + +#### Execution Request Parameter Details + +- **`incidentId`** (required) + - Purpose: Unique identifier for tracking this recovery operation + - Format: String (80 chars or less, allows alphanumeric and hyphens) + - Example: `"incident-2025-001"`, `"corruption-fix-20250115"` + - Used in: Backup names, restored table names, execution tracking + +- **`pitrBackupTime`** (required) + - Purpose: The timestamp to restore the table to + - Format: UTC datetime string + - Example: `"2030-01-15T12:39:46Z"` + - Constraints: Must be within the PITR retention window (35 days) + +- **`tableNameRecoveryConfirmation`** (required) + - Purpose: Security guard rail to prevent accidental execution + - Format: Exact table name being recovered (you can copy this from the DynamoDB console) + - Example: `"Prod-PersistentStack-DataEventTable00A96798-C6VX9JVDOYGN"` + - Validation: Must match the actual destination table name + +example: +```json +{ + "incidentId": "transaction-corruption-20250115", + "pitrBackupTime": "2025-01-15T09:00:00Z", + "tableNameRecoveryConfirmation": "Prod-PersistentStack-TransactionHistoryTable00A96798-C6VX9JVDOYGN" +} +``` + +#### Running Step Functions from AWS Console + +1. Navigate to Step Functions in the AWS Console +2. Find the appropriate Step Function(s) for the table(s) you need to recover (e.g., `DRRestoreDynamoDbTableTransactionHistoryTableStateMachine`) +3. For each step function you need to run, Click "Start Execution" +4. Enter the JSON payload in the input field +5. Click "Start Execution" and wait for completion (multiple Step functions can be run concurrently if you are restoring multiple tables) + +### Step 3: End Recovery Mode + +**⚠️CRITICAL**: Only proceed after ALL recovery Step Functions you have run have completed successfully. + +After the DR Step Function completes successfully for each table you need to restore, end the recovery mode to restore normal operations: + +```bash +# End recovery mode for the environment +python end_recovery_mode.py --environment Prod +``` + +This will: +- Remove reserved concurrency throttling from all Lambda functions +- Restore normal application operations +- Complete the disaster recovery process +- **Important**: If any functions failed to unthrottle, you may rerun the script or manually check their reserved concurrency settings if needed. The script is idempotent and can be run multiple times. + +### Post-Execution + +1. **Verify Recovery**: Confirm data integrity and completeness +2. **Application Testing**: Test critical application functions +3. **Documentation**: Update incident documentation with recovery details +4. **Cleanup Review**: Cleanup temporary resources after post-incident analysis. + +### Operational Constraints + +- **Data Loss**: All data newer than the PITR timestamp will be permanently lost. The backup snapshot may be restored post-recovery to determine which records can potentially be recovered. +- **Dependencies**: Related tables may need coordinated restoration for consistency. + +## Monitoring and Troubleshooting +### Common Issues and Solutions + +#### Invalid table name +- **Cause**: `tableNameRecoveryConfirmation` doesn't match actual table name (this parameter is used to prevent accidental recovery on a database) +- **Solution**: Copy exact table name from DynamoDB console + +#### Restore timestamp out of range +- **Cause**: PITR timestamp is outside the 35-day retention window +- **Solution**: Choose a more recent timestamp within the retention period + +## Complete Table Deletion Recovery (Manual Backup Restoration) + +**⚠️ CRITICAL**: This section applies ONLY when a DynamoDB table has been completely deleted and PITR is not available. This requires manual intervention and cannot use the automated Step Functions. + +### Recovery Steps +Depending on how the table was deleted, there may be a latest 'snapshot' backup in the DynamoDB console that you can recover from. If that snapshot is not available, the system performs daily backups of our tables and store them in the AWS Backup service that you can recover from. + +#### Step 1: Locate the Latest Backup + +##### Option A: DynamoDB Console +1. Navigate to DynamoDB Console → Backups +2. Find the most recent backup for the deleted table +3. Note the backup name and creation time + +##### Option B: AWS Backup Console +1. Navigate to AWS Backup Console → Backup Vaults +2. Find the most recent recovery point for the deleted table +3. **CRITICAL**: Note the "Original table name" from the recovery point details + +#### Step 2: Restore Table from Backup + +1. **From DynamoDB Console**: + - Go to DynamoDB → Backups + - Select the backup → "Restore" + - **CRITICAL Configuration**: + - **Table Name**: Must match EXACTLY the original deleted table name + - **Encryption**: Select "Customer managed key" + - **KMS Key**: Choose `-PersistentStack-shared-encryption-key` for non-ssn tables, `ssn-key` for the SSN table + - Example: `Prod-PersistentStack-shared-encryption-key` + - **Global Secondary Indexes (GSIs)**: Ensure ALL original GSIs are included in the restore by selecting 'Restore the entire table' + - Select 'Restore' + +2. **From AWS Backup Console**: + - Navigate to Recovery Points → Select the backup + - Click "Restore" + - **CRITICAL Configuration**: + - **New Table Name**: Use the EXACT "Original table name" from the recovery point + - **Encryption**: Choose an AWS KMS key -> `-PersistentStack-shared-encryption-key` for non-ssn tables, `ssn-key` for the SSN table + - **GSIs**: Verify all original GSIs are restored + - Select 'Restore Backup' + +#### Step 3: Verify Restoration + +1. **Table Configuration**: + - ✅ Table name matches exactly (including environment prefix and suffix) + - ✅ All Global Secondary Indexes are present + - ✅ Encryption is set to the correct KMS key + - ✅ Table status is "ACTIVE" + +2. **Data Verification**: + - Spot-check critical records + - Verify record counts are reasonable + - Verify application functionality with the restored table diff --git a/backend/compact-connect/disaster_recovery/LICENSE_UPLOAD_ROLLBACK.md b/backend/compact-connect/disaster_recovery/LICENSE_UPLOAD_ROLLBACK.md new file mode 100644 index 000000000..365bdfdc3 --- /dev/null +++ b/backend/compact-connect/disaster_recovery/LICENSE_UPLOAD_ROLLBACK.md @@ -0,0 +1,193 @@ +# License Upload Rollback Guide + +## Overview + +The License Upload Rollback system allows AWS account administrators to automatically revert invalid or corrupted license data that was uploaded by a specific jurisdiction within a defined time window. + +The system will automatically determine which providers had their license records modified as a result of uploads during the time window, and confirm which license updates can be safely rolled back. A provider is eligible for automatic rollback if only license upload-related changes happened since the window. If any other updates have occurred since the start of the time window, the provider will be skipped and manual review will be required to determine which action should be taken for that individual. The rollback process will generate a full JSON report showing which providers had their licenses rolled back and which were skipped and require manual review. + +## Step-by-Step Execution Guide + +### Prerequisites + +Before starting the rollback: + +1. ✅ **Verify the Problem**: Confirm which jurisdiction uploaded bad data for which compact(s) +2. ✅ **Disable automated access for Jurisdiction**: If jurisdiction has API credentials for automated uploads, disable those credentials to prevent further data changes until system has been recovered. To do this, determine which Cognito app client(s) the jurisdiction is using for the compact(s) and delete the appropriate app client(s) from the State Auth Cognito user pool. +3. ✅ **Determine Time Window**: Identify the exact start and end times (UTC) of the problematic uploads +4. ✅ **Determine When Rollback Should be Performed**: Depending on the severity of the issue and scale of records that need to be rolled back, determine if the rollback needs to be performed as soon as possible or if it can be performed outside of peak traffic hours. When possible, it is recommended to perform rollbacks during periods of low traffic. While the risk is low, there is a narrow race condition (.2 second window based on load testing) where a license record may be modified by another part of the system after the rollback system checked for updates and the modification could be removed by the rollback. Running the rollback when traffic is low reduces this risk even further. +5. ✅ **Stakeholder Notification**: Coordinate with relevant state administrators and other stakeholders. Ensure jurisdiction is aware they should not attempt to upload any more license data until the rollback has been completed. + +### Step 1: Gather Required Information + +You'll need the following information for the execution: + +| Parameter | Description | Example | +|-----------|----------------------------------------------------------|---------| +| `compact` | The compact abbreviation (lowercase) | `"aslp"`, `"octp"`, `"counseling"` | +| `jurisdiction` | The state/jurisdiction code (lowercase) | `"oh"`, `"ky"`, `"ne"` | +| `startDateTime` | UTC timestamp when problematic uploads began (inclusive) | `"2020-01-15T08:00:00Z"` | +| `endDateTime` | UTC timestamp when problematic uploads ended (inclusive) | `"2020-01-15T17:59:59Z"` | +| `rollbackReason` | Description for audit trail | `"Invalid license data uploaded by OH staff"` | + +**Important Notes:** +- All timestamps must be in UTC +- Time window cannot exceed 7 days (604,800 seconds) + +### Step 2: Locate the Step Function + +1. Navigate to the AWS Console → Step Functions +2. Find the Step Function with the name prefix: **`LicenseUploadRollbackLicenseUploadRollbackStateMachine`** + +### Step 3: Execute the Step Function + +1. Click **"Start Execution"** +2. Enter a descriptive execution name (this will be used for the S3 results folder): + ``` + rollback-aslp-oh-2020-01-15 + ``` + +3. Paste the following JSON input (replace values with your specific parameters): + +```json +{ + "compact": "aslp", + "jurisdiction": "oh", + "startDateTime": "2020-01-15T08:00:00Z", + "endDateTime": "2020-01-15T17:59:59Z", + "rollbackReason": "Invalid license data uploaded - incorrect expiration dates" +} +``` + +4. Click **"Start Execution"** + +### Step 4: Monitor Execution Progress + +The Step Function will process providers in batches. Monitor the step function execution until it completes and verify the execution was successful. + +### Step 5: Review Results + +Once the execution completes, comprehensive results are stored in S3. The S3 key is returned as output from the lambda step of the step function. Check the Step Function execution output/logs to get the S3 key. + +#### Accessing the Results File + +1. Navigate to S3 in the AWS Console +2. Find the bucket with `disasterrecoveryrollbackresults` in the name. +3. Navigate to the folder matching your execution name: `rollback-aslp-oh-2025-01-15/` +4. Download the file: `results.json` + +#### Understanding the Results Structure + +The results file contains three main sections: + +##### 1. Reverted Provider Summaries + +Providers that were successfully rolled back (example): + +```json +{ + "revertedProviderSummaries": [ + { + "providerId": "01234567-89ab-cdef-0123-456789abcdef", + "licensesReverted": [ + { + "jurisdiction": "oh", + "licenseType": "audiologist", + "revisionId": "98765432-10ab-cdef-0123-456789abcdef", + "action": "REVERT" + } + ], + "privilegesReverted": [ + { + "jurisdiction": "ky", + "licenseType": "audiologist", + "revisionId": "11111111-2222-3333-4444-555555555555", + "action": "REACTIVATED" + } + ], + "updatesDeleted": [ + + ] + } + ] +} +``` + +**Actions Explained:** +- `"REVERT"`: License data was restored to its pre-upload state +- `"DELETE"`: License was newly created during the upload and has been removed +- `"REACTIVATED"`: Privilege was deactivated due to the upload and has been reactivated + +##### 2. Skipped Provider Details + +Providers that require manual review (example): + +```json +{ + "skippedProviderDetails": [ + { + "providerId": "12345678-90ab-cdef-0123-456789abcdef", + "reason": "Provider has updates that are either unrelated to license upload or occurred after rollback end time. Manual review required.", + "ineligibleUpdates": [ + { + "recordType": "licenseUpdate", + "typeOfUpdate": "encumbrance", + "updateTime": "2025-01-16T10:30:00Z", + "licenseType": "audiologist", + "reason": "License was updated with a change unrelated to license upload or the update occurred after rollback end time. Manual review required." + } + ] + } + ] +} +``` + +##### 3. Failed Provider Details + +Providers that encountered errors: + +```json +{ + "failedProviderDetails": [ + { + "providerId": "23456789-01ab-cdef-0123-456789abcdef", + "error": "Failed to rollback updates for provider. Manual review required: ConditionalCheckFailedException" + } + ] +} +``` + +These require technical investigation to determine the cause. + +#### Options for Skipped or Failed Providers + +For providers requiring manual review, you have three options: + +1. **Do Nothing**: If the subsequent updates are valid, the provider's current state is correct +2. **Manual Database Edit**: For complex cases, coordinate with stakeholders to manually adjust records and document manual edits made. +3. **Re-upload Data**: Have the state re-upload correct data for these specific providers through the normal upload process (often the simplest option) + +## Technical Details + +### How the System Identifies Affected Providers + +The system uses the `licenseUploadDateGSI` Global Secondary Index to efficiently query for all license records uploaded during the specified time window. This index is structured as: + +- **Partition Key**: `C#{compact}#J#{jurisdiction}#D#{year-month}` +- **Sort Key**: `TIME#{epoch}#LT#{license_type}#PID#{provider_id}` + +The system queries each month in the time range and collects unique provider IDs. + +### Event Publishing + +For each successfully reverted provider, the system publishes events to the EventBridge event bus: + +- `license.reverted` events for each reverted license +- `privilege.reverted` events for each reactivated privilege + +These events include: +- The rollback reason +- Time window information +- Revision IDs for tracking + +These events purely for auditing purposes. They are not currently referenced by any downstream processes. diff --git a/backend/compact-connect/disaster_recovery/README.md b/backend/compact-connect/disaster_recovery/README.md index 99e7db4a7..e09dfbea0 100644 --- a/backend/compact-connect/disaster_recovery/README.md +++ b/backend/compact-connect/disaster_recovery/README.md @@ -1,232 +1,16 @@ # DynamoDB Disaster Recovery System -## Overview +## 🚨 IMPORTANT: Choose the Right Recovery Tool -The Disaster Recovery (DR) system provides automated recovery capabilities for critical DynamoDB tables in the CompactConnect system. This system allows administrators to perform Point-in-Time Recovery (PITR) operations when tables become corrupted or require rollback to a previous state. +This repository contains TWO DIFFERENT recovery systems for different scenarios: -**⚠️ WARNING: This system performs a HARD RESET of the target table, permanently deleting all current data before restoring from the specified timestamp.** +### 1. **License Upload Rollback** +Use when you need to revert **specific license uploads** from **one jurisdiction** within a **time window**. -## When to Use +See: [LICENSE_UPLOAD_ROLLBACK.md](./LICENSE_UPLOAD_ROLLBACK.md) -This Disaster Recovery process should only be run in the event that the system experiences an event that causes -system-wide failures, such as the following scenarios: +### 2. **Full System Disaster Recovery** +Use when you need to recover **entire DynamoDB tables** affecting **ALL compacts and jurisdictions**. -1. **Data Corruption**: When a table contains corrupted or invalid data that cannot be fixed through normal operations -2. **Accidental Data Loss**: When critical data has been accidentally deleted or modified -3. **Failed Deployments**: When a deployment has caused data integrity issues -4. **Security Incidents**: When unauthorized modifications require rolling back to a clean state -5. **System-wide Issues**: When multiple tables need to be restored to a consistent point in time +See: [FULL_TABLE_RECOVERY.md](./FULL_TABLE_RECOVERY.md) -## Architecture - -### Two-Phase Recovery Process -DynamoDB PITR cannot directly restore data into your production database. Instead, it creates a new table with data matching the exact values you had in your production database at the specified timestamp. You as the owner of the database must decide what to do with that data from that point in time. For the purposes of disaster recovery rollback, we have determined to get the data into the production table by performing a 'hard reset', meaning **all the current data in the production table is deleted**, then we copy over the data from the temporary table into the production table. This process includes the following step functions. - -1. **RestoreDynamoDbTable Step Function** (Parent) - - Creates a backup of the current table for post-incident analysis - - Restores a temporary table from the specified PITR timestamp - - Invokes the SyncTableData Step Function - -2. **SyncTableData Step Function** (Child) - - **Delete Phase**: Removes all records from the production table - - **Copy Phase**: Copies all records from the temporary table to the production table - -Once this process is complete, the data in the target table will be restored with the data from the specified point in time. - -### Per-Table Isolation - -Each DynamoDB table has its own dedicated pair of Step Functions: - -- `DRRestoreDynamoDbTable{TableName}StateMachine` -- `{TableName}DRSyncTableDataStateMachine` - -This design allows for: -- **Targeted Recovery**: Restore only the affected table(s) -- **Granular Permissions**: Each Step Function has minimal, table-specific permissions - -## Supported Tables - -The following tables are configured for disaster recovery: - -| Table Name | Step Function Prefix | Purpose | Recovery Notes | -|------------|---------------------|---------|----------------| -| TransactionHistoryTable | `TransactionHistoryTable` | transaction data from authorize.net | Can be rolled back independently. After DR rollback, run the Transaction History Processing Workflow Step Function for each compact for every day where data was lost to restore all transaction data from Authorize.net accounts. The Transaction History Processing Workflow step functions are idempotent. They can be run multiple times without producing duplicate transaction items in the table. | -| ProviderTable | `ProviderTable` | Provider information and GSIs | **Dependent on SSN table** - Can be rolled back without updating SSN table since SSN table does not have a dependency on the provider table. **⚠️ WARNING**: If SSN table needs rollback, the provider table will likely need to be restored to same point in time as SSN table. Otherwise new provider IDs may be generated for existing SSNs causing data inconsistency/orphaned providers that won't receive license updates. After DR rollback, consider that the transaction history table will have a list of all privileges purchased as recorded in Authorize.net, and can be used as a data source for repopulating any privilege records that may have been lost as a result of the rollback.| -| CompactConfigurationTable | `CompactConfigurationTable` | System configuration data | Can be rolled back independently of other tables. Contains configuration set by compact and state admins. Admins may need to reset configurations that were lost as a result of the rollback. | -| DataEventTable | `DataEventTable` | License data events | Used for downstream processing events triggered by Event Bridge event bus. In the event of recovery, many of these events can likely be restored by replaying events placed on the event bus. See https://docs.aws.amazon.com/eventbridge/latest/userguide/eb-archive.html | -| UsersTable | `UsersTable` | Staff user permissions and account data | Can be rolled back independently. Contains staff user permissions and account information. Admins may need to re-invite new users or reset permissions that were lost as a result of the rollback. | - -> **Note**: The SSN table is excluded due to additional security requirements and will be handled in a future implementation. - -## Running the Disaster Recovery Workflow - -## Pre-Execution Checklist - -1. ✅ **Verify Impact**: Confirm which applications/users will be affected -2. ✅ **Communication**: Notify stakeholders of the planned recovery -3. ✅ **Timestamp Selection**: Determine the UTC timestamp to restore to (must be within 35 days) -4. ✅ **Access Verification**: Confirm you have necessary permissions (Currently only AWS account admins can trigger a DR) - -### Step 1: Start Recovery Mode - -Before executing the DR Step Function, you must throttle all Lambda functions to prevent other data operations from occurring while attempting to roll any databases back. There is a script provided to perform this action: - -```bash -# Navigate to the disaster_recovery directory -cd backend/compact-connect/disaster_recovery - -# Start recovery mode for the environment (replace "Prod" with your target environment) -python start_recovery_mode.py --environment Prod -``` - -This will put the system into recovery mode by: -- Setting reserved concurrency to 0 for all environment Lambda functions, so they can't be invoked -- Leaving Disaster Recovery functions operational -- **Important**: If any functions failed to throttle, you may rerun the script or manually check their reserved concurrency settings if needed. The script is idempotent and can be run multiple times. - -### Step 2: Execute Disaster Recovery Step Function For Specific Tables -#### Prerequisites -- Identify the exact table name from the DynamoDB console (needed for `tableNameRecoveryConfirmation`) -- Verify the PITR timestamp is correct -- Create a unique incident ID for tracking (see [Execution Request Parameter Details](#execution-request-parameter-details)) - -When you are ready to perform a rollback, find the step function for the specific table you need to rollback (`DRRestoreDynamoDbTable{TableName}StateMachine`) and start an execution with the following input (replace placeholders with your values) - -```json -{ - "incidentId": "", - "pitrBackupTime": "", - "tableNameRecoveryConfirmation": "
" -} -``` - -#### Execution Request Parameter Details - -- **`incidentId`** (required) - - Purpose: Unique identifier for tracking this recovery operation - - Format: String (80 chars or less, allows alphanumeric and hyphens) - - Example: `"incident-2025-001"`, `"corruption-fix-20250115"` - - Used in: Backup names, restored table names, execution tracking - -- **`pitrBackupTime`** (required) - - Purpose: The timestamp to restore the table to - - Format: UTC datetime string - - Example: `"2030-01-15T12:39:46Z"` - - Constraints: Must be within the PITR retention window (35 days) - -- **`tableNameRecoveryConfirmation`** (required) - - Purpose: Security guard rail to prevent accidental execution - - Format: Exact table name being recovered (you can copy this from the DynamoDB console) - - Example: `"Prod-PersistentStack-DataEventTable00A96798-C6VX9JVDOYGN"` - - Validation: Must match the actual destination table name - -example: -```json -{ - "incidentId": "transaction-corruption-20250115", - "pitrBackupTime": "2025-01-15T09:00:00Z", - "tableNameRecoveryConfirmation": "Prod-PersistentStack-TransactionHistoryTable00A96798-C6VX9JVDOYGN" -} -``` - -#### Running Step Functions from AWS Console - -1. Navigate to Step Functions in the AWS Console -2. Find the appropriate Step Function(s) for the table(s) you need to recover (e.g., `DRRestoreDynamoDbTableTransactionHistoryTableStateMachine`) -3. For each step function you need to run, Click "Start Execution" -4. Enter the JSON payload in the input field -5. Click "Start Execution" and wait for completion (multiple Step functions can be run concurrently if you are restoring multiple tables) - -### Step 3: End Recovery Mode - -**⚠️CRITICAL**: Only proceed after ALL recovery Step Functions you have run have completed successfully. - -After the DR Step Function completes successfully for each table you need to restore, end the recovery mode to restore normal operations: - -```bash -# End recovery mode for the environment -python end_recovery_mode.py --environment Prod -``` - -This will: -- Remove reserved concurrency throttling from all Lambda functions -- Restore normal application operations -- Complete the disaster recovery process -- **Important**: If any functions failed to unthrottle, you may rerun the script or manually check their reserved concurrency settings if needed. The script is idempotent and can be run multiple times. - -### Post-Execution - -1. **Verify Recovery**: Confirm data integrity and completeness -2. **Application Testing**: Test critical application functions -3. **Documentation**: Update incident documentation with recovery details -4. **Cleanup Review**: Cleanup temporary resources after post-incident analysis. - -### Operational Constraints - -- **Data Loss**: All data newer than the PITR timestamp will be permanently lost. The backup snapshot may be restored post-recovery to determine which records can potentially be recovered. -- **Dependencies**: Related tables may need coordinated restoration for consistency. - -## Monitoring and Troubleshooting -### Common Issues and Solutions - -#### Invalid table name -- **Cause**: `tableNameRecoveryConfirmation` doesn't match actual table name (this parameter is used to prevent accidental recovery on a database) -- **Solution**: Copy exact table name from DynamoDB console - -#### Restore timestamp out of range -- **Cause**: PITR timestamp is outside the 35-day retention window -- **Solution**: Choose a more recent timestamp within the retention period - -## Complete Table Deletion Recovery (Manual Backup Restoration) - -**⚠️ CRITICAL**: This section applies ONLY when a DynamoDB table has been completely deleted and PITR is not available. This requires manual intervention and cannot use the automated Step Functions. - -### Recovery Steps -Depending on how the table was deleted, there may be a latest 'snapshot' backup in the DynamoDB console that you can recover from. If that snapshot is not available, the system performs daily backups of our tables and store them in the AWS Backup service that you can recover from. - -#### Step 1: Locate the Latest Backup - -##### Option A: DynamoDB Console -1. Navigate to DynamoDB Console → Backups -2. Find the most recent backup for the deleted table -3. Note the backup name and creation time - -##### Option B: AWS Backup Console -1. Navigate to AWS Backup Console → Backup Vaults -2. Find the most recent recovery point for the deleted table -3. **CRITICAL**: Note the "Original table name" from the recovery point details - -#### Step 2: Restore Table from Backup - -1. **From DynamoDB Console**: - - Go to DynamoDB → Backups - - Select the backup → "Restore" - - **CRITICAL Configuration**: - - **Table Name**: Must match EXACTLY the original deleted table name - - **Encryption**: Select "Customer managed key" - - **KMS Key**: Choose `-PersistentStack-shared-encryption-key` for non-ssn tables, `ssn-key` for the SSN table - - Example: `Prod-PersistentStack-shared-encryption-key` - - **Global Secondary Indexes (GSIs)**: Ensure ALL original GSIs are included in the restore by selecting 'Restore the entire table' - - Select 'Restore' - -2. **From AWS Backup Console**: - - Navigate to Recovery Points → Select the backup - - Click "Restore" - - **CRITICAL Configuration**: - - **New Table Name**: Use the EXACT "Original table name" from the recovery point - - **Encryption**: Choose an AWS KMS key -> `-PersistentStack-shared-encryption-key` for non-ssn tables, `ssn-key` for the SSN table - - **GSIs**: Verify all original GSIs are restored - - Select 'Restore Backup' - -#### Step 3: Verify Restoration - -1. **Table Configuration**: - - ✅ Table name matches exactly (including environment prefix and suffix) - - ✅ All Global Secondary Indexes are present - - ✅ Encryption is set to the correct KMS key - - ✅ Table status is "ACTIVE" - -2. **Data Verification**: - - Spot-check critical records - - Verify record counts are reasonable - - Verify application functionality with the restored table diff --git a/backend/compact-connect/lambdas/python/common/cc_common/config.py b/backend/compact-connect/lambdas/python/common/cc_common/config.py index 4a790faca..79fd69b75 100644 --- a/backend/compact-connect/lambdas/python/common/cc_common/config.py +++ b/backend/compact-connect/lambdas/python/common/cc_common/config.py @@ -187,6 +187,10 @@ def bulk_bucket_name(self): def provider_user_bucket_name(self): return os.environ['PROVIDER_USER_BUCKET_NAME'] + @property + def disaster_recovery_results_bucket_name(self): + return os.environ['DISASTER_RECOVERY_RESULTS_BUCKET_NAME'] + @property def user_pool_id(self): """ @@ -213,6 +217,10 @@ def users_table_name(self): def fam_giv_index_name(self): return os.environ['FAM_GIV_INDEX_NAME'] + @property + def license_upload_date_index_name(self): + return os.environ['LICENSE_UPLOAD_DATE_INDEX_NAME'] + @property def expiration_resolution_timezone(self): return timezone(offset=timedelta(hours=-4)) diff --git a/backend/compact-connect/lambdas/python/common/cc_common/data_model/data_client.py b/backend/compact-connect/lambdas/python/common/cc_common/data_model/data_client.py index 4cb052a8e..3f5188ef7 100644 --- a/backend/compact-connect/lambdas/python/common/cc_common/data_model/data_client.py +++ b/backend/compact-connect/lambdas/python/common/cc_common/data_model/data_client.py @@ -41,6 +41,7 @@ from cc_common.data_model.schema.privilege import PrivilegeData, PrivilegeUpdateData from cc_common.data_model.schema.privilege.record import PrivilegeUpdateRecordSchema from cc_common.data_model.schema.provider import ProviderData, ProviderUpdateData +from cc_common.data_model.update_tier_enum import UpdateTierEnum from cc_common.exceptions import ( CCAwsServiceException, CCInternalException, @@ -48,7 +49,7 @@ CCNotFoundException, ) from cc_common.license_util import LicenseUtility -from cc_common.utils import load_records_into_schemas, logger_inject_kwargs +from cc_common.utils import logger_inject_kwargs class DataClient: @@ -179,9 +180,22 @@ def get_provider_user_records( compact: str, provider_id: UUID, consistent_read: bool = True, + include_update_tier: UpdateTierEnum | None = None, ) -> ProviderUserRecords: logger.info('Getting provider') + # Determine SK condition based on include_update_tier parameter + # When include_update_tier=None, use begins_with to get only main records (provider, licenses, privileges) + # When include_update_tier is set, use lt (less than) to get main records plus updates up to that tier + if include_update_tier is None: + # Get only main records: {compact}#PROVIDER prefix + sk_condition = Key('sk').begins_with(f'{compact}#PROVIDER') + else: + # Get main records and updates up to specified tier using lt (less than) + # This fetches all SKs less than {compact}#UPDATE#{next_tier} + next_tier = int(include_update_tier) + 1 + sk_condition = Key('sk').lt(f'{compact}#UPDATE#{next_tier}') + resp = {'Items': []} last_evaluated_key = None @@ -190,8 +204,7 @@ def get_provider_user_records( query_resp = self.config.provider_table.query( Select='ALL_ATTRIBUTES', - KeyConditionExpression=Key('pk').eq(f'{compact}#PROVIDER#{provider_id}') - & Key('sk').begins_with(f'{compact}#PROVIDER'), + KeyConditionExpression=Key('pk').eq(f'{compact}#PROVIDER#{provider_id}') & sk_condition, ConsistentRead=consistent_read, **pagination, ) @@ -912,6 +925,7 @@ def process_registration_values( } # Create provider update record to show registration event and fields that were updated + now = config.current_standard_datetime provider_update_record = ProviderUpdateData.create_new( { 'type': ProviderRecordType.PROVIDER_UPDATE, @@ -919,6 +933,7 @@ def process_registration_values( 'providerId': matched_license_record.providerId, 'compact': matched_license_record.compact, 'previous': current_provider_record.to_dict(), + 'createDate': now, 'updatedValues': {**registration_values}, } ) @@ -953,8 +968,7 @@ def process_registration_values( ] ) - @logger_inject_kwargs(logger, 'compact', 'provider_id', 'detail', 'jurisdiction', 'license_type') - def get_privilege_data( + def _get_privilege_record_directly( self, *, compact: str, @@ -962,61 +976,148 @@ def get_privilege_data( jurisdiction: str, license_type_abbr: str, consistent_read: bool = False, - detail: bool = False, - ) -> list[dict]: + ) -> PrivilegeData: """ - Get a privilege for a provider in a jurisdiction of the license type + Query for a single privilege record directly from DynamoDB. + + This should be used when it is undesirable to get all provider records and + filter for the specific privilege record. :param str compact: The compact of the privilege :param str provider_id: The provider of the privilege :param str jurisdiction: The jurisdiction of the privilege :param str license_type_abbr: The license type abbreviation of the privilege - :param bool detail: Boolean determining whether we include associated records or just privilege record itself + :param bool consistent_read: If true, performs a consistent read of the record :raises CCNotFoundException: If the privilege record is not found - :return If detail = False list of length one containing privilege item, if detail = True list containing, - privilege record, privilege update records and privilege adverse action records + :return: The privilege record as PrivilegeData """ - # Get the privilege record - if detail: - sk_condition = Key('sk').begins_with(f'{compact}#PROVIDER#privilege/{jurisdiction}/{license_type_abbr}#') - else: - sk_condition = Key('sk').eq(f'{compact}#PROVIDER#privilege/{jurisdiction}/{license_type_abbr}#') + pk = f'{compact}#PROVIDER#{provider_id}' + sk = f'{compact}#PROVIDER#privilege/{jurisdiction}/{license_type_abbr}#' - resp = self.config.provider_table.query( - Select='ALL_ATTRIBUTES', - KeyConditionExpression=Key('pk').eq(f'{compact}#PROVIDER#{provider_id}') & sk_condition, - ConsistentRead=consistent_read, - ) - if not resp['Items'] or not len(resp['Items']): - raise CCNotFoundException('Privilege not found') + try: + response = self.config.provider_table.get_item( + Key={'pk': pk, 'sk': sk}, + ConsistentRead=consistent_read, + ) + if 'Item' not in response: + raise CCNotFoundException('Privilege not found') - return load_records_into_schemas(resp['Items']) + return PrivilegeData.from_database_record(response['Item']) + except KeyError as e: + raise CCNotFoundException('Privilege not found') from e - @logger_inject_kwargs(logger, 'compact', 'provider_id', 'jurisdiction', 'license_type') - def get_privilege(self, *, compact: str, provider_id: str, jurisdiction: str, license_type_abbr: str) -> dict: + def _get_privilege_update_records_directly( + self, + *, + compact: str, + provider_id: str, + jurisdiction: str, + license_type_abbr: str, + consistent_read: bool = False, + ) -> list[PrivilegeUpdateData]: """ - Get a privilege for a provider in a jurisdiction of the license type + Query for all privilege update records for a specific privilege directly from DynamoDB. + + This should be used when it is undesirable to get all provider update records and + filter for the specific privilege update records. + + During migration period, this method queries both the new and old SK patterns to ensure + no records are missed. :param str compact: The compact of the privilege :param str provider_id: The provider of the privilege :param str jurisdiction: The jurisdiction of the privilege :param str license_type_abbr: The license type abbreviation of the privilege + :param bool consistent_read: If true, performs a consistent read of the records + :return: List of privilege update records + """ + pk = f'{compact}#PROVIDER#{provider_id}' + + # SK prefixes to query (new pattern and old pattern for migration support) + # TODO - remove old pattern once migration is complete # noqa: FIX002 + sk_prefixes = [ + # New pattern + f'{compact}#UPDATE#{UpdateTierEnum.TIER_ONE}#privilege/{jurisdiction}/{license_type_abbr}/', + # Old pattern + f'{compact}#PROVIDER#privilege/{jurisdiction}/{license_type_abbr}#UPDATE', + ] + + response_items = [] + + # Query for records using each SK prefix pattern + for sk_prefix in sk_prefixes: + last_evaluated_key = None + while True: + pagination = {'ExclusiveStartKey': last_evaluated_key} if last_evaluated_key else {} + + query_resp = self.config.provider_table.query( + Select='ALL_ATTRIBUTES', + KeyConditionExpression=Key('pk').eq(pk) & Key('sk').begins_with(sk_prefix), + ConsistentRead=consistent_read, + **pagination, + ) + + response_items.extend(query_resp.get('Items', [])) + + last_evaluated_key = query_resp.get('LastEvaluatedKey') + if not last_evaluated_key: + break + + return [PrivilegeUpdateData.from_database_record(item) for item in response_items] + + @logger_inject_kwargs(logger, 'compact', 'provider_id', 'detail', 'jurisdiction', 'license_type_abbr') + def get_privilege_data( + self, + *, + compact: str, + provider_id: str, + jurisdiction: str, + license_type_abbr: str, + consistent_read: bool = False, + detail: bool = False, + ) -> list[dict]: + """ + Get a privilege for a provider in a jurisdiction of the license type. + + This should be used when it is undesirable to pull all provider records and + filter for the specific privilege record and associated update records. + + :param str compact: The compact of the privilege + :param str provider_id: The provider of the privilege + :param str jurisdiction: The jurisdiction of the privilege + :param str license_type_abbr: The license type abbreviation of the privilege + :param bool consistent_read: If true, performs a consistent read of the records + :param bool detail: Boolean determining whether we include associated records or just privilege record itself :raises CCNotFoundException: If the privilege record is not found + :return If detail = False list of length one containing privilege item, if detail = True list containing, + privilege record and privilege update records """ - # Get the privilege record - try: - privilege_record = self.config.provider_table.get_item( - Key={ - 'pk': f'{compact}#PROVIDER#{provider_id}', - 'sk': f'{compact}#PROVIDER#privilege/{jurisdiction}/{license_type_abbr}#', - }, - )['Item'] - except KeyError as e: - raise CCNotFoundException(f'Privilege not found for jurisdiction {jurisdiction}') from e + # Query directly for the privilege record + privilege = self._get_privilege_record_directly( + compact=compact, + provider_id=provider_id, + jurisdiction=jurisdiction, + license_type_abbr=license_type_abbr, + consistent_read=consistent_read, + ) - return privilege_record + # Build return list in the same format as before + result = [privilege.to_dict()] - @logger_inject_kwargs(logger, 'compact', 'provider_id', 'jurisdiction', 'license_type') + if detail: + # Query directly for privilege update records + privilege_updates = self._get_privilege_update_records_directly( + compact=compact, + provider_id=provider_id, + jurisdiction=jurisdiction, + license_type_abbr=license_type_abbr, + consistent_read=consistent_read, + ) + result.extend([update.to_dict() for update in privilege_updates]) + + return result + + @logger_inject_kwargs(logger, 'compact', 'provider_id', 'jurisdiction', 'license_type_abbr') def deactivate_privilege( self, *, compact: str, provider_id: str, jurisdiction: str, license_type_abbr: str, deactivation_details: dict ) -> None: @@ -1052,7 +1153,7 @@ def deactivate_privilege( privilege_update_record = PrivilegeUpdateRecordSchema().dump( { 'type': ProviderRecordType.PRIVILEGE_UPDATE, - 'updateType': 'deactivation', + 'updateType': UpdateCategory.DEACTIVATION, 'providerId': provider_id, 'compact': compact, 'jurisdiction': jurisdiction, @@ -2062,7 +2163,7 @@ def lift_license_encumbrance( and potentially updating the license record's encumbered status. :param str compact: The compact name - :param str provider_id: The provider ID + :param UUID provider_id: The provider ID :param str jurisdiction: The jurisdiction :param str license_type_abbreviation: The license type abbreviation :param UUID adverse_action_id: The adverse action ID to lift @@ -2533,6 +2634,7 @@ def _get_provider_record_transaction_items_for_jurisdiction_with_no_known_licens ) # Create the provider update record + now = config.current_standard_datetime provider_update_record = ProviderUpdateData.create_new( { 'type': ProviderRecordType.PROVIDER_UPDATE, @@ -2540,6 +2642,7 @@ def _get_provider_record_transaction_items_for_jurisdiction_with_no_known_licens 'providerId': provider_id, 'compact': compact, 'previous': provider_record.to_dict(), + 'createDate': now, 'updatedValues': { 'currentHomeJurisdiction': selected_jurisdiction, }, @@ -2693,6 +2796,7 @@ def _get_provider_record_transaction_items_for_jurisdiction_change_with_license( ) # Create the provider update record + now = config.current_standard_datetime provider_update_record = ProviderUpdateData.create_new( { 'type': ProviderRecordType.PROVIDER_UPDATE, @@ -2700,6 +2804,7 @@ def _get_provider_record_transaction_items_for_jurisdiction_change_with_license( 'providerId': provider_id, 'compact': compact, 'previous': provider_records.get_provider_record().to_dict(), + 'createDate': now, 'updatedValues': { 'licenseJurisdiction': new_license_record.jurisdiction, # we explicitly set this to align with what was passed in as the selected jurisdiction @@ -3386,6 +3491,7 @@ def complete_provider_email_update( current_provider_record = self.get_provider_top_level_record(compact=compact, provider_id=provider_id) # Create provider update record to track the email change + now = config.current_standard_datetime provider_update_record = ProviderUpdateData.create_new( { 'type': ProviderRecordType.PROVIDER_UPDATE, @@ -3393,6 +3499,7 @@ def complete_provider_email_update( 'providerId': provider_id, 'compact': compact, 'previous': current_provider_record.to_dict(), + 'createDate': now, 'updatedValues': { 'compactConnectRegisteredEmailAddress': new_email_address, }, diff --git a/backend/compact-connect/lambdas/python/common/cc_common/data_model/provider_record_util.py b/backend/compact-connect/lambdas/python/common/cc_common/data_model/provider_record_util.py index 6cf5b5c72..998ce5397 100644 --- a/backend/compact-connect/lambdas/python/common/cc_common/data_model/provider_record_util.py +++ b/backend/compact-connect/lambdas/python/common/cc_common/data_model/provider_record_util.py @@ -23,7 +23,7 @@ from cc_common.data_model.schema.privilege import PrivilegeData, PrivilegeUpdateData from cc_common.data_model.schema.privilege.api import PrivilegeHistoryResponseSchema from cc_common.data_model.schema.provider import ProviderData, ProviderUpdateData -from cc_common.exceptions import CCInternalException +from cc_common.exceptions import CCInternalException, CCNotFoundException class ProviderRecordType(StrEnum): @@ -500,6 +500,27 @@ def get_privilege_records( """ return [record for record in self._privilege_records if filter_condition is None or filter_condition(record)] + def get_privileges_associated_with_license( + self, + license_jurisdiction: str, + license_type_abbreviation: str, + filter_condition: Callable[[PrivilegeData], bool] | None = None, + ) -> list[PrivilegeData]: + """ + Get all privileges associated with a given license. + :param license_jurisdiction: The jurisdiction of the license. + :param license_type_abbreviation: The abbreviation of the license type. + :param filter_condition: An optional filter to apply to the privilege records + :return: A list of privilege records associated with the license + """ + return [ + record + for record in self._privilege_records + if record.licenseJurisdiction == license_jurisdiction + and record.licenseTypeAbbreviation == license_type_abbreviation + and (filter_condition is None or filter_condition(record)) + ] + def get_license_records( self, filter_condition: Callable[[LicenseData], bool] | None = None, @@ -739,7 +760,7 @@ def find_best_license_in_current_known_licenses(self, jurisdiction: str | None = # Last issued inactive license, otherwise latest_licenses = sorted(license_records, key=lambda x: x.dateOfIssuance.isoformat(), reverse=True) if not latest_licenses: - raise CCInternalException('No licenses found') + raise CCNotFoundException('No licenses found') return latest_licenses[0] @@ -758,6 +779,45 @@ def get_latest_military_affiliation_status(self) -> str | None: return latest_military_affiliation.status + def get_all_license_update_records( + self, + filter_condition: Callable[[LicenseUpdateData], bool] | None = None, + ) -> list[LicenseUpdateData]: + """ + Get all license update records for this provider. + :param filter_condition: An optional filter to apply to the update records + :return: List of LicenseUpdateData records + """ + return [ + record for record in self._license_update_records if filter_condition is None or filter_condition(record) + ] + + def get_all_privilege_update_records( + self, + filter_condition: Callable[[PrivilegeUpdateData], bool] | None = None, + ) -> list[PrivilegeUpdateData]: + """ + Get all privilege update records for this provider. + :param filter_condition: An optional filter to apply to the update records + :return: List of PrivilegeUpdateData records + """ + return [ + record for record in self._privilege_update_records if filter_condition is None or filter_condition(record) + ] + + def get_all_provider_update_records( + self, + filter_condition: Callable[[ProviderUpdateData], bool] | None = None, + ) -> list[ProviderUpdateData]: + """ + Get all provider update records for this provider. + :param filter_condition: An optional filter to apply to the update records + :return: List of ProviderUpdateData records + """ + return [ + record for record in self._provider_update_records if filter_condition is None or filter_condition(record) + ] + def get_update_records_for_license( self, jurisdiction: str, diff --git a/backend/compact-connect/lambdas/python/common/cc_common/data_model/schema/common.py b/backend/compact-connect/lambdas/python/common/cc_common/data_model/schema/common.py index ced223d8a..ae058c729 100644 --- a/backend/compact-connect/lambdas/python/common/cc_common/data_model/schema/common.py +++ b/backend/compact-connect/lambdas/python/common/cc_common/data_model/schema/common.py @@ -312,6 +312,14 @@ class UpdateCategory(CCEnum): LICENSE_UPLOAD_UPDATE_OTHER = 'other' +# License upload related update categories +LICENSE_UPLOAD_UPDATE_CATEGORIES = { + UpdateCategory.DEACTIVATION, + UpdateCategory.RENEWAL, + UpdateCategory.LICENSE_UPLOAD_UPDATE_OTHER, +} + + class ActiveInactiveStatus(CCEnum): ACTIVE = 'active' INACTIVE = 'inactive' diff --git a/backend/compact-connect/lambdas/python/common/cc_common/data_model/schema/data_event/api.py b/backend/compact-connect/lambdas/python/common/cc_common/data_model/schema/data_event/api.py index 6549399e7..ff73f21f3 100644 --- a/backend/compact-connect/lambdas/python/common/cc_common/data_model/schema/data_event/api.py +++ b/backend/compact-connect/lambdas/python/common/cc_common/data_model/schema/data_event/api.py @@ -66,3 +66,21 @@ class InvestigationEventDetailSchema(DataEventDetailBaseSchema): class LicenseDeactivationDetailSchema(DataEventDetailBaseSchema): providerId = UUID(required=True, allow_none=False) licenseType = String(required=True, allow_none=False) + + +class LicenseRevertDetailSchema(DataEventDetailBaseSchema): + providerId = UUID(required=True, allow_none=False) + licenseType = String(required=True, allow_none=False) + rollbackReason = String(required=True, allow_none=False) + startTime = DateTime(required=True, allow_none=False) + endTime = DateTime(required=True, allow_none=False) + rollbackExecutionName = String(required=True, allow_none=False) + + +class PrivilegeRevertDetailSchema(DataEventDetailBaseSchema): + providerId = UUID(required=True, allow_none=False) + licenseType = String(required=True, allow_none=False) + rollbackReason = String(required=True, allow_none=False) + startTime = DateTime(required=True, allow_none=False) + endTime = DateTime(required=True, allow_none=False) + rollbackExecutionName = String(required=True, allow_none=False) diff --git a/backend/compact-connect/lambdas/python/common/cc_common/data_model/schema/license/__init__.py b/backend/compact-connect/lambdas/python/common/cc_common/data_model/schema/license/__init__.py index 5672fd071..9da8acc5e 100644 --- a/backend/compact-connect/lambdas/python/common/cc_common/data_model/schema/license/__init__.py +++ b/backend/compact-connect/lambdas/python/common/cc_common/data_model/schema/license/__init__.py @@ -140,6 +140,10 @@ def encumberedStatus(self) -> str | None: def investigationStatus(self) -> str | None: return self._data.get('investigationStatus') + @property + def firstUploadDate(self) -> datetime | None: + return self._data.get('firstUploadDate') + class LicenseUpdateData(CCDataClass): """ diff --git a/backend/compact-connect/lambdas/python/common/cc_common/data_model/schema/license/record.py b/backend/compact-connect/lambdas/python/common/cc_common/data_model/schema/license/record.py index 9af6ea97e..1cbbed27c 100644 --- a/backend/compact-connect/lambdas/python/common/cc_common/data_model/schema/license/record.py +++ b/backend/compact-connect/lambdas/python/common/cc_common/data_model/schema/license/record.py @@ -12,6 +12,7 @@ ForgivingSchema, ) from cc_common.data_model.schema.common import ( + LICENSE_UPLOAD_UPDATE_CATEGORIES, ActiveInactiveStatus, ChangeHashMixin, CompactEligibilityStatus, @@ -31,6 +32,7 @@ ) from cc_common.data_model.schema.investigation.record import InvestigationDetailsSchema from cc_common.data_model.schema.license.common import LicenseCommonSchema +from cc_common.data_model.update_tier_enum import UpdateTierEnum @BaseRecordSchema.register_schema('license') @@ -47,6 +49,13 @@ class LicenseRecordSchema(BaseRecordSchema, LicenseCommonSchema): providerId = UUID(required=True, allow_none=False) licenseGSIPK = String(required=True, allow_none=False) licenseGSISK = String(required=True, allow_none=False) + licenseUploadDateGSIPK = String(required=False, allow_none=False) + licenseUploadDateGSISK = String(required=False, allow_none=False) + + # Optional field for tracking the first license upload that caused this record to be created + # Note that records which were uploaded before this field was supported will not have this included + # and will not be included in the license upload date GSI + firstUploadDate = DateTime(required=False, allow_none=False) # Provided fields npi = NationalProviderIdentifier(required=False, allow_none=False) @@ -124,12 +133,34 @@ def generate_license_gsi_fields(self, in_data, **kwargs): # noqa: ARG001 unused in_data['licenseGSISK'] = f'FN#{quote(in_data["familyName"].lower())}#GN#{quote(in_data["givenName"].lower())}' return in_data + @pre_dump + def generate_license_upload_date_gsi_fields(self, in_data, **kwargs): # noqa: ARG001 unused-argument + """Generate GSI fields for license upload date tracking (only if firstUploadDate is present)""" + if 'firstUploadDate' in in_data and in_data['firstUploadDate'] is not None: + # Extract YYYY-MM from firstUploadDate + upload_date = in_data['firstUploadDate'] + year_month = upload_date.strftime('%Y-%m') + + # Generate GSI PK: C#{compact}#J#{jurisdiction}#D#{YYYY-MM} + in_data['licenseUploadDateGSIPK'] = ( + f'C#{in_data["compact"].lower()}#J#{in_data["jurisdiction"].lower()}#D#{year_month}' + ) + # Generate GSI SK: TIME#{epoch_timestamp}#LT#{licenseType}#PID#{providerId} + upload_epoch_time = int(upload_date.timestamp()) + license_type_abbr = config.license_type_abbreviations[in_data['compact']][in_data['licenseType']] + in_data['licenseUploadDateGSISK'] = ( + f'TIME#{upload_epoch_time}#LT#{license_type_abbr}#PID#{in_data["providerId"]}' + ) + return in_data + @post_load def drop_license_gsi_fields(self, in_data, **kwargs): # noqa: ARG001 unused-argument """Drop the db-specific license GSI fields before returning loaded data""" # only drop the field if it's present, else continue on in_data.pop('licenseGSIPK', None) in_data.pop('licenseGSISK', None) + in_data.pop('licenseUploadDateGSIPK', None) + in_data.pop('licenseUploadDateGSISK', None) return in_data @@ -199,6 +230,10 @@ class LicenseUpdateRecordSchema(BaseRecordSchema, ChangeHashMixin): # List of field names that were present in the previous record but removed in the update removedValues = List(String(), required=False, allow_none=False) + # Optional GSI fields for license upload date tracking + licenseUploadDateGSIPK = String(required=False, allow_none=False) + licenseUploadDateGSISK = String(required=False, allow_none=False) + @post_dump # Must be _post_ dump so we have values that are more easily hashed def generate_pk_sk(self, in_data, **kwargs): # noqa: ARG001 unused-argument """ @@ -209,16 +244,47 @@ def generate_pk_sk(self, in_data, **kwargs): # noqa: ARG001 unused-argument served out via API. """ in_data['pk'] = f'{in_data["compact"]}#PROVIDER#{in_data["providerId"]}' - # This needs to include a POSIX timestamp (seconds) and a hash of the changes - # to the record. We'll use the current time and the hash of the updatedValues + # This needs to include an iso formatted datetime string and a hash of the changes + # to the record. We'll use the createDate and the hash of the updatedValues # field for this. change_hash = self.hash_changes(in_data) license_type_abbr = config.license_type_abbreviations[in_data['compact']][in_data['licenseType']] in_data['sk'] = ( - f'{in_data["compact"]}#PROVIDER#license/{in_data["jurisdiction"]}/{license_type_abbr}#UPDATE#{int(config.current_standard_datetime.timestamp())}/{change_hash}' + f'{in_data["compact"]}#UPDATE#{UpdateTierEnum.TIER_THREE}#license/{in_data["jurisdiction"]}/{license_type_abbr}/{in_data["createDate"]}/{change_hash}' ) return in_data + @pre_dump + def generate_license_upload_date_gsi_fields(self, in_data, **kwargs): # noqa: ARG001 unused-argument + """Generate GSI fields for license upload date tracking""" + # If the update is related to an upload event, we generate the upload GSI fields to allow the system to + # query when certain uploads occurred + if in_data['updateType'] in LICENSE_UPLOAD_UPDATE_CATEGORIES: + # Extract YYYY-MM from createDate + upload_date = in_data['createDate'] + year_month = upload_date.strftime('%Y-%m') + + # Generate GSI PK: C#{compact}#J#{jurisdiction}#D#{YYYY-MM} + in_data['licenseUploadDateGSIPK'] = ( + f'C#{in_data["compact"].lower()}#J#{in_data["jurisdiction"].lower()}#D#{year_month}' + ) + + # Generate GSI SK: TIME#{epoch_timestamp}#LT#{licenseType}#PID#{providerId} + upload_epoch_time = int(upload_date.timestamp()) + license_type_abbr = config.license_type_abbreviations[in_data['compact']][in_data['licenseType']] + in_data['licenseUploadDateGSISK'] = ( + f'TIME#{upload_epoch_time}#LT#{license_type_abbr}#PID#{in_data["providerId"]}' + ) + return in_data + + @post_load + def drop_license_gsi_fields(self, in_data, **kwargs): # noqa: ARG001 unused-argument + """Drop the db-specific license GSI fields before returning loaded data""" + # only drop the field if it's present, else continue on + in_data.pop('licenseUploadDateGSIPK', None) + in_data.pop('licenseUploadDateGSISK', None) + return in_data + @validates_schema def validate_license_type(self, data, **kwargs): # noqa: ARG001 unused-argument license_types = config.license_types_for_compact(data['compact']) diff --git a/backend/compact-connect/lambdas/python/common/cc_common/data_model/schema/privilege/record.py b/backend/compact-connect/lambdas/python/common/cc_common/data_model/schema/privilege/record.py index fdc918357..cc6996389 100644 --- a/backend/compact-connect/lambdas/python/common/cc_common/data_model/schema/privilege/record.py +++ b/backend/compact-connect/lambdas/python/common/cc_common/data_model/schema/privilege/record.py @@ -27,6 +27,7 @@ UpdateType, ) from cc_common.data_model.schema.investigation.record import InvestigationDetailsSchema +from cc_common.data_model.update_tier_enum import UpdateTierEnum class AttestationVersionRecordSchema(Schema): @@ -234,13 +235,13 @@ class PrivilegeUpdateRecordSchema(BaseRecordSchema, ChangeHashMixin, ValidatesLi @post_dump # Must be _post_ dump so we have values that are more easily hashed def generate_pk_sk(self, in_data, **kwargs): # noqa: ARG001 unused-argument in_data['pk'] = f'{in_data["compact"]}#PROVIDER#{in_data["providerId"]}' - # This needs to include a POSIX timestamp (seconds) and a hash of the changes - # to the record. We'll use the current time and the hash of the updatedValues + # This needs to include an iso formatted datetime string and a hash of the changes + # to the record. We'll use the createDate and the hash of the updatedValues # field for this. change_hash = self.hash_changes(in_data) license_type_abbr = config.license_type_abbreviations[in_data['compact']][in_data['licenseType']] in_data['sk'] = ( - f'{in_data["compact"]}#PROVIDER#privilege/{in_data["jurisdiction"]}/{license_type_abbr}#UPDATE#{int(config.current_standard_datetime.timestamp())}/{change_hash}' + f'{in_data["compact"]}#UPDATE#{UpdateTierEnum.TIER_ONE}#privilege/{in_data["jurisdiction"]}/{license_type_abbr}/{in_data["createDate"]}/{change_hash}' ) return in_data diff --git a/backend/compact-connect/lambdas/python/common/cc_common/data_model/schema/provider/__init__.py b/backend/compact-connect/lambdas/python/common/cc_common/data_model/schema/provider/__init__.py index edab65d15..ebbf23e4b 100644 --- a/backend/compact-connect/lambdas/python/common/cc_common/data_model/schema/provider/__init__.py +++ b/backend/compact-connect/lambdas/python/common/cc_common/data_model/schema/provider/__init__.py @@ -175,6 +175,10 @@ def providerId(self) -> UUID: def compact(self) -> str: return self._data['compact'] + @property + def createDate(self) -> str: + return self._data['createDate'] + @property def previous(self) -> dict: return self._data['previous'] diff --git a/backend/compact-connect/lambdas/python/common/cc_common/data_model/schema/provider/record.py b/backend/compact-connect/lambdas/python/common/cc_common/data_model/schema/provider/record.py index be8b56992..2c4de0c77 100644 --- a/backend/compact-connect/lambdas/python/common/cc_common/data_model/schema/provider/record.py +++ b/backend/compact-connect/lambdas/python/common/cc_common/data_model/schema/provider/record.py @@ -26,6 +26,7 @@ Set, UpdateType, ) +from cc_common.data_model.update_tier_enum import UpdateTierEnum @BaseRecordSchema.register_schema('provider') @@ -225,19 +226,35 @@ class ProviderUpdateRecordSchema(BaseRecordSchema, ChangeHashMixin): providerId = UUID(required=True, allow_none=False) compact = Compact(required=True, allow_none=False) previous = Nested(ProviderUpdatePreviousRecordSchema, required=True, allow_none=False) + # this tracks when the update record was created + createDate = DateTime(required=True, allow_none=False) # We'll allow any fields that can show up in the previous field to be here as well, but none are required updatedValues = Nested(ProviderUpdatePreviousRecordSchema(partial=True), required=True, allow_none=False) # List of field names that were present in the previous record but removed in the update removedValues = List(String(), required=False, allow_none=False) + # TODO - remove this pre_load hook after migration is complete # noqa: FIX002 + @pre_load + def populate_create_date_for_backwards_compatibility(self, in_data, **kwargs): # noqa: ARG001 unused-argument + """ + For backwards compatibility, populate createDate from dateOfUpdate if createDate is missing. + This allows us to load old records that were created before the createDate field was added. + """ + if 'createDate' not in in_data: + in_data['createDate'] = in_data['dateOfUpdate'] + return in_data + @post_dump # Must be _post_ dump so we have values that are more easily hashed def generate_pk_sk(self, in_data, **kwargs): # noqa: ARG001 unused-argument in_data['pk'] = f'{in_data["compact"]}#PROVIDER#{in_data["providerId"]}' - # This needs to include a POSIX timestamp (seconds) and a hash of the changes - # to the record. We'll use the current time and the hash of the updatedValues + # This needs to include an iso formatted datetime string and a hash of the changes + # to the record. We'll use the createDate and the hash of the updatedValues # field for this. + # Provider update records are considered a tier 2 update. Privilege updates are tier 1 because they are accessed + # most frequently. Provider update records are not generated often, so it is more performant to place them at + # tier 2, with license updates being the last tier 3. change_hash = self.hash_changes(in_data) in_data['sk'] = ( - f'{in_data["compact"]}#PROVIDER#UPDATE#{int(config.current_standard_datetime.timestamp())}/{change_hash}' + f'{in_data["compact"]}#UPDATE#{UpdateTierEnum.TIER_TWO}#provider/{in_data["createDate"]}/{change_hash}' ) return in_data diff --git a/backend/compact-connect/lambdas/python/common/cc_common/data_model/transaction_client.py b/backend/compact-connect/lambdas/python/common/cc_common/data_model/transaction_client.py index 275a230b2..0ef086eb2 100644 --- a/backend/compact-connect/lambdas/python/common/cc_common/data_model/transaction_client.py +++ b/backend/compact-connect/lambdas/python/common/cc_common/data_model/transaction_client.py @@ -377,7 +377,7 @@ def reconcile_unsettled_transactions(self, compact: str, settled_transactions: l if unmatched_settled_transaction_ids: logger.error( 'Unable to reconcile some transactions from Authorize.Net with our unsettled transactions', - unreconciled_transactions=unmatched_settled_transaction_ids + unreconciled_transactions=unmatched_settled_transaction_ids, ) for unsettled_tx in unmatched_unsettled: diff --git a/backend/compact-connect/lambdas/python/common/cc_common/data_model/update_tier_enum.py b/backend/compact-connect/lambdas/python/common/cc_common/data_model/update_tier_enum.py new file mode 100644 index 000000000..a05fd8fe2 --- /dev/null +++ b/backend/compact-connect/lambdas/python/common/cc_common/data_model/update_tier_enum.py @@ -0,0 +1,40 @@ +from enum import StrEnum + + +class UpdateTierEnum(StrEnum): + """ + Enum for update record tiers in the sort key hierarchy. + + DynamoDB sort keys are treated as numeric values, even if the key is a string. + This means we can perform comparison operations on string sort keys, such as less than (lt) + and grab records within a certain range. + + To reduce risk that massive invalid updates from a jurisdiction will cause the system to crash + when loading provider data, we migrated the sort keys of our update records to follow this + tier based pattern, which will allow us to query for update records only as needed. + + Update records are organized into tiers to enable efficient range queries. + Because all the primary provider records are prefixed under a common `{compact}#PROVIDER` prefix, + which is lexicographically less than the `{compact}#UPDATE` prefix, using the lt condition with the + UPDATE prefix will grab all the update records up to the specified tier and all primary records under + the PROVIDER prefix. + + Tier structure in sort keys: + - Tier 1: {compact}#UPDATE#1#privilege/... (Privilege updates) + - Tier 2: {compact}#UPDATE#2#provider/... (Provider updates) + - Tier 3: {compact}#UPDATE#3#license/... (License updates) + + Query patterns: + - TIER_ONE: Fetches privilege updates only + Query: Key('sk').lt('{compact}#UPDATE#2') + + - TIER_TWO: Fetches privilege + provider updates + Query: Key('sk').lt('{compact}#UPDATE#3') + + - TIER_THREE: Fetches all updates (privilege + provider + license) + Query: Key('sk').lt('{compact}#UPDATE#4') + """ + + TIER_ONE = '1' # Privilege updates only + TIER_TWO = '2' # Privilege + Provider updates + TIER_THREE = '3' # All updates (Privilege + Provider + License) diff --git a/backend/compact-connect/lambdas/python/common/cc_common/event_bus_client.py b/backend/compact-connect/lambdas/python/common/cc_common/event_bus_client.py index 88ae6307d..fe6af52b7 100644 --- a/backend/compact-connect/lambdas/python/common/cc_common/event_bus_client.py +++ b/backend/compact-connect/lambdas/python/common/cc_common/event_bus_client.py @@ -2,15 +2,19 @@ from datetime import date, datetime from uuid import UUID +from marshmallow import ValidationError + from cc_common.config import config from cc_common.data_model.schema.common import InvestigationAgainstEnum from cc_common.data_model.schema.data_event.api import ( EncumbranceEventDetailSchema, InvestigationEventDetailSchema, LicenseDeactivationDetailSchema, + LicenseRevertDetailSchema, PrivilegeIssuanceDetailSchema, PrivilegePurchaseEventDetailSchema, PrivilegeRenewalDetailSchema, + PrivilegeRevertDetailSchema, ) from cc_common.event_batch_writer import EventBatchWriter from cc_common.utils import ResponseEncoder @@ -437,3 +441,107 @@ def publish_investigation_closed_event( detail=deserialized_detail, event_batch_writer=event_batch_writer, ) + + def publish_license_revert_event( + self, + source: str, + compact: str, + provider_id: str, + jurisdiction: str, + license_type: str, + rollback_reason: str, + start_time: datetime, + end_time: datetime, + execution_name: str, + event_batch_writer: EventBatchWriter | None = None, + ): + """ + Publish a license revert event to the event bus. + + :param source: The source of the event + :param compact: The compact name + :param provider_id: The provider ID + :param jurisdiction: The jurisdiction of the license. + :param license_type: The license type. + :param rollback_reason: The reason for the rollback + :param start_time: The start time of the rollback window + :param end_time: The end time of the rollback window + :param execution_name: The execution name for the rollback operation + :param event_batch_writer: Optional EventBatchWriter for efficient batch publishing + """ + event_detail = { + 'compact': compact, + 'providerId': provider_id, + 'jurisdiction': jurisdiction, + 'licenseType': license_type, + 'rollbackReason': rollback_reason, + 'startTime': start_time, + 'endTime': end_time, + 'rollbackExecutionName': execution_name, + 'eventTime': config.current_standard_datetime, + } + + license_revert_detail_schema = LicenseRevertDetailSchema() + deserialized_detail = license_revert_detail_schema.dump(event_detail) + validation_errors = license_revert_detail_schema.validate(deserialized_detail) + if validation_errors: + raise ValidationError(message=validation_errors) + + self._publish_event( + source=source, + detail_type='license.revert', + detail=deserialized_detail, + event_batch_writer=event_batch_writer, + ) + + def publish_privilege_revert_event( + self, + source: str, + compact: str, + provider_id: str, + jurisdiction: str, + license_type: str, + rollback_reason: str, + start_time: datetime, + end_time: datetime, + execution_name: str, + event_batch_writer: EventBatchWriter | None = None, + ): + """ + Publish a privilege revert event to the event bus. + + :param source: The source of the event + :param compact: The compact name + :param provider_id: The provider ID + :param jurisdiction: The jurisdiction of the privilege + :param license_type: The license type + :param rollback_reason: The reason for the rollback + :param start_time: The start time of the rollback window + :param end_time: The end time of the rollback window + :param execution_name: The execution name for the rollback operation + :param event_batch_writer: Optional EventBatchWriter for efficient batch publishing + """ + event_detail = { + 'compact': compact, + 'providerId': provider_id, + 'jurisdiction': jurisdiction, + 'licenseType': license_type, + 'rollbackReason': rollback_reason, + 'startTime': start_time, + 'endTime': end_time, + 'rollbackExecutionName': execution_name, + 'eventTime': config.current_standard_datetime, + } + + privilege_revert_detail_schema = PrivilegeRevertDetailSchema() + deserialized_detail = privilege_revert_detail_schema.dump(event_detail) + validation_errors = privilege_revert_detail_schema.validate(deserialized_detail) + if validation_errors: + raise ValidationError(message=validation_errors) + + self._publish_event( + source=source, + detail_type='privilege.revert', + detail=deserialized_detail, + event_batch_writer=event_batch_writer, + ) diff --git a/backend/compact-connect/lambdas/python/common/common_test/test_constants.py b/backend/compact-connect/lambdas/python/common/common_test/test_constants.py index d88206c9b..2c8932c45 100644 --- a/backend/compact-connect/lambdas/python/common/common_test/test_constants.py +++ b/backend/compact-connect/lambdas/python/common/common_test/test_constants.py @@ -88,8 +88,10 @@ PRIVILEGE_RECORD_TYPE = 'privilege' PRIVILEGE_UPDATE_RECORD_TYPE = 'privilegeUpdate' PROVIDER_RECORD_TYPE = 'provider' +PROVIDER_UPDATE_RECORD_TYPE = 'providerUpdate' TRANSACTION_RECORD_TYPE = 'transaction' + # Privilege update default values DEFAULT_PRIVILEGE_UPDATE_TYPE = 'renewal' DEFAULT_PRIVILEGE_UPDATE_DATE_OF_UPDATE = '2020-05-05T12:59:59+00:00' @@ -106,6 +108,9 @@ DEFAULT_LICENSE_UPDATE_PREVIOUS_DATE_OF_EXPIRATION = '2020-06-06' DEFAULT_LICENSE_UPDATE_PREVIOUS_DATE_OF_RENEWAL = '2015-06-06' +# Provider update default values +DEFAULT_PROVIDER_UPDATE_TYPE = 'registration' + # Adverse Action defaults DEFAULT_ACTION_AGAINST_PRIVILEGE = 'privilege' DEFAULT_BLOCKS_FUTURE_PRIVILEGES = True diff --git a/backend/compact-connect/lambdas/python/common/common_test/test_data_generator.py b/backend/compact-connect/lambdas/python/common/common_test/test_data_generator.py index 883056a9f..b70e3dd05 100644 --- a/backend/compact-connect/lambdas/python/common/common_test/test_data_generator.py +++ b/backend/compact-connect/lambdas/python/common/common_test/test_data_generator.py @@ -14,7 +14,7 @@ from cc_common.data_model.schema.license import LicenseData, LicenseUpdateData from cc_common.data_model.schema.military_affiliation import MilitaryAffiliationData from cc_common.data_model.schema.privilege import PrivilegeData, PrivilegeUpdateData -from cc_common.data_model.schema.provider import ProviderData +from cc_common.data_model.schema.provider import ProviderData, ProviderUpdateData from cc_common.utils import ResponseEncoder from common_test.test_constants import * @@ -101,15 +101,15 @@ def query_privilege_update_records_for_given_record_from_database( ) -> list[PrivilegeUpdateData]: """ Helper method to query update records from the database using the provider data class instance. - - All of our update records use the same pk as the actual record that is being updated. The sk of the actual - record is the prefix for all the update records. Using this pattern, we can query for all of the update records - that have been written for the given record. """ serialized_record = privilege_data.serialize_to_database_record() + from cc_common.config import config + + license_type_abbr = config.license_type_abbreviations[privilege_data.compact][privilege_data.licenseType] + sk_prefix = f'{privilege_data.compact}#UPDATE#1#privilege/{privilege_data.jurisdiction}/{license_type_abbr}/' privilege_update_records = TestDataGenerator._query_records_by_pk_and_sk_prefix( - serialized_record['pk'], f'{serialized_record["sk"]}UPDATE' + serialized_record['pk'], sk_prefix ) return [PrivilegeUpdateData.from_database_record(update_record) for update_record in privilege_update_records] @@ -125,10 +125,32 @@ def query_provider_update_records_for_given_record_from_database(provider_record """ serialized_record = provider_record.serialize_to_database_record() - return TestDataGenerator._query_records_by_pk_and_sk_prefix( - serialized_record['pk'], f'{serialized_record["sk"]}#UPDATE' + sk_prefix = f'{provider_record.compact}#UPDATE#2#provider' + + return TestDataGenerator._query_records_by_pk_and_sk_prefix(serialized_record['pk'], sk_prefix) + + @staticmethod + def query_license_update_records_for_given_record_from_database( + license_data: LicenseData, + ) -> list[LicenseUpdateData]: + """ + Helper method to query update records from the database using the license data class instance. + + All of our update records use the same pk as the actual record that is being updated. The sk prefix + for license updates follows the tier pattern: {compact}#UPDATE#3#license/{jurisdiction}/{license_type_abbr}/ + """ + serialized_record = license_data.serialize_to_database_record() + from cc_common.config import config + + license_type_abbr = config.license_type_abbreviations[license_data.compact][license_data.licenseType] + sk_prefix = f'{license_data.compact}#UPDATE#3#license/{license_data.jurisdiction}/{license_type_abbr}/' + + license_update_records = TestDataGenerator._query_records_by_pk_and_sk_prefix( + serialized_record['pk'], sk_prefix ) + return [LicenseUpdateData.from_database_record(update_record) for update_record in license_update_records] + @staticmethod def generate_default_adverse_action(value_overrides: dict | None = None) -> AdverseActionData: """Generate a default adverse action""" @@ -307,6 +329,20 @@ def generate_default_license_update( return LicenseUpdateData.create_new(license_update) + @staticmethod + def put_default_license_update_record_in_provider_table( + value_overrides: dict | None = None, + ) -> LicenseUpdateData: + """ + Creates a default license update and stores it in the provider table. + """ + update_data = TestDataGenerator.generate_default_license_update(value_overrides) + update_record = update_data.serialize_to_database_record() + + TestDataGenerator.store_record_in_provider_table(update_record) + + return update_data + @staticmethod def generate_default_privilege(value_overrides: dict | None = None) -> PrivilegeData: """Generate a default privilege""" @@ -460,6 +496,58 @@ def put_default_provider_record_in_provider_table( return provider_data + @staticmethod + def generate_default_provider_update( + value_overrides: dict | None = None, previous_provider: ProviderData | None = None + ) -> ProviderUpdateData: + """Generate a default provider update""" + if previous_provider is None: + previous_provider = TestDataGenerator.generate_default_provider() + + # Ensure previous provider has dateOfUpdate for the previous field + previous_dict = previous_provider.to_dict() + if 'dateOfUpdate' not in previous_dict: + previous_dict['dateOfUpdate'] = datetime.fromisoformat(DEFAULT_PROVIDER_UPDATE_DATETIME) + + provider_update = { + 'updateType': DEFAULT_PROVIDER_UPDATE_TYPE, + 'providerId': DEFAULT_PROVIDER_ID, + 'compact': DEFAULT_COMPACT, + 'type': PROVIDER_UPDATE_RECORD_TYPE, + 'previous': previous_dict, + 'createDate': datetime.fromisoformat(DEFAULT_PROVIDER_UPDATE_DATETIME), + 'updatedValues': { + 'compactConnectRegisteredEmailAddress': DEFAULT_REGISTERED_EMAIL_ADDRESS, + 'currentHomeJurisdiction': DEFAULT_LICENSE_JURISDICTION, + }, + 'dateOfUpdate': datetime.fromisoformat(DEFAULT_PROVIDER_UPDATE_DATETIME), + } + if value_overrides: + provider_update.update(value_overrides) + + return ProviderUpdateData.create_new(provider_update) + + @staticmethod + def put_default_provider_update_record_in_provider_table( + value_overrides: dict | None = None, date_of_update_override: str = None + ) -> ProviderUpdateData: + """ + Creates a default provider update record and stores it in the provider table. + + :param value_overrides: Optional dictionary to override default values + :param date_of_update_override: optional date for date of update to be shown on provider record + :return: The ProviderUpdateData instance that was stored + """ + provider_update_data = TestDataGenerator.generate_default_provider_update(value_overrides) + provider_update_record = provider_update_data.serialize_to_database_record() + if date_of_update_override: + provider_update_record['dateOfUpdate'] = date_of_update_override + + TestDataGenerator.store_record_in_provider_table(provider_update_record) + + # recreate data object to ensure it picks up the dateOfUpdate change + return ProviderUpdateData.from_database_record(provider_update_record) + @staticmethod def _override_date_of_update_for_record(data_class: CCDataClass, date_of_update: datetime): # we have to access this here, as in runtime code dateOfUpdate is not to be modified diff --git a/backend/compact-connect/lambdas/python/common/tests/function/test_data_client.py b/backend/compact-connect/lambdas/python/common/tests/function/test_data_client.py index 1599702ce..39267d400 100644 --- a/backend/compact-connect/lambdas/python/common/tests/function/test_data_client.py +++ b/backend/compact-connect/lambdas/python/common/tests/function/test_data_client.py @@ -1,9 +1,10 @@ import json from datetime import UTC, date, datetime -from unittest.mock import patch +from unittest.mock import ANY, patch from uuid import UUID, uuid4 from boto3.dynamodb.conditions import Key +from cc_common.data_model.update_tier_enum import UpdateTierEnum from cc_common.exceptions import CCAwsServiceException, CCInvalidRequestException from common_test.test_constants import DEFAULT_PROVIDER_ID from moto import mock_aws @@ -227,6 +228,7 @@ def test_data_client_updates_privilege_records_for_specific_license_type(self): referenced nor updated in any way as part of this purchase. """ from cc_common.data_model.data_client import DataClient + from cc_common.data_model.provider_record_util import ProviderUserRecords from cc_common.data_model.schema.privilege import PrivilegeData # Imagine that there have been 123 privileges issued for the compact @@ -315,103 +317,118 @@ def test_data_client_updates_privilege_records_for_specific_license_type(self): ) # Verify that the audiologist privilege update record was created for ky - new_aud_ky_privilege = self._provider_table.query( - KeyConditionExpression=Key('pk').eq(f'aslp#PROVIDER#{provider_uuid}') - & Key('sk').begins_with('aslp#PROVIDER#privilege/ky/aud#'), - )['Items'] + provider_user_records: ProviderUserRecords = self.config.data_client.get_provider_user_records( + compact='aslp', provider_id=provider_uuid + ) + + new_aud_ky_privilege = provider_user_records.get_specific_privilege_record( + jurisdiction='ky', license_abbreviation='aud' + ) + self.assertEqual( - [ - # Primary record - { - 'pk': f'aslp#PROVIDER#{provider_uuid}', - 'sk': 'aslp#PROVIDER#privilege/ky/aud#', - 'type': 'privilege', - 'providerId': provider_uuid, - 'compact': 'aslp', - 'jurisdiction': 'ky', - 'licenseJurisdiction': 'oh', - 'licenseType': 'audiologist', - 'administratorSetStatus': 'active', - # Should be updated dates for renewal, expiration, update + # Primary record + { + 'pk': f'aslp#PROVIDER#{provider_uuid}', + 'sk': 'aslp#PROVIDER#privilege/ky/aud#', + 'type': 'privilege', + 'providerId': provider_uuid, + 'compact': 'aslp', + 'jurisdiction': 'ky', + 'licenseJurisdiction': 'oh', + 'licenseType': 'audiologist', + 'administratorSetStatus': 'active', + # Should be updated dates for renewal, expiration, update + 'dateOfIssuance': '2023-11-08T23:59:59+00:00', + 'dateOfRenewal': '2024-11-08T23:59:59+00:00', + 'dateOfExpiration': '2025-10-31', + 'dateOfUpdate': '2024-11-08T23:59:59+00:00', + 'compactTransactionId': 'test_transaction_id', + 'compactTransactionIdGSIPK': 'COMPACT#aslp#TX#test_transaction_id#', + 'attestations': self.sample_privilege_attestations, + # Should remain the same, since we're renewing the same privilege + 'privilegeId': 'AUD-KY-1', + }, + new_aud_ky_privilege.serialize_to_database_record(), + ) + + # Get update records using test_data_generator + update_records = self.test_data_generator.query_privilege_update_records_for_given_record_from_database( + new_aud_ky_privilege + ) + self.assertEqual(1, len(update_records)) + ky_aud_update_record = update_records[0] + + self.assertEqual( + # A new history record + { + 'pk': f'aslp#PROVIDER#{provider_uuid}', + 'sk': 'aslp#UPDATE#1#privilege/ky/aud/2024-11-08T23:59:59+00:00/f61e34798e1775ff6230d1187d444146', + 'type': 'privilegeUpdate', + 'updateType': 'renewal', + 'providerId': provider_uuid, + 'compact': 'aslp', + 'compactTransactionIdGSIPK': 'COMPACT#aslp#TX#1234567890#', + 'jurisdiction': 'ky', + 'licenseType': 'audiologist', + 'dateOfUpdate': '2024-11-08T23:59:59+00:00', + 'createDate': '2024-11-08T23:59:59+00:00', + 'effectiveDate': '2024-11-08T23:59:59+00:00', + 'previous': { 'dateOfIssuance': '2023-11-08T23:59:59+00:00', + 'dateOfRenewal': '2023-11-08T23:59:59+00:00', + 'dateOfExpiration': '2024-10-31', + 'dateOfUpdate': '2023-11-08T23:59:59+00:00', + 'compactTransactionId': '1234567890', + 'attestations': self.sample_privilege_attestations, + 'administratorSetStatus': 'active', + 'licenseJurisdiction': 'oh', + 'privilegeId': 'AUD-KY-1', + }, + 'updatedValues': { + 'attestations': self.sample_privilege_attestations, 'dateOfRenewal': '2024-11-08T23:59:59+00:00', 'dateOfExpiration': '2025-10-31', - 'dateOfUpdate': '2024-11-08T23:59:59+00:00', 'compactTransactionId': 'test_transaction_id', - 'compactTransactionIdGSIPK': 'COMPACT#aslp#TX#test_transaction_id#', - 'attestations': self.sample_privilege_attestations, - # Should remain the same, since we're renewing the same privilege 'privilegeId': 'AUD-KY-1', }, - # A new history record - { - 'pk': f'aslp#PROVIDER#{provider_uuid}', - 'sk': 'aslp#PROVIDER#privilege/ky/aud#UPDATE#1731110399/f61e34798e1775ff6230d1187d444146', - 'type': 'privilegeUpdate', - 'updateType': 'renewal', - 'providerId': provider_uuid, - 'compact': 'aslp', - 'compactTransactionIdGSIPK': 'COMPACT#aslp#TX#1234567890#', - 'jurisdiction': 'ky', - 'licenseType': 'audiologist', - 'dateOfUpdate': '2024-11-08T23:59:59+00:00', - 'createDate': '2024-11-08T23:59:59+00:00', - 'effectiveDate': '2024-11-08T23:59:59+00:00', - 'previous': { - 'dateOfIssuance': '2023-11-08T23:59:59+00:00', - 'dateOfRenewal': '2023-11-08T23:59:59+00:00', - 'dateOfExpiration': '2024-10-31', - 'dateOfUpdate': '2023-11-08T23:59:59+00:00', - 'compactTransactionId': '1234567890', - 'attestations': self.sample_privilege_attestations, - 'administratorSetStatus': 'active', - 'licenseJurisdiction': 'oh', - 'privilegeId': 'AUD-KY-1', - }, - 'updatedValues': { - 'attestations': self.sample_privilege_attestations, - 'dateOfRenewal': '2024-11-08T23:59:59+00:00', - 'dateOfExpiration': '2025-10-31', - 'compactTransactionId': 'test_transaction_id', - 'privilegeId': 'AUD-KY-1', - }, - }, - ], - new_aud_ky_privilege, + }, + ky_aud_update_record.serialize_to_database_record(), ) # Verify that a new audiologist privilege record was created for ne with expected values - new_aud_ne_privilege = self._provider_table.query( - KeyConditionExpression=Key('pk').eq(f'aslp#PROVIDER#{provider_uuid}') - & Key('sk').begins_with('aslp#PROVIDER#privilege/ne/aud#'), - )['Items'] + new_aud_ne_privilege = provider_user_records.get_specific_privilege_record( + jurisdiction='ne', license_abbreviation='aud' + ) self.assertEqual( - [ - # Primary record with no history record - { - 'pk': f'aslp#PROVIDER#{provider_uuid}', - 'sk': 'aslp#PROVIDER#privilege/ne/aud#', - 'type': 'privilege', - 'providerId': provider_uuid, - 'compact': 'aslp', - 'jurisdiction': 'ne', - 'licenseJurisdiction': 'oh', - 'licenseType': 'audiologist', - 'administratorSetStatus': 'active', - # issuance and renewal dates should be the same - 'dateOfIssuance': '2024-11-08T23:59:59+00:00', - 'dateOfRenewal': '2024-11-08T23:59:59+00:00', - 'dateOfExpiration': '2025-10-31', - 'dateOfUpdate': '2024-11-08T23:59:59+00:00', - 'compactTransactionId': 'test_transaction_id', - 'compactTransactionIdGSIPK': 'COMPACT#aslp#TX#test_transaction_id#', - 'attestations': self.sample_privilege_attestations, - # Should remain the same, since we're renewing the same privilege - 'privilegeId': 'AUD-NE-124', - } - ], - new_aud_ne_privilege, + # Primary record + { + 'pk': f'aslp#PROVIDER#{provider_uuid}', + 'sk': 'aslp#PROVIDER#privilege/ne/aud#', + 'type': 'privilege', + 'providerId': provider_uuid, + 'compact': 'aslp', + 'jurisdiction': 'ne', + 'licenseJurisdiction': 'oh', + 'licenseType': 'audiologist', + 'administratorSetStatus': 'active', + # issuance and renewal dates should be the same + 'dateOfIssuance': '2024-11-08T23:59:59+00:00', + 'dateOfRenewal': '2024-11-08T23:59:59+00:00', + 'dateOfExpiration': '2025-10-31', + 'dateOfUpdate': '2024-11-08T23:59:59+00:00', + 'compactTransactionId': 'test_transaction_id', + 'compactTransactionIdGSIPK': 'COMPACT#aslp#TX#test_transaction_id#', + 'attestations': self.sample_privilege_attestations, + # Should remain the same, since we're renewing the same privilege + 'privilegeId': 'AUD-NE-124', + }, + new_aud_ne_privilege.serialize_to_database_record(), ) + # assert there are no update records for this privilege using test_data_generator + ne_aud_update_records = self.test_data_generator.query_privilege_update_records_for_given_record_from_database( + new_aud_ne_privilege + ) + self.assertEqual([], ne_aud_update_records) # ensure that slp privilege was not updated with an update record slp_privilege = self._provider_table.query( @@ -448,6 +465,7 @@ def test_data_client_create_privilege_record_invalid_license_type(self): def test_data_client_handles_large_privilege_purchase(self): """Test that we can process privilege purchases with more than 100 transaction items.""" from cc_common.data_model.data_client import DataClient + from cc_common.data_model.provider_record_util import ProviderUserRecords from cc_common.data_model.schema.common import ActiveInactiveStatus from cc_common.data_model.schema.privilege import PrivilegeData @@ -501,24 +519,28 @@ def test_data_client_handles_large_privilege_purchase(self): ) # Verify that all privileges were updated - for jurisdiction in jurisdictions: - privilege_records = self._provider_table.query( - KeyConditionExpression=Key('pk').eq(f'aslp#PROVIDER#{provider_uuid}') - & Key('sk').begins_with(f'aslp#PROVIDER#privilege/{jurisdiction}/aud#'), - )['Items'] - - self.assertEqual(2, len(privilege_records)) # One privilege record and one update record + provider_user_records: ProviderUserRecords = self.config.data_client.get_provider_user_records( + compact='aslp', provider_id=provider_uuid + ) - # Find the main privilege record - privilege_record = next(r for r in privilege_records if r['type'] == 'privilege') - self.assertEqual('2025-10-31', privilege_record['dateOfExpiration']) - self.assertEqual('test_transaction_id', privilege_record['compactTransactionId']) + for jurisdiction in jurisdictions: + # Get the privilege record using ProviderUserRecords + privilege_record = provider_user_records.get_specific_privilege_record( + jurisdiction=jurisdiction, license_abbreviation='aud' + ) + self.assertIsNotNone(privilege_record, f'Privilege record not found for jurisdiction {jurisdiction}') + self.assertEqual('2025-10-31', privilege_record.dateOfExpiration.isoformat()) + self.assertEqual('test_transaction_id', privilege_record.compactTransactionId) - # Find the update record - update_record = next(r for r in privilege_records if r['type'] == 'privilegeUpdate') - self.assertEqual('renewal', update_record['updateType']) - self.assertEqual('2024-10-31', update_record['previous']['dateOfExpiration']) - self.assertEqual('2025-10-31', update_record['updatedValues']['dateOfExpiration']) + # Get the update record using test_data_generator + update_records = self.test_data_generator.query_privilege_update_records_for_given_record_from_database( + privilege_record + ) + self.assertEqual(1, len(update_records), f'Expected 1 update record for jurisdiction {jurisdiction}') + update_record = update_records[0] + self.assertEqual('renewal', update_record.updateType) + self.assertEqual('2024-10-31', update_record.previous['dateOfExpiration'].isoformat()) + self.assertEqual('2025-10-31', update_record.updatedValues['dateOfExpiration'].isoformat()) # Verify the provider record was updated correctly provider = self._provider_table.get_item( @@ -721,6 +743,7 @@ def test_get_ssn_by_provider_id_raises_exception_multiple_records_found(self): def test_deactivate_privilege_updates_record(self): from cc_common.data_model.data_client import DataClient + from cc_common.data_model.provider_record_util import ProviderUserRecords provider_id = self._load_provider_data() @@ -762,68 +785,80 @@ def test_deactivate_privilege_updates_record(self): ) # Verify that the privilege record was updated - new_privilege = self._provider_table.query( - KeyConditionExpression=Key('pk').eq(f'aslp#PROVIDER#{provider_id}') - & Key('sk').begins_with('aslp#PROVIDER#privilege/ne/aud#'), - )['Items'] + provider_user_records: ProviderUserRecords = self.config.data_client.get_provider_user_records( + compact='aslp', provider_id=provider_id + ) + + new_privilege = provider_user_records.get_specific_privilege_record( + jurisdiction='ne', license_abbreviation='aud' + ) + self.assertIsNotNone(new_privilege, 'Privilege record not found') + self.assertEqual( - [ - # Primary record - { - 'pk': f'aslp#PROVIDER#{provider_id}', - 'sk': 'aslp#PROVIDER#privilege/ne/aud#', - 'type': 'privilege', - 'providerId': str(provider_id), - 'compact': 'aslp', - 'licenseJurisdiction': 'oh', - 'licenseType': 'audiologist', - 'jurisdiction': 'ne', - 'administratorSetStatus': 'inactive', + { + 'pk': f'aslp#PROVIDER#{provider_id}', + 'sk': 'aslp#PROVIDER#privilege/ne/aud#', + 'type': 'privilege', + 'providerId': str(provider_id), + 'compact': 'aslp', + 'licenseJurisdiction': 'oh', + 'licenseType': 'audiologist', + 'jurisdiction': 'ne', + 'administratorSetStatus': 'inactive', + 'dateOfIssuance': '2023-11-08T23:59:59+00:00', + 'dateOfRenewal': '2023-11-08T23:59:59+00:00', + 'dateOfExpiration': '2024-10-31', + 'dateOfUpdate': '2024-11-08T23:59:59+00:00', + 'compactTransactionId': '1234567890', + 'compactTransactionIdGSIPK': 'COMPACT#aslp#TX#1234567890#', + 'attestations': self.sample_privilege_attestations, + 'privilegeId': 'AUD-NE-1', + }, + new_privilege.serialize_to_database_record(), + ) + + # Get the update record using test_data_generator + update_records = self.test_data_generator.query_privilege_update_records_for_given_record_from_database( + new_privilege + ) + self.assertEqual(1, len(update_records), 'Expected 1 update record') + update_record = update_records[0] + + self.assertEqual( + { + 'pk': f'aslp#PROVIDER#{provider_id}', + 'sk': 'aslp#UPDATE#1#privilege/ne/aud/2024-11-08T23:59:59+00:00/aac682a76e1182a641a1b40dd606ae51', + 'type': 'privilegeUpdate', + 'updateType': 'deactivation', + 'providerId': str(provider_id), + 'compact': 'aslp', + 'compactTransactionIdGSIPK': 'COMPACT#aslp#TX#1234567890#', + 'jurisdiction': 'ne', + 'licenseType': 'audiologist', + 'dateOfUpdate': '2024-11-08T23:59:59+00:00', + 'createDate': '2024-11-08T23:59:59+00:00', + 'effectiveDate': '2024-11-08T23:59:59+00:00', + 'deactivationDetails': { + 'note': 'test deactivation note', + 'deactivatedByStaffUserId': 'a4182428-d061-701c-82e5-a3d1d547d797', + 'deactivatedByStaffUserName': 'John Doe', + }, + 'previous': { 'dateOfIssuance': '2023-11-08T23:59:59+00:00', 'dateOfRenewal': '2023-11-08T23:59:59+00:00', 'dateOfExpiration': '2024-10-31', - 'dateOfUpdate': '2024-11-08T23:59:59+00:00', + 'dateOfUpdate': '2023-11-08T23:59:59+00:00', 'compactTransactionId': '1234567890', - 'compactTransactionIdGSIPK': 'COMPACT#aslp#TX#1234567890#', 'attestations': self.sample_privilege_attestations, + 'administratorSetStatus': 'active', + 'licenseJurisdiction': 'oh', 'privilegeId': 'AUD-NE-1', }, - # A new history record - { - 'pk': f'aslp#PROVIDER#{provider_id}', - 'sk': 'aslp#PROVIDER#privilege/ne/aud#UPDATE#1731110399/aac682a76e1182a641a1b40dd606ae51', - 'type': 'privilegeUpdate', - 'updateType': 'deactivation', - 'providerId': str(provider_id), - 'compact': 'aslp', - 'compactTransactionIdGSIPK': 'COMPACT#aslp#TX#1234567890#', - 'jurisdiction': 'ne', - 'licenseType': 'audiologist', - 'dateOfUpdate': '2024-11-08T23:59:59+00:00', - 'createDate': '2024-11-08T23:59:59+00:00', - 'effectiveDate': '2024-11-08T23:59:59+00:00', - 'deactivationDetails': { - 'note': 'test deactivation note', - 'deactivatedByStaffUserId': 'a4182428-d061-701c-82e5-a3d1d547d797', - 'deactivatedByStaffUserName': 'John Doe', - }, - 'previous': { - 'dateOfIssuance': '2023-11-08T23:59:59+00:00', - 'dateOfRenewal': '2023-11-08T23:59:59+00:00', - 'dateOfExpiration': '2024-10-31', - 'dateOfUpdate': '2023-11-08T23:59:59+00:00', - 'compactTransactionId': '1234567890', - 'attestations': self.sample_privilege_attestations, - 'administratorSetStatus': 'active', - 'licenseJurisdiction': 'oh', - 'privilegeId': 'AUD-NE-1', - }, - 'updatedValues': { - 'administratorSetStatus': 'inactive', - }, + 'updatedValues': { + 'administratorSetStatus': 'inactive', }, - ], - new_privilege, + }, + update_record.serialize_to_database_record(), ) # The deactivation should not remove 'ne' from privilegeJurisdictions, as that set is intended to include @@ -855,6 +890,7 @@ def test_deactivate_privilege_raises_if_privilege_not_found(self): def test_deactivate_privilege_on_inactive_privilege_raises_exception(self): from cc_common.data_model.data_client import DataClient + from cc_common.data_model.provider_record_util import ProviderUserRecords provider_id = self._load_provider_data() @@ -889,11 +925,14 @@ def test_deactivate_privilege_on_inactive_privilege_raises_exception(self): # We'll create it as if it were already deactivated original_history = { 'pk': f'aslp#PROVIDER#{provider_id}', - 'sk': 'aslp#PROVIDER#privilege/ne/aud#UPDATE#1731110399/483bebc6cb3fd6b517f8ce9ad706c518', + 'sk': 'aslp#UPDATE#1#privilege/ne/aud/2024-11-08T23:59:59+00:00/4ebb3dc8f1ffcc30fe7aad5ec49d0ca6', 'type': 'privilegeUpdate', 'updateType': 'renewal', 'providerId': str(provider_id), 'compact': 'aslp', + 'licenseType': 'audiologist', + 'createDate': '2024-11-08T23:59:59+00:00', + 'effectiveDate': '2024-11-08T23:59:59+00:00', 'compactTransactionIdGSIPK': 'COMPACT#aslp#TX#1234567890#', 'jurisdiction': 'ne', 'dateOfUpdate': '2024-11-08T23:59:59+00:00', @@ -905,7 +944,6 @@ def test_deactivate_privilege_on_inactive_privilege_raises_exception(self): 'compactTransactionId': '1234567890', 'attestations': self.sample_privilege_attestations, 'licenseJurisdiction': 'oh', - 'licenseType': 'audiologist', 'privilegeId': 'AUD-NE-1', }, 'updatedValues': { @@ -932,17 +970,30 @@ def test_deactivate_privilege_on_inactive_privilege_raises_exception(self): self.assertEqual('Privilege already deactivated', context.exception.message) # Verify that the privilege record was unchanged - new_privilege = self._provider_table.query( - KeyConditionExpression=Key('pk').eq(f'aslp#PROVIDER#{provider_id}') - & Key('sk').begins_with('aslp#PROVIDER#privilege/ne/aud#'), - )['Items'] - self.assertEqual([original_privilege, original_history], new_privilege) + provider_user_records: ProviderUserRecords = self.config.data_client.get_provider_user_records( + compact='aslp', provider_id=provider_id + ) + + new_privilege = provider_user_records.get_specific_privilege_record( + jurisdiction='ne', license_abbreviation='aud' + ) + self.assertIsNotNone(new_privilege, 'Privilege record not found') + serialized_record = new_privilege.serialize_to_database_record() + # the serialize_to_database_record() call automatically generates a new dateOfUpdate stamp, + # setting it back to the original timestamp for comparison + serialized_record['dateOfUpdate'] = original_privilege['dateOfUpdate'] + self.assertEqual(original_privilege, serialized_record) + + # Verify the update record is unchanged using test_data_generator + update_records = self.test_data_generator.query_privilege_update_records_for_given_record_from_database( + new_privilege + ) + self.assertEqual(1, len(update_records), 'Expected 1 update record') + self.assertEqual(original_history, update_records[0].serialize_to_database_record()) # 'ne' should still be removed from privilegeJurisdictions - provider = self._provider_table.get_item( - Key={'pk': f'aslp#PROVIDER#{provider_id}', 'sk': 'aslp#PROVIDER'}, - )['Item'] - self.assertEqual(set(), provider.get('privilegeJurisdictions', set())) + provider = provider_user_records.get_provider_record() + self.assertEqual(set(), provider.privilegeJurisdictions) def test_get_provider_user_records_correctly_handles_pagination(self): """Test that get_provider_user_records correctly handles pagination by returning all records. @@ -1057,10 +1108,13 @@ def test_create_privilege_investigation_success(self): client.create_investigation(investigation) # Verify investigation record was created - investigation_records = self.config.provider_table.query( - KeyConditionExpression=Key('pk').eq(f'aslp#PROVIDER#{provider_id}') - & Key('sk').begins_with('aslp#PROVIDER#privilege/ne/slp#INVESTIGATION#') - )['Items'] + provider_user_records = self.config.data_client.get_provider_user_records( + compact='aslp', provider_id=provider_id, include_update_tier=UpdateTierEnum.TIER_THREE + ) + investigation_records = provider_user_records.get_investigation_records_for_privilege( + privilege_jurisdiction='ne', + privilege_license_type_abbreviation='slp', + ) self.assertEqual(1, len(investigation_records)) investigation_record = investigation_records[0] @@ -1078,27 +1132,23 @@ def test_create_privilege_investigation_success(self): 'investigationId': str(investigation.investigationId), 'submittingUser': str(investigation.submittingUser), 'creationDate': investigation.creationDate.isoformat(), + 'dateOfUpdate': ANY, } # Pop dynamic fields that we don't want to assert on - investigation_record.pop('dateOfUpdate') - - self.assertEqual(expected_investigation, investigation_record) + self.assertEqual(expected_investigation, investigation_record.serialize_to_database_record()) # Verify privilege record was updated with investigation status - privilege_records = self.config.provider_table.query( - KeyConditionExpression=Key('pk').eq(f'aslp#PROVIDER#{provider_id}') - & Key('sk').eq('aslp#PROVIDER#privilege/ne/slp#') - )['Items'] + privilege_records = provider_user_records.get_privilege_records() self.assertEqual(1, len(privilege_records)) privilege_record = privilege_records[0] - self.assertEqual('underInvestigation', privilege_record['investigationStatus']) + self.assertEqual('underInvestigation', privilege_record.investigationStatus) # Verify update record was created - update_records = self.config.provider_table.query( - KeyConditionExpression=Key('pk').eq(f'aslp#PROVIDER#{provider_id}') - & Key('sk').begins_with('aslp#PROVIDER#privilege/ne/slp#UPDATE#') - )['Items'] + update_records = provider_user_records.get_update_records_for_privilege( + jurisdiction=privilege_record.jurisdiction, + license_type=privilege_record.licenseType, + ) self.assertEqual(1, len(update_records)) update_record = update_records[0] @@ -1106,6 +1156,7 @@ def test_create_privilege_investigation_success(self): # Verify the complete update record structure expected_update = { 'pk': f'aslp#PROVIDER#{provider_id}', + 'sk': ANY, 'compactTransactionIdGSIPK': 'COMPACT#aslp#TX#1234567890#', 'type': 'privilegeUpdate', 'updateType': 'investigation', @@ -1132,12 +1183,10 @@ def test_create_privilege_investigation_success(self): 'investigationDetails': { 'investigationId': str(investigation.investigationId), }, + 'dateOfUpdate': ANY, } - # Pop dynamic fields that we don't want to assert on - update_record.pop('dateOfUpdate') - update_record.pop('sk') - self.assertEqual(expected_update, update_record) + self.assertEqual(expected_update, update_record.serialize_to_database_record()) def test_create_license_investigation_success(self): """Test successful creation of license investigation""" @@ -1168,10 +1217,13 @@ def test_create_license_investigation_success(self): client.create_investigation(investigation) # Verify investigation record was created - investigation_records = self.config.provider_table.query( - KeyConditionExpression=Key('pk').eq(f'aslp#PROVIDER#{provider_id}') - & Key('sk').begins_with('aslp#PROVIDER#license/oh/slp#INVESTIGATION#') - )['Items'] + provider_user_records = self.config.data_client.get_provider_user_records( + compact='aslp', provider_id=provider_id, include_update_tier=UpdateTierEnum.TIER_THREE + ) + investigation_records = provider_user_records.get_investigation_records_for_license( + license_jurisdiction='oh', + license_type_abbreviation='slp', + ) self.assertEqual(1, len(investigation_records)) investigation_record = investigation_records[0] @@ -1189,27 +1241,23 @@ def test_create_license_investigation_success(self): 'investigationId': str(investigation.investigationId), 'submittingUser': str(investigation.submittingUser), 'creationDate': investigation.creationDate.isoformat(), + 'dateOfUpdate': ANY, } - # Pop dynamic fields that we don't want to assert on - investigation_record.pop('dateOfUpdate') - self.assertEqual(expected_investigation, investigation_record) + self.assertEqual(expected_investigation, investigation_record.serialize_to_database_record()) # Verify license record was updated with investigation status - license_records = self.config.provider_table.query( - KeyConditionExpression=Key('pk').eq(f'aslp#PROVIDER#{provider_id}') - & Key('sk').eq('aslp#PROVIDER#license/oh/slp#') - )['Items'] + license_records = provider_user_records.get_license_records() self.assertEqual(1, len(license_records)) license_record = license_records[0] - self.assertEqual('underInvestigation', license_record['investigationStatus']) + self.assertEqual('underInvestigation', license_record.investigationStatus) # Verify update record was created - update_records = self.config.provider_table.query( - KeyConditionExpression=Key('pk').eq(f'aslp#PROVIDER#{provider_id}') - & Key('sk').begins_with('aslp#PROVIDER#license/oh/slp#UPDATE#') - )['Items'] + update_records = provider_user_records.get_update_records_for_license( + jurisdiction=license_record.jurisdiction, + license_type=license_record.licenseType, + ) self.assertEqual(1, len(update_records)) update_record = update_records[0] @@ -1217,6 +1265,7 @@ def test_create_license_investigation_success(self): # Verify the complete update record structure expected_update = { 'pk': f'aslp#PROVIDER#{provider_id}', + 'sk': ANY, 'type': 'licenseUpdate', 'updateType': 'investigation', 'compact': 'aslp', @@ -1254,12 +1303,10 @@ def test_create_license_investigation_success(self): 'investigationDetails': { 'investigationId': str(investigation.investigationId), }, + 'dateOfUpdate': ANY, } - # Pop dynamic fields that we don't want to assert on - update_record.pop('dateOfUpdate') - update_record.pop('sk') - self.assertEqual(expected_update, update_record) + self.assertEqual(expected_update, update_record.serialize_to_database_record()) def test_create_privilege_investigation_privilege_not_found(self): """Test creation of privilege investigation when privilege doesn't exist""" @@ -1367,10 +1414,12 @@ def test_close_privilege_investigation_success(self): ) # Verify investigation record was updated with close information - investigation_records = self.config.provider_table.query( - KeyConditionExpression=Key('pk').eq(f'aslp#PROVIDER#{provider_id}') - & Key('sk').begins_with('aslp#PROVIDER#privilege/ne/slp#INVESTIGATION#') - )['Items'] + provider_user_records = self.config.data_client.get_provider_user_records( + compact='aslp', provider_id=provider_id, include_update_tier=UpdateTierEnum.TIER_THREE + ) + investigation_records = provider_user_records.get_investigation_records_for_privilege( + privilege_jurisdiction='ne', privilege_license_type_abbreviation='slp', include_closed=True + ) self.assertEqual(1, len(investigation_records)) investigation_record = investigation_records[0] @@ -1390,27 +1439,21 @@ def test_close_privilege_investigation_success(self): 'creationDate': investigation.creationDate.isoformat(), 'closeDate': investigation.creationDate.isoformat(), 'closingUser': closing_user, + 'dateOfUpdate': ANY, } - # Pop dynamic fields that we don't want to assert on - investigation_record.pop('dateOfUpdate') - - self.assertEqual(expected_investigation_close, investigation_record) + self.assertEqual(expected_investigation_close, investigation_record.serialize_to_database_record()) # Verify privilege record no longer has investigation status - privilege_records = self.config.provider_table.query( - KeyConditionExpression=Key('pk').eq(f'aslp#PROVIDER#{provider_id}') - & Key('sk').eq('aslp#PROVIDER#privilege/ne/slp#') - )['Items'] - + privilege_records = provider_user_records.get_privilege_records() self.assertEqual(1, len(privilege_records)) privilege_record = privilege_records[0] - self.assertNotIn('investigationStatus', privilege_record) + self.assertIsNone(privilege_record.investigationStatus) # Verify update record was created for closure - update_records = self.config.provider_table.query( - KeyConditionExpression=Key('pk').eq(f'aslp#PROVIDER#{provider_id}') - & Key('sk').begins_with('aslp#PROVIDER#privilege/ne/slp#UPDATE#') - )['Items'] + update_records = provider_user_records.get_update_records_for_privilege( + jurisdiction='ne', + license_type=privilege_record.licenseType, + ) # Should have 2 update records: one for creation, one for closure self.assertEqual(2, len(update_records)) @@ -1418,7 +1461,7 @@ def test_close_privilege_investigation_success(self): # Find the closure update record closure_update = None for update_record in update_records: - if update_record.get('updateType') == 'closingInvestigation': + if update_record.updateType == 'closingInvestigation': closure_update = update_record break @@ -1427,6 +1470,7 @@ def test_close_privilege_investigation_success(self): # Verify the complete closure update record structure expected_closure_update = { 'pk': f'aslp#PROVIDER#{provider_id}', + 'sk': ANY, 'type': 'privilegeUpdate', 'updateType': 'closingInvestigation', 'compact': 'aslp', @@ -1449,15 +1493,11 @@ def test_close_privilege_investigation_success(self): }, 'updatedValues': {}, 'removedValues': ['investigationStatus'], + 'dateOfUpdate': ANY, + 'compactTransactionIdGSIPK': ANY, } - # Pop dynamic fields that we don't want to assert on - closure_update.pop('dateOfUpdate') - closure_update.pop('sk') - # Only pop compactTransactionIdGSIPK if it exists - if 'compactTransactionIdGSIPK' in closure_update: - closure_update.pop('compactTransactionIdGSIPK') - self.assertEqual(expected_closure_update, closure_update) + self.assertEqual(expected_closure_update, closure_update.serialize_to_database_record()) def test_close_license_investigation_success(self): """Test successful closing of license investigation""" @@ -1501,11 +1541,15 @@ def test_close_license_investigation_success(self): investigation_against=InvestigationAgainstEnum.LICENSE, ) + # grab all provider records to make assertions + provider_user_records = self.config.data_client.get_provider_user_records( + compact='aslp', provider_id=provider_id, include_update_tier=UpdateTierEnum.TIER_THREE + ) + # Verify investigation record was updated with close information - investigation_records = self.config.provider_table.query( - KeyConditionExpression=Key('pk').eq(f'aslp#PROVIDER#{provider_id}') - & Key('sk').begins_with('aslp#PROVIDER#license/oh/slp#INVESTIGATION#') - )['Items'] + investigation_records = provider_user_records.get_investigation_records_for_license( + license_jurisdiction='oh', license_type_abbreviation='slp', include_closed=True + ) self.assertEqual(1, len(investigation_records)) investigation_record = investigation_records[0] @@ -1525,27 +1569,22 @@ def test_close_license_investigation_success(self): 'creationDate': investigation.creationDate.isoformat(), 'closeDate': close_date.isoformat(), 'closingUser': closing_user, + 'dateOfUpdate': ANY, } - # Pop dynamic fields that we don't want to assert on - investigation_record.pop('dateOfUpdate') - self.assertEqual(expected_investigation_close, investigation_record) + self.assertEqual(expected_investigation_close, investigation_record.serialize_to_database_record()) # Verify license record no longer has investigation status - license_records = self.config.provider_table.query( - KeyConditionExpression=Key('pk').eq(f'aslp#PROVIDER#{provider_id}') - & Key('sk').eq('aslp#PROVIDER#license/oh/slp#') - )['Items'] + license_records = provider_user_records.get_license_records() self.assertEqual(1, len(license_records)) license_record = license_records[0] - self.assertNotIn('investigationStatus', license_record) + self.assertNotIn('investigationStatus', license_record.to_dict()) # Verify update record was created for closure - update_records = self.config.provider_table.query( - KeyConditionExpression=Key('pk').eq(f'aslp#PROVIDER#{provider_id}') - & Key('sk').begins_with('aslp#PROVIDER#license/oh/slp#UPDATE#') - )['Items'] + update_records = provider_user_records.get_update_records_for_license( + jurisdiction=license_record.jurisdiction, license_type=license_record.licenseType + ) # Should have 2 update records: one for creation, one for closure self.assertEqual(2, len(update_records)) @@ -1553,7 +1592,7 @@ def test_close_license_investigation_success(self): # Find the closure update record closure_update = None for update_record in update_records: - if update_record.get('updateType') == 'closingInvestigation': + if update_record.updateType == 'closingInvestigation': closure_update = update_record break @@ -1562,6 +1601,7 @@ def test_close_license_investigation_success(self): # Verify the complete closure update record structure expected_closure_update = { 'pk': f'aslp#PROVIDER#{provider_id}', + 'sk': ANY, 'type': 'licenseUpdate', 'updateType': 'closingInvestigation', 'compact': 'aslp', @@ -1596,12 +1636,10 @@ def test_close_license_investigation_success(self): }, 'updatedValues': {}, 'removedValues': ['investigationStatus'], + 'dateOfUpdate': ANY, } - # Pop dynamic fields that we don't want to assert on - closure_update.pop('dateOfUpdate') - closure_update.pop('sk') - self.assertEqual(expected_closure_update, closure_update) + self.assertEqual(expected_closure_update, closure_update.serialize_to_database_record()) def test_close_privilege_investigation_not_found(self): """Test closing privilege investigation when investigation doesn't exist""" @@ -1916,3 +1954,158 @@ def test_close_license_investigation_with_encumbrance(self): investigation_record.pop('dateOfUpdate') self.assertEqual(expected_investigation_close, investigation_record) + + # TODO - remove this test once migration from old update SK pattern is complete # noqa: FIX002 + def test_get_provider_user_records_returns_old_sk_pattern_update_records_with_tier_one(self): + """Test that get_provider_user_records with TIER_ONE returns privilege update records with old SK pattern.""" + from cc_common.data_model.data_client import DataClient + from cc_common.data_model.provider_record_util import ProviderUserRecords + from cc_common.data_model.update_tier_enum import UpdateTierEnum + + provider_uuid = str(uuid4()) + compact = 'aslp' + jurisdiction = 'ky' + license_type_abbr = 'aud' + + # Create provider and privilege records + self.test_data_generator.put_default_provider_record_in_provider_table( + value_overrides={ + 'providerId': provider_uuid, + 'compact': compact, + } + ) + + privilege = self.test_data_generator.put_default_privilege_record_in_provider_table( + value_overrides={ + 'providerId': provider_uuid, + 'compact': compact, + 'jurisdiction': jurisdiction, + 'licenseType': 'audiologist', + } + ) + + # Manually create a privilege update record with the old SK pattern + old_sk_update_record = { + 'pk': f'{compact}#PROVIDER#{provider_uuid}', + 'sk': f'{compact}#PROVIDER#privilege/{jurisdiction}/{license_type_abbr}' + f'#UPDATE/1731110399/939a3c350708e34875f0a652bf7d7454', + 'type': 'privilegeUpdate', + 'updateType': 'renewal', + 'providerId': provider_uuid, + 'compact': compact, + 'jurisdiction': jurisdiction, + 'licenseType': 'audiologist', + 'createDate': '2024-11-08T23:59:59+00:00', + 'effectiveDate': '2024-11-08T23:59:59+00:00', + 'dateOfUpdate': '2024-11-08T23:59:59+00:00', + 'compactTransactionIdGSIPK': 'COMPACT#aslp#TX#1234567890#', + 'previous': { + 'attestations': [{'attestationId': 'jurisprudence-confirmation', 'version': '1'}], + 'dateOfIssuance': '2016-05-05T12:59:59+00:00', + 'dateOfRenewal': '2016-05-05T12:59:59+00:00', + 'dateOfExpiration': '2020-06-06', + 'dateOfUpdate': '2016-05-05T12:59:59+00:00', + 'compactTransactionId': '0123456789', + 'privilegeId': 'SLP-NE-1', + 'administratorSetStatus': 'active', + 'licenseJurisdiction': 'oh', + }, + 'updatedValues': { + 'dateOfRenewal': '2024-11-08T23:59:59+00:00', + 'dateOfExpiration': '2025-10-31', + 'compactTransactionId': 'test_transaction_id', + }, + } + self._provider_table.put_item(Item=old_sk_update_record) + + # Call get_provider_user_records with TIER_ONE + client = DataClient(self.config) + provider_user_records: ProviderUserRecords = client.get_provider_user_records( + compact=compact, + provider_id=provider_uuid, + include_update_tier=UpdateTierEnum.TIER_ONE, + ) + + # Verify the old SK pattern update record is returned + update_records = provider_user_records.get_update_records_for_privilege( + jurisdiction=jurisdiction, license_type=privilege.licenseType + ) + self.assertEqual(1, len(update_records)) + self.assertEqual('renewal', update_records[0].updateType) + + # TODO - remove this test once migration from old update SK pattern is complete # noqa: FIX002 + def test_get_privilege_data_returns_old_sk_pattern_update_records_with_detail(self): + """Test that get_privilege_data with detail=True returns privilege update records with old SK pattern.""" + from cc_common.data_model.data_client import DataClient + + provider_uuid = str(uuid4()) + compact = 'aslp' + jurisdiction = 'ne' + license_type_abbr = 'aud' + + # Create provider and privilege records + self.test_data_generator.put_default_provider_record_in_provider_table( + value_overrides={ + 'providerId': provider_uuid, + 'compact': compact, + } + ) + + self.test_data_generator.put_default_privilege_record_in_provider_table( + value_overrides={ + 'providerId': provider_uuid, + 'compact': compact, + 'jurisdiction': jurisdiction, + 'licenseType': 'audiologist', + } + ) + + # Manually create a privilege update record with the old SK pattern + old_sk_update_record = { + 'pk': f'{compact}#PROVIDER#{provider_uuid}', + 'sk': f'{compact}#PROVIDER#privilege/{jurisdiction}/{license_type_abbr}' + f'#UPDATE/1731110399/939a3c350708e34875f0a652bf7d7454', + 'type': 'privilegeUpdate', + 'updateType': 'renewal', + 'providerId': provider_uuid, + 'compact': compact, + 'jurisdiction': jurisdiction, + 'licenseType': 'audiologist', + 'createDate': '2024-11-08T23:59:59+00:00', + 'effectiveDate': '2024-11-08T23:59:59+00:00', + 'dateOfUpdate': '2024-11-08T23:59:59+00:00', + 'compactTransactionIdGSIPK': 'COMPACT#aslp#TX#1234567890#', + 'previous': { + 'attestations': [{'attestationId': 'jurisprudence-confirmation', 'version': '1'}], + 'dateOfIssuance': '2016-05-05T12:59:59+00:00', + 'dateOfRenewal': '2016-05-05T12:59:59+00:00', + 'dateOfExpiration': '2020-06-06', + 'dateOfUpdate': '2016-05-05T12:59:59+00:00', + 'compactTransactionId': '0123456789', + 'privilegeId': 'SLP-NE-1', + 'administratorSetStatus': 'active', + 'licenseJurisdiction': 'oh', + }, + 'updatedValues': { + 'dateOfRenewal': '2024-11-08T23:59:59+00:00', + 'dateOfExpiration': '2025-10-31', + 'compactTransactionId': 'test_transaction_id', + }, + } + self._provider_table.put_item(Item=old_sk_update_record) + + # Call get_privilege_data with detail=True + client = DataClient(self.config) + result = client.get_privilege_data( + compact=compact, + provider_id=provider_uuid, + jurisdiction=jurisdiction, + license_type_abbr=license_type_abbr, + detail=True, + ) + + # Verify the result contains the privilege record and the old SK pattern update record + self.assertEqual(2, len(result)) + self.assertEqual('privilege', result[0]['type']) + self.assertEqual('privilegeUpdate', result[1]['type']) + self.assertEqual('renewal', result[1]['updateType']) diff --git a/backend/compact-connect/lambdas/python/common/tests/function/test_data_model/test_transaction_client.py b/backend/compact-connect/lambdas/python/common/tests/function/test_data_model/test_transaction_client.py index a525544dd..e72dd93a6 100644 --- a/backend/compact-connect/lambdas/python/common/tests/function/test_data_model/test_transaction_client.py +++ b/backend/compact-connect/lambdas/python/common/tests/function/test_data_model/test_transaction_client.py @@ -261,9 +261,7 @@ def test_reconcile_unsettled_transactions_deletes_matching_record_and_returns_ol ) # Two unmatched transactions remain @patch('cc_common.data_model.transaction_client.logger') - def test_reconcile_unsettled_transactions_logs_error_when_settled_transactions_not_matched( - self, mock_logger - ): + def test_reconcile_unsettled_transactions_logs_error_when_settled_transactions_not_matched(self, mock_logger): """ Test that reconcile_unsettled_transactions logs an error when settled transactions don't match unsettled ones. """ diff --git a/backend/compact-connect/lambdas/python/common/tests/resources/dynamo/license-update.json b/backend/compact-connect/lambdas/python/common/tests/resources/dynamo/license-update.json index 6c43d8072..68b682b23 100644 --- a/backend/compact-connect/lambdas/python/common/tests/resources/dynamo/license-update.json +++ b/backend/compact-connect/lambdas/python/common/tests/resources/dynamo/license-update.json @@ -1,6 +1,8 @@ { "pk": "aslp#PROVIDER#89a6377e-c3a5-40e5-bca5-317ec854c570", - "sk": "aslp#PROVIDER#license/oh/slp#UPDATE#1586264399/34702de3dc08e64922605a6b18f3838b", + "sk": "aslp#UPDATE#3#license/oh/slp/2024-11-08T23:59:59+00:00/34702de3dc08e64922605a6b18f3838b", + "licenseUploadDateGSIPK": "C#aslp#J#oh#D#2024-11", + "licenseUploadDateGSISK": "TIME#1731110399#LT#slp#PID#89a6377e-c3a5-40e5-bca5-317ec854c570", "type": "licenseUpdate", "updateType": "renewal", "providerId": "89a6377e-c3a5-40e5-bca5-317ec854c570", diff --git a/backend/compact-connect/lambdas/python/common/tests/resources/dynamo/privilege-update.json b/backend/compact-connect/lambdas/python/common/tests/resources/dynamo/privilege-update.json index 8f2469682..95ad3ebd8 100644 --- a/backend/compact-connect/lambdas/python/common/tests/resources/dynamo/privilege-update.json +++ b/backend/compact-connect/lambdas/python/common/tests/resources/dynamo/privilege-update.json @@ -1,6 +1,6 @@ { "pk": "aslp#PROVIDER#89a6377e-c3a5-40e5-bca5-317ec854c570", - "sk": "aslp#PROVIDER#privilege/ne/slp#UPDATE#1731110399/939a3c350708e34875f0a652bf7d7454", + "sk": "aslp#UPDATE#1#privilege/ne/slp/2020-05-05T12:59:59+00:00/939a3c350708e34875f0a652bf7d7454", "type": "privilegeUpdate", "updateType": "renewal", "providerId": "89a6377e-c3a5-40e5-bca5-317ec854c570", diff --git a/backend/compact-connect/lambdas/python/common/tests/unit/test_data_model/test_schema/test_license.py b/backend/compact-connect/lambdas/python/common/tests/unit/test_data_model/test_schema/test_license.py index d251530bc..73f239035 100644 --- a/backend/compact-connect/lambdas/python/common/tests/unit/test_data_model/test_schema/test_license.py +++ b/backend/compact-connect/lambdas/python/common/tests/unit/test_data_model/test_schema/test_license.py @@ -207,6 +207,8 @@ def test_hash_is_deterministic(self): 'compact': 'aslp', 'jurisdiction': 'ky', 'licenseType': 'speech-language pathologist', + 'updateType': loaded_record['updateType'], + 'createDate': loaded_record['createDate'], # These two fields should determine the change hash: 'previous': loaded_record['previous'].copy(), 'updatedValues': loaded_record['updatedValues'].copy(), @@ -234,6 +236,8 @@ def test_hash_is_unique(self): 'compact': 'aslp', 'jurisdiction': 'ky', 'licenseType': 'speech-language pathologist', + 'updateType': loaded_record['updateType'], + 'createDate': loaded_record['createDate'], # These two fields should determine the change hash: 'previous': loaded_record['previous'].copy(), 'updatedValues': loaded_record['updatedValues'].copy(), diff --git a/backend/compact-connect/lambdas/python/data-events/tests/function/test_encumbrance_events.py b/backend/compact-connect/lambdas/python/data-events/tests/function/test_encumbrance_events.py index be9b4fc75..f7a5ffe26 100644 --- a/backend/compact-connect/lambdas/python/data-events/tests/function/test_encumbrance_events.py +++ b/backend/compact-connect/lambdas/python/data-events/tests/function/test_encumbrance_events.py @@ -4,7 +4,6 @@ from unittest.mock import ANY, MagicMock, patch from uuid import UUID -from boto3.dynamodb.conditions import Key from common_test.test_constants import ( DEFAULT_ADVERSE_ACTION_ID, DEFAULT_CLINICAL_PRIVILEGE_ACTION_CATEGORY, @@ -242,21 +241,18 @@ def test_license_encumbrance_listener_handles_all_privileges_already_encumbered( privileges = provider_records.get_privilege_records() self.assertEqual(1, len(privileges)) - serialized_privilege = privilege.serialize_to_database_record() self.assertEqual(PrivilegeEncumberedStatusEnum.ENCUMBERED, privileges[0].encumberedStatus) - privilege_update_records = self._provider_table.query( - Select='ALL_ATTRIBUTES', - KeyConditionExpression=Key('pk').eq(serialized_privilege['pk']) - & Key('sk').begins_with(f'{serialized_privilege["sk"]}UPDATE'), + # Get update records using test_data_generator + update_records = self.test_data_generator.query_privilege_update_records_for_given_record_from_database( + privilege ) - - self.assertEqual(1, len(privilege_update_records['Items'])) - update_record = privilege_update_records['Items'][0] - update_encumbrance_details = update_record['encumbranceDetails'] + self.assertEqual(1, len(update_records)) + update_record = update_records[0] + update_encumbrance_details = update_record.encumbranceDetails self.assertEqual( { - 'adverseActionId': DEFAULT_ADVERSE_ACTION_ID, + 'adverseActionId': UUID(DEFAULT_ADVERSE_ACTION_ID), 'licenseJurisdiction': 'oh', 'clinicalPrivilegeActionCategories': ['Unsafe Practice or Substandard Care'], }, @@ -312,21 +308,18 @@ def test_license_encumbrance_listener_handles_all_privileges_already_encumbered_ privileges = provider_records.get_privilege_records() self.assertEqual(1, len(privileges)) - serialized_privilege = privilege.serialize_to_database_record() self.assertEqual(PrivilegeEncumberedStatusEnum.ENCUMBERED, privileges[0].encumberedStatus) - privilege_update_records = self._provider_table.query( - Select='ALL_ATTRIBUTES', - KeyConditionExpression=Key('pk').eq(serialized_privilege['pk']) - & Key('sk').begins_with(f'{serialized_privilege["sk"]}UPDATE'), + # Get update records using test_data_generator + update_records = self.test_data_generator.query_privilege_update_records_for_given_record_from_database( + privilege ) - - self.assertEqual(1, len(privilege_update_records['Items'])) - update_record = privilege_update_records['Items'][0] - update_encumbrance_details = update_record['encumbranceDetails'] + self.assertEqual(1, len(update_records)) + update_record = update_records[0] + update_encumbrance_details = update_record.encumbranceDetails self.assertEqual( { - 'adverseActionId': DEFAULT_ADVERSE_ACTION_ID, + 'adverseActionId': UUID(DEFAULT_ADVERSE_ACTION_ID), 'licenseJurisdiction': 'oh', 'clinicalPrivilegeActionCategory': 'Unsafe Practice or Substandard Care', }, @@ -358,17 +351,13 @@ def test_license_encumbrance_listener_creates_privilege_update_records(self): license_encumbrance_listener(event, self.mock_context) # Verify privilege update record was created - serialized_privilege = privilege.serialize_to_database_record() - privilege_update_records = self._provider_table.query( - Select='ALL_ATTRIBUTES', - KeyConditionExpression=Key('pk').eq(serialized_privilege['pk']) - & Key('sk').begins_with(f'{serialized_privilege["sk"]}UPDATE'), + update_records = self.test_data_generator.query_privilege_update_records_for_given_record_from_database( + privilege ) - - self.assertEqual(1, len(privilege_update_records['Items'])) - update_record = privilege_update_records['Items'][0] - self.assertEqual('encumbrance', update_record['updateType']) - self.assertEqual({'encumberedStatus': 'licenseEncumbered'}, update_record['updatedValues']) + self.assertEqual(1, len(update_records)) + update_record = update_records[0] + self.assertEqual('encumbrance', update_record.updateType) + self.assertEqual({'encumberedStatus': 'licenseEncumbered'}, update_record.updatedValues) @patch('cc_common.event_bus_client.EventBusClient._publish_event') def test_license_encumbrance_lifted_listener_unencumbers_license_encumbered_privileges_successfully( @@ -552,16 +541,13 @@ def test_license_encumbrance_lifted_listener_creates_privilege_update_records(se license_encumbrance_lifted_listener(event, self.mock_context) # Verify privilege update record was created - privilege_update_records = self._provider_table.query( - Select='ALL_ATTRIBUTES', - KeyConditionExpression=Key('pk').eq(privilege.serialize_to_database_record()['pk']) - & Key('sk').begins_with(f'{privilege.compact}#PROVIDER#privilege/{privilege.jurisdiction}/slp#UPDATE'), + update_records = self.test_data_generator.query_privilege_update_records_for_given_record_from_database( + privilege ) - - self.assertEqual(1, len(privilege_update_records['Items'])) - update_record = privilege_update_records['Items'][0] - self.assertEqual('lifting_encumbrance', update_record['updateType']) - self.assertEqual({'encumberedStatus': 'unencumbered'}, update_record['updatedValues']) + self.assertEqual(1, len(update_records)) + update_record = update_records[0] + self.assertEqual('lifting_encumbrance', update_record.updateType) + self.assertEqual({'encumberedStatus': 'unencumbered'}, update_record.updatedValues) @patch('cc_common.event_bus_client.EventBusClient._publish_event') def test_license_encumbrance_listener_handles_multiple_matching_privileges(self, mock_publish_event): diff --git a/backend/compact-connect/lambdas/python/data-events/tests/function/test_license_deactivation_events.py b/backend/compact-connect/lambdas/python/data-events/tests/function/test_license_deactivation_events.py index 4f3d0106f..c65bec043 100644 --- a/backend/compact-connect/lambdas/python/data-events/tests/function/test_license_deactivation_events.py +++ b/backend/compact-connect/lambdas/python/data-events/tests/function/test_license_deactivation_events.py @@ -2,7 +2,6 @@ from datetime import datetime from unittest.mock import patch -from boto3.dynamodb.conditions import Key from common_test.test_constants import ( DEFAULT_COMPACT, DEFAULT_DATE_OF_UPDATE_TIMESTAMP, @@ -299,16 +298,14 @@ def test_license_deactivation_listener_creates_update_records_for_all_affected_p # Verify privilege update records were created for both privileges for privilege in [privilege1, privilege2]: - privilege_update_records = self._provider_table.query( - Select='ALL_ATTRIBUTES', - KeyConditionExpression=Key('pk').eq(privilege.serialize_to_database_record()['pk']) - & Key('sk').begins_with(f'{privilege.compact}#PROVIDER#privilege/{privilege.jurisdiction}/slp#UPDATE'), + privilege_update_records = ( + self.test_data_generator.query_privilege_update_records_for_given_record_from_database(privilege) ) - self.assertEqual(1, len(privilege_update_records['Items'])) - update_record = privilege_update_records['Items'][0] - self.assertEqual('licenseDeactivation', update_record['updateType']) - self.assertEqual({'licenseDeactivatedStatus': 'licenseDeactivated'}, update_record['updatedValues']) + self.assertEqual(1, len(privilege_update_records)) + update_record = privilege_update_records[0] + self.assertEqual('licenseDeactivation', update_record.updateType) + self.assertEqual({'licenseDeactivatedStatus': 'licenseDeactivated'}, update_record.updatedValues) def test_license_deactivation_listener_fails_with_missing_required_fields(self): """Test that license deactivation event handler fails when required fields are missing.""" diff --git a/backend/compact-connect/lambdas/python/disaster-recovery/handlers/rollback_license_upload.py b/backend/compact-connect/lambdas/python/disaster-recovery/handlers/rollback_license_upload.py new file mode 100644 index 000000000..6fc26dc4e --- /dev/null +++ b/backend/compact-connect/lambdas/python/disaster-recovery/handlers/rollback_license_upload.py @@ -0,0 +1,1097 @@ +import json +import time +from dataclasses import dataclass, field +from datetime import datetime + +from aws_lambda_powertools.utilities.typing import LambdaContext +from boto3.dynamodb.conditions import Key +from botocore.exceptions import ClientError +from cc_common.config import config, logger +from cc_common.data_model.provider_record_util import ProviderRecordUtility, ProviderUserRecords +from cc_common.data_model.schema.common import LICENSE_UPLOAD_UPDATE_CATEGORIES, UpdateCategory +from cc_common.data_model.schema.license import LicenseData +from cc_common.data_model.schema.license.record import LicenseRecordSchema +from cc_common.data_model.schema.privilege import PrivilegeData +from cc_common.data_model.schema.provider import ProviderData +from cc_common.data_model.update_tier_enum import UpdateTierEnum +from cc_common.event_batch_writer import EventBatchWriter +from cc_common.exceptions import CCInternalException, CCNotFoundException +from marshmallow import ValidationError + +# Maximum time window for rollback (1 week in seconds) +# this is set as a safety net to prevent accidental rollback over large time period +# it can be modified if needed +MAX_ROLLBACK_WINDOW_SECONDS = 7 * 24 * 60 * 60 + +# Privilege update category for license deactivations +PRIVILEGE_LICENSE_DEACTIVATION_CATEGORY = UpdateCategory.LICENSE_DEACTIVATION + + +class ProviderRollbackFailedException(Exception): + """Custom exception that is thrown when a provider fails to rollback""" + + def __init__(self, message: str): + self.message = message + super().__init__(message) + + +# Data classes for rollback operations +@dataclass +class IneligibleUpdate: + """Represents an update that makes a provider ineligible for rollback.""" + + record_type: str # 'licenseUpdate', 'privilegeUpdate', or 'providerUpdate' + type_of_update: str + update_time: str + reason: str + license_type: str | None = None # License type if applicable (None for provider updates) + + +@dataclass +class ProviderSkippedDetails: + """Details for a provider that was skipped.""" + + provider_id: str + reason: str + ineligible_updates: list[IneligibleUpdate] = field(default_factory=list) + + +@dataclass +class ProviderFailedDetails: + """Details for a provider that failed to revert.""" + + provider_id: str + error: str + + +@dataclass +class RevertedLicense: + """Details of a reverted license for event publishing.""" + + jurisdiction: str + license_type: str + action: str + + +@dataclass +class RevertedPrivilege: + """Details of a reverted privilege for event publishing.""" + + jurisdiction: str + license_type: str + action: str + + +@dataclass +class ProviderRevertedSummary: + """Summary for a provider that was successfully reverted.""" + + provider_id: str + licenses_reverted: list[RevertedLicense] = field(default_factory=list) + privileges_reverted: list[RevertedPrivilege] = field(default_factory=list) + updates_deleted: list[str] = field(default_factory=list) # List of SKs for deleted update records + + +@dataclass +class RollbackResults: + """Complete results of a rollback operation.""" + + execution_name: str + skipped_provider_details: list[ProviderSkippedDetails] = field(default_factory=list) + failed_provider_details: list[ProviderFailedDetails] = field(default_factory=list) + reverted_provider_summaries: list[ProviderRevertedSummary] = field(default_factory=list) + + def to_dict(self) -> dict: + """Convert to dictionary for S3 storage.""" + return { + 'executionName': self.execution_name, + 'skippedProviderDetails': [ + { + 'providerId': detail.provider_id, + 'reason': detail.reason, + 'ineligibleUpdates': [ + { + 'recordType': update.record_type, + 'typeOfUpdate': update.type_of_update, + 'updateTime': update.update_time, + 'reason': update.reason, + 'licenseType': update.license_type, + } + for update in detail.ineligible_updates + ], + } + for detail in self.skipped_provider_details + ], + 'failedProviderDetails': [ + { + 'providerId': detail.provider_id, + 'error': detail.error, + } + for detail in self.failed_provider_details + ], + 'revertedProviderSummaries': [ + { + 'providerId': str(summary.provider_id), + 'licensesReverted': [ + { + 'jurisdiction': license_record.jurisdiction, + 'licenseType': license_record.license_type, + 'action': license_record.action, + } + for license_record in summary.licenses_reverted + ], + 'privilegesReverted': [ + { + 'jurisdiction': privilege.jurisdiction, + 'licenseType': privilege.license_type, + 'action': privilege.action, + } + for privilege in summary.privileges_reverted + ], + 'updatesDeleted': summary.updates_deleted, + } + for summary in self.reverted_provider_summaries + ], + } + + @classmethod + def from_dict(cls, data: dict) -> 'RollbackResults': + """Create from dictionary loaded from S3.""" + return cls( + execution_name=data['executionName'], + skipped_provider_details=[ + ProviderSkippedDetails( + provider_id=detail['providerId'], + reason=detail['reason'], + ineligible_updates=[ + IneligibleUpdate( + record_type=update['recordType'], + type_of_update=update['typeOfUpdate'], + update_time=update['updateTime'], + reason=update['reason'], + license_type=update['licenseType'], + ) + for update in detail.get('ineligibleUpdates', []) + ], + ) + for detail in data.get('skippedProviderDetails', []) + ], + failed_provider_details=[ + ProviderFailedDetails( + provider_id=detail['providerId'], + error=detail['error'], + ) + for detail in data.get('failedProviderDetails', []) + ], + reverted_provider_summaries=[ + ProviderRevertedSummary( + provider_id=summary['providerId'], + licenses_reverted=[ + RevertedLicense( + jurisdiction=reverted_license['jurisdiction'], + license_type=reverted_license['licenseType'], + action=reverted_license['action'], + ) + for reverted_license in summary.get('licensesReverted', []) + ], + privileges_reverted=[ + RevertedPrivilege( + jurisdiction=reverted_privilege['jurisdiction'], + license_type=reverted_privilege['licenseType'], + action=reverted_privilege['action'], + ) + for reverted_privilege in summary.get('privilegesReverted', []) + ], + updates_deleted=summary.get('updatesDeleted', []), + ) + for summary in data.get('revertedProviderSummaries', []) + ], + ) + + +def rollback_license_upload(event: dict, context: LambdaContext): # noqa: ARG001 unused-argument + """ + Rollback invalid license uploads for a compact/jurisdiction/time window. + + This function queries the licenseUploadDateGSI to find all affected records, validates + rollback eligibility, reverts records to their pre-upload state, and publishes events. + Results are written to S3 to avoid state management in the step function. + + Input event structure: + { + 'compact': 'aslp', + 'jurisdiction': 'oh', + 'startDateTime': '2024-01-01T00:00:00Z', + 'endDateTime': '2024-01-01T23:59:59Z', + 'rollbackReason': 'Invalid data uploaded', + 'executionName': 'unique-execution-id', + 'providersProcessed': 0, + 'continueFromProviderId': None + } + + Returns: + { + 'rollbackStatus': 'IN_PROGRESS' | 'COMPLETE', + 'providersProcessed': int, + 'providersReverted': int, + 'providersSkipped': int, + 'providersFailed': int, + 'continueFromProviderId': str | None, + } + """ + execution_start_time = time.time() + max_execution_time = 12 * 60 # 12 minutes in seconds + + # Extract and validate input parameters + compact = event['compact'] + jurisdiction = event['jurisdiction'] + start_datetime_str = event['startDateTime'] + end_datetime_str = event['endDateTime'] + rollback_reason = event['rollbackReason'] + execution_name = event['executionName'] + providers_processed = event.get('providersProcessed', 0) + continue_from_provider_id = event.get('continueFromProviderId') + + # Parse and validate datetime parameters + try: + start_datetime = datetime.fromisoformat(start_datetime_str) + end_datetime = datetime.fromisoformat(end_datetime_str) + except ValueError as e: + logger.error(f'Invalid datetime format: {str(e)}') + return { + 'rollbackStatus': 'FAILED', + 'error': f'Invalid datetime format: {str(e)}', + } + + # Validate time window + if start_datetime >= end_datetime: + logger.error('Start time must be before end time') + return { + 'rollbackStatus': 'FAILED', + 'error': 'Start time must be before end time', + } + + time_window_seconds = (end_datetime - start_datetime).total_seconds() + if time_window_seconds > MAX_ROLLBACK_WINDOW_SECONDS: + logger.error(f'Time window exceeds maximum of {MAX_ROLLBACK_WINDOW_SECONDS / 86400} days') + return { + 'rollbackStatus': 'FAILED', + 'error': f'Time window cannot exceed {MAX_ROLLBACK_WINDOW_SECONDS / 86400} days', + } + + logger.info( + 'Starting license upload rollback', + compact=compact, + jurisdiction=jurisdiction, + start_datetime=start_datetime_str, + end_datetime=end_datetime_str, + execution_name=execution_name, + ) + + # Initialize S3 client and bucket + results_s3_key = f'licenseUploadRollbacks/{execution_name}/results.json' + + # Load existing results if this is a continuation + existing_results = _load_results_from_s3(results_s3_key, execution_name) + + # Initialize counters + providers_reverted = len(existing_results.reverted_provider_summaries) + providers_skipped = len(existing_results.skipped_provider_details) + providers_failed = len(existing_results.failed_provider_details) + + try: + # Query GSI for affected records across the time window + affected_provider_ids = _query_gsi_for_affected_providers( + compact, + jurisdiction, + start_datetime, + end_datetime, + ) + + # Convert to sorted list for consistent ordering across invocations + affected_provider_ids_list = sorted(affected_provider_ids) + + # If continuing from a previous invocation, slice the list to start from that provider + if continue_from_provider_id: + try: + start_index = affected_provider_ids_list.index(continue_from_provider_id) + affected_provider_ids_list = affected_provider_ids_list[start_index:] + logger.info( + f'Continuing from provider {continue_from_provider_id} (index {start_index}). ' + f'{len(affected_provider_ids_list)} providers remaining to process.' + ) + except ValueError: + # Provider ID in event input not found in list + # Log error and raise exception + logger.error( + f'Continue-from provider {continue_from_provider_id} not found in affected providers list.', + continue_from_provider_id=continue_from_provider_id, + affected_provider_ids_list=affected_provider_ids_list, + ) + raise + + # Process each provider + for provider_id in affected_provider_ids_list: + # Check time limit + elapsed_time = time.time() - execution_start_time + if elapsed_time > max_execution_time: + logger.info(f'Approaching time limit after {elapsed_time:.2f} seconds. Returning IN_PROGRESS status.') + + # Write current results to S3 + _write_results_to_s3(results_s3_key, existing_results) + + return { + 'rollbackStatus': 'IN_PROGRESS', + 'providersProcessed': providers_processed, + 'providersReverted': providers_reverted, + 'providersSkipped': providers_skipped, + 'providersFailed': providers_failed, + 'continueFromProviderId': provider_id, # Continue from next provider + 'compact': compact, + 'jurisdiction': jurisdiction, + 'startDateTime': start_datetime_str, + 'endDateTime': end_datetime_str, + 'rollbackReason': rollback_reason, + 'executionName': execution_name, + } + + # Process the provider + result = _process_provider_rollback( + provider_id=provider_id, + compact=compact, + jurisdiction=jurisdiction, + start_datetime=start_datetime, + end_datetime=end_datetime, + rollback_reason=rollback_reason, + execution_name=execution_name, + ) + + providers_processed += 1 + + # Update results based on outcome + if isinstance(result, ProviderRevertedSummary): + providers_reverted += 1 + existing_results.reverted_provider_summaries.append(result) + logger.info('Provider reverted successfully', provider_id=provider_id) + elif isinstance(result, ProviderSkippedDetails): + providers_skipped += 1 + existing_results.skipped_provider_details.append(result) + logger.info('Provider skipped due to ineligibility', provider_id=provider_id) + elif isinstance(result, ProviderFailedDetails): + providers_failed += 1 + existing_results.failed_provider_details.append(result) + logger.info('Provider failed to revert', provider_id=provider_id, error=result.error) + + logger.info( + 'processed provider', + total_providers_processed=providers_processed, + providers_reverted=providers_reverted, + providers_skipped=providers_skipped, + providers_failed=providers_failed, + ) + + # All providers processed successfully + logger.info( + 'Rollback complete', + providers_processed=providers_processed, + providers_skipped=providers_skipped, + providers_reverted=providers_reverted, + providers_failed=providers_failed, + ) + + # Write final results to S3 + _write_results_to_s3(results_s3_key, existing_results) + + return { + 'rollbackStatus': 'COMPLETE', + 'providersProcessed': providers_processed, + 'providersReverted': providers_reverted, + 'providersSkipped': providers_skipped, + 'providersFailed': providers_failed, + 'resultsS3Key': f's3://{config.disaster_recovery_results_bucket_name}/{results_s3_key}', + } + + except ClientError as e: + logger.error(f'Error during rollback: {str(e)}') + raise + + +def _query_gsi_for_affected_providers( + compact: str, + jurisdiction: str, + start_datetime: datetime, + end_datetime: datetime, +) -> set[str]: + """ + Query the licenseUploadDateGSI to find all affected provider IDs. + + Since the time window might span multiple months, we need to query each month separately. + """ + affected_provider_ids = set() + + # Generate list of year-month strings to query + current_date = start_datetime.replace(day=1) + end_month = end_datetime.replace(day=1) + + year_months = [] + while current_date <= end_month: + year_months.append(current_date.strftime('%Y-%m')) + # Move to next month + if current_date.month == 12: + current_date = current_date.replace(year=current_date.year + 1, month=1) + else: + current_date = current_date.replace(month=current_date.month + 1) + + start_epoch = int(start_datetime.timestamp()) + end_epoch = int(end_datetime.timestamp()) + + # Query each month + for year_month in year_months: + gsi_pk = f'C#{compact.lower()}#J#{jurisdiction.lower()}#D#{year_month}' + + query_kwargs = { + 'IndexName': config.license_upload_date_index_name, + 'KeyConditionExpression': ( + Key('licenseUploadDateGSIPK').eq(gsi_pk) + & Key('licenseUploadDateGSISK').between(f'TIME#{start_epoch}#', f'TIME#{end_epoch}#~') + ), + } + + while True: + response = config.provider_table.query(**query_kwargs) + + # Extract provider IDs from the results + for item in response.get('Items', []): + # The providerId is in the SK: TIME#{epoch}#LT#{license_type}#PID#{provider_id} + provider_id = item['providerId'] + affected_provider_ids.add(provider_id) + + # Check for pagination + last_evaluated_key = response.get('LastEvaluatedKey') + if not last_evaluated_key: + break + + query_kwargs['ExclusiveStartKey'] = last_evaluated_key + + logger.info(f'Found {len(affected_provider_ids)} unique providers affected by upload window') + return affected_provider_ids + + +def _process_provider_rollback( + provider_id: str, + compact: str, + jurisdiction: str, + start_datetime: datetime, + end_datetime: datetime, + rollback_reason: str, + execution_name: str, +) -> ProviderRevertedSummary | ProviderSkippedDetails | ProviderFailedDetails: + """ + Process rollback for a single provider. + + Returns one of: + - ProviderRevertedSummary: If provider was successfully reverted + - ProviderSkippedDetails: If provider was skipped due to ineligibility + - ProviderFailedDetails: If an error occurred during processing + """ + logger.info('Processing provider rollback', provider_id=provider_id) + + try: + # Build transactions and check eligibility in a single pass + # If ineligible updates are found, this will return a ProviderSkippedDetails + result = _build_and_execute_revert_transactions( + upload_window_start_datetime=start_datetime, + upload_window_end_datetime=end_datetime, + compact=compact, + jurisdiction=jurisdiction, + provider_id=provider_id, + ) + + # If provider was skipped due to ineligibility, return early + if isinstance(result, ProviderSkippedDetails): + return result + except ProviderRollbackFailedException as e: # noqa BLE001 + logger.error('Error processing provider rollback', provider_id=provider_id, exc_info=e) + return ProviderFailedDetails( + provider_id=provider_id, + error=f'Failed to rollback updates for provider. Manual review required: {str(e)}', + ) + + # Publish events for successful rollback + _publish_revert_events(result, compact, rollback_reason, start_datetime, end_datetime, execution_name) + return result + + +def _extract_sk_from_transaction_item(transaction_item: dict) -> str | None: + """ + Extract the sort key (SK) from a transaction item. + + Transaction items can be Put, Delete, or Update operations. + Returns the SK if found, None otherwise. + """ + if 'Put' in transaction_item: + return transaction_item['Put']['Item'].get('sk') + if 'Delete' in transaction_item: + return transaction_item['Delete']['Key'].get('sk') + if 'Update' in transaction_item: + return transaction_item['Update']['Key'].get('sk') + return None + + +def _perform_transaction(transaction_items: list[dict], provider_id: str) -> None: + logger.info(f'Executing {len(transaction_items)} transaction items in batches of 100') + + for i in range(0, len(transaction_items), 100): + batch = transaction_items[i : i + 100] + # Use Table resource's client for automatic type conversion + try: + config.provider_table.meta.client.transact_write_items(TransactItems=batch) + logger.info(f'Executed batch {i // 100 + 1} with {len(batch)} items') + except ClientError as e: + # Extract all SKs from the failed transaction batch for debugging + failed_sks = [_extract_sk_from_transaction_item(item) for item in batch] + # filter out null values + failed_sks = [sk for sk in failed_sks if sk is not None] + + logger.error( + 'Transaction batch failed for provider', + provider_id=provider_id, + batch_number=i // 100 + 1, + batch_size=len(batch), + failed_sks=failed_sks, + error=str(e), + ) + raise ProviderRollbackFailedException(message=str(e)) from e + + +def _check_for_orphaned_update_records( + provider_records: ProviderUserRecords, +) -> IneligibleUpdate | None: + """ + Check if there are any license update records without associated top-level license records. + + :param provider_records: The provider's records + :return: IneligibleUpdate if orphaned updates are found, None otherwise + """ + # Get all license update records + all_license_updates = provider_records.get_all_license_update_records() + + # Extract unique (jurisdiction, license_type) pairs from update records + license_keys_from_updates: set[tuple[str, str]] = set() + + for update in all_license_updates: + license_keys_from_updates.add((update.jurisdiction, update.licenseType)) + + # Check if each license key has a corresponding top-level license record + for license_jurisdiction, license_type in license_keys_from_updates: + # Try to find the license record + license_record = next( + ( + record + for record in provider_records.get_license_records() + if record.jurisdiction == license_jurisdiction and record.licenseType == license_type + ), + None, + ) + + if license_record is None: + # Found an orphaned update record + return IneligibleUpdate( + record_type='licenseUpdate', + type_of_update='Orphaned', + update_time='N/A', + license_type=license_type, + reason=f'License update record(s) exist for license in jurisdiction ' + f'{license_jurisdiction} with type {license_type}, but no corresponding top-level ' + f'license record was found. This indicates data inconsistency. Manual review required.', + ) + + return None + + +def _build_and_execute_revert_transactions( + upload_window_start_datetime: datetime, + upload_window_end_datetime: datetime, + compact: str, + jurisdiction: str, + provider_id: str, +) -> ProviderRevertedSummary | ProviderSkippedDetails: + """ + Build and execute DynamoDB transactions to revert provider records. + + This function processes all records in a single pass: + - Checks eligibility (returns ProviderSkippedDetails if ineligible) + - Builds transaction items + - Executes transactions + + Returns either a summary of what was reverted or details about why the provider was skipped. + """ + # Split transaction lists into first tier/second tier lists (license/privilege/provider first tier, updates second) + # then merge the two lists into a single list of transaction items + primary_record_transaction_items = [] # License, privilege, and provider records + update_record_transactions_items = [] # Update records (license updates, privilege updates, provider updates) + table_name = config.provider_table_name + reverted_licenses = [] + reverted_privileges = [] + updates_deleted_sks = [] # List of SKs for deleted update records + ineligible_updates: list[IneligibleUpdate] = [] + + # Helper functions for cleaner item building + def add_put(item: dict, update_record: bool): + """ + Add a Put operation to the appropriate list. + + :param item: The item to put + :param update_record: True if the item is an update record, False if it is a primary record + """ + transaction_item = { + 'Put': { + 'TableName': table_name, + 'Item': item, + } + } + if update_record: + update_record_transactions_items.append(transaction_item) + else: + primary_record_transaction_items.append(transaction_item) + + def add_delete(pk: str, sk: str, update_record: bool): + """ + Add a Delete operation. + + :param pk: Partition key + :param sk: Sort key - used to determine if this is an update record + :param update_record: True if the item is an update record, False if it is a primary record + """ + transaction_item = { + 'Delete': { + 'TableName': table_name, + 'Key': {'pk': pk, 'sk': sk}, + } + } + if update_record: + update_record_transactions_items.append(transaction_item) + else: + primary_record_transaction_items.append(transaction_item) + + # Fetch all provider records including all update tiers + try: + provider_records = config.data_client.get_provider_user_records( + compact=compact, + provider_id=provider_id, + # tier three includes all update records for the provider + include_update_tier=UpdateTierEnum.TIER_THREE, + ) + except ValidationError as e: + logger.info('provider record data failed schema validation. Skipping provider', exc_info=e) + raise ProviderRollbackFailedException(message=f'Validation error: {str(e)}') from e + + # Step 1: Check for license update records without top-level license records + orphaned_update_check = _check_for_orphaned_update_records(provider_records) + if orphaned_update_check is not None: + ineligible_updates.append(orphaned_update_check) + + # Step 2: Check provider updates - any after start_datetime make provider ineligible + provider_updates = provider_records.get_all_provider_update_records() + for update in provider_updates: + if update.createDate >= upload_window_start_datetime: + ineligible_updates.append( + IneligibleUpdate( + record_type='providerUpdate', + type_of_update=update.updateType, + update_time=update.createDate.isoformat(), + reason='Provider update occurred after rollback start time. Manual review required.', + # provider updates are not specific to a license type + license_type='N/A', + ) + ) + + # Step 3: Process each license record for the jurisdiction + license_records = provider_records.get_license_records(filter_condition=lambda x: x.jurisdiction == jurisdiction) + + reverted_licenses_dict = [] + + for license_record in license_records: + privileges_associated_with_license = provider_records.get_privileges_associated_with_license( + license_jurisdiction=license_record.jurisdiction, + license_type_abbreviation=license_record.licenseTypeAbbreviation, + ) + + # Check if any privileges were issued for this license since the upload start date + for privilege in privileges_associated_with_license: + if privilege.dateOfIssuance >= upload_window_start_datetime: + ineligible_updates.append( + IneligibleUpdate( + record_type='privilegeUpdate', + type_of_update='Issuance', + update_time=privilege.dateOfIssuance.isoformat(), + license_type=license_record.licenseType, + reason=f"Privilege in jurisdiction '{privilege.jurisdiction}' issued after license upload. " + 'Manual review required.', + ) + ) + # Check updates associated with this privilege that are after the start_datetime + privilege_updates_after_start_time = provider_records.get_update_records_for_privilege( + jurisdiction=privilege.jurisdiction, + license_type=privilege.licenseType, + filter_condition=lambda x: x.createDate >= upload_window_start_datetime, + ) + + # Check privilege updates for eligibility + for privilege_update in privilege_updates_after_start_time: + if privilege_update.updateType != PRIVILEGE_LICENSE_DEACTIVATION_CATEGORY: + # Non-license-deactivation privilege update makes provider ineligible for rollback + ineligible_updates.append( + IneligibleUpdate( + record_type='privilegeUpdate', + type_of_update=privilege_update.updateType, + update_time=privilege_update.createDate.isoformat(), + license_type=privilege_update.licenseType, + # include privilege jurisdiction in reason + reason=f"Privilege in jurisdiction '{privilege_update.jurisdiction}' was updated " + f'with a change unrelated to license upload. Manual review required.', + ) + ) + elif privilege_update.createDate > upload_window_end_datetime: + # privilege update after upload window makes provider ineligible + ineligible_updates.append( + IneligibleUpdate( + record_type='privilegeUpdate', + type_of_update=privilege_update.updateType, + update_time=privilege_update.createDate.isoformat(), + license_type=privilege_update.licenseType, + # include privilege jurisdiction in reason + reason=f"Privilege in jurisdiction '{privilege_update.jurisdiction}' was deactivated " + 'after rollback end time. Manual review required.', + ) + ) + else: + # License deactivation within window cause privilege deactivation - revert the deactivation + serialized_privilege_update = privilege_update.serialize_to_database_record() + add_delete(serialized_privilege_update['pk'], serialized_privilege_update['sk'], update_record=True) + updates_deleted_sks.append(serialized_privilege_update['sk']) + logger.info('Will delete privilege deactivation update record if provider is eligible for rollback') + + # Reactivate the privilege + privilege_record = provider_records.get_specific_privilege_record( + jurisdiction=privilege_update.jurisdiction, + license_abbreviation=license_record.licenseTypeAbbreviation, + ) + if privilege_record: + logger.info( + 'privilege record found associated with deactivation, reactivating privilege', + provider_id=provider_id, + privilege_jurisdiction=privilege_record.jurisdiction, + license_type=privilege_record.licenseType, + ) + # Remove the licenseDeactivatedStatus field to reactivate using UPDATE operation + serialized_privilege = privilege_record.serialize_to_database_record() + primary_record_transaction_items.append( + { + 'Update': { + 'TableName': table_name, + 'Key': {'pk': serialized_privilege['pk'], 'sk': serialized_privilege['sk']}, + 'UpdateExpression': 'REMOVE licenseDeactivatedStatus', + } + } + ) + logger.info('Will reactivate privilege record if provider is eligible for rollback') + + reverted_privileges.append( + RevertedPrivilege( + jurisdiction=privilege_record.jurisdiction, + license_type=privilege_record.licenseType, + action='REACTIVATED', + ) + ) + + # Get license updates for this license after start_datetime + license_updates_after_start = provider_records.get_update_records_for_license( + jurisdiction=license_record.jurisdiction, + license_type=license_record.licenseType, + filter_condition=lambda x: x.createDate >= upload_window_start_datetime, + ) + + # check license updates for eligibility + license_updates_in_window = [] + for license_update in license_updates_after_start: + if ( + license_update.updateType not in LICENSE_UPLOAD_UPDATE_CATEGORIES + or license_update.createDate > upload_window_end_datetime + ): + # Non-upload-related license updates make provider ineligible + ineligible_updates.append( + IneligibleUpdate( + record_type='licenseUpdate', + type_of_update=license_update.updateType, + update_time=license_update.createDate.isoformat(), + license_type=license_update.licenseType, + reason='License was updated with a change unrelated to license upload or the update ' + 'occurred after rollback end time. Manual review required.', + ) + ) + else: + # Upload-related update within window - mark for deletion + license_updates_in_window.append(license_update) + serialized_license_update = license_update.serialize_to_database_record() + add_delete(serialized_license_update['pk'], serialized_license_update['sk'], update_record=True) + updates_deleted_sks.append(serialized_license_update['sk']) + logger.info( + 'Will delete license update record if provider is eligible for rollback', + update_type=license_update.updateType, + license_type=license_update.licenseType, + ) + + # if license record was created during the window, delete it + if ( + license_record.firstUploadDate is not None + and upload_window_start_datetime <= license_record.firstUploadDate <= upload_window_end_datetime + ): + serialized_license_record = license_record.serialize_to_database_record() + add_delete(serialized_license_record['pk'], serialized_license_record['sk'], update_record=False) + logger.info('Will delete license record (created during upload) if provider is eligible for rollback') + reverted_licenses.append( + RevertedLicense( + jurisdiction=license_record.jurisdiction, + license_type=license_record.licenseType, + action='DELETE', + ) + ) + # license was not first uploaded during the upload window, revert it to last previous state before the upload + else: + # if the provider is ineligible for rollback, the list of license updates may be empty, and we need to + # defensively check for that here and continue to the next license + if not license_updates_in_window: + continue + + # Find the earliest update in the window to get the previous state + license_updates_in_window.sort(key=lambda x: x.createDate) + earliest_update_in_window = license_updates_in_window[0] + + # License existed before - revert to previous state + reverted_license_data = license_record.to_dict() + reverted_license_data.update(earliest_update_in_window.previous) + + reverted_license = LicenseData.create_new(reverted_license_data) + serialized_reverted_license = reverted_license.serialize_to_database_record() + + add_put(serialized_reverted_license, update_record=False) + logger.info('Reverting license record to pre-upload state') + + # Track for provider record regeneration + license_schema = LicenseRecordSchema() + reverted_licenses_dict.append(license_schema.load(serialized_reverted_license)) + + reverted_licenses.append( + RevertedLicense( + jurisdiction=license_record.jurisdiction, + license_type=license_record.licenseType, + action='REVERT', + ) + ) + + # Check if provider is ineligible for rollback + if ineligible_updates: + logger.info( + 'Provider not eligible for automatic rollback', + provider_id=provider_id, + ineligible_updates=ineligible_updates, + ) + return ProviderSkippedDetails( + provider_id=provider_id, + reason='Provider has updates that are either unrelated to license upload or occurred after' + ' rollback end time. Manual review required.', + ineligible_updates=ineligible_updates, + ) + + # process primary records first, then update records + transaction_items = primary_record_transaction_items + update_record_transactions_items + + if not transaction_items: + # This should never happen, as it means that somehow the GSI query returned this provider id within + # the search results, but the provider was not either skipped over or had something to revert as we expect. + # If we do get here, we will exit the lambda in a failed state, as there is something unexpected happening that + # needs to be investigated before we attempt to roll back any other providers. + message = ( + 'No transaction items to execute for provider. This is an unexpected state that should be ' + 'investigated before attempting to roll back any other providers' + ) + logger.error(message, provider_id=provider_id) + raise CCInternalException(message=f'{message} provider_id: {provider_id}') + + _perform_transaction(transaction_items, provider_id) + try: + # Now read all the license records for the provider and update the provider record + provider_records_after_rollback = config.data_client.get_provider_user_records( + compact=compact, provider_id=provider_id + ) + top_level_provider_record: ProviderData = provider_records_after_rollback.get_provider_record() + except (CCNotFoundException, CCInternalException) as e: + # This would most likely happen if the top level provider record was somehow deleted by another process. + # We don't ever expect to get into this state, so we are going to let this bubble to the top and end the entire + # process, to ensure we are not putting the system into a worse state. + logger.error( + 'Expected top level provider record not found after rollback. ' + 'Ending workflow to prevent risk of data corruption.', + provider_id=provider_id, + exc_info=e, + ) + raise + + # Create a new list for provider record updates (all first tier items) + primary_record_transaction_items.clear() + + try: + privilege_records: list[PrivilegeData] = provider_records_after_rollback.get_privilege_records() + best_license = provider_records_after_rollback.find_best_license_in_current_known_licenses() + updated_provider_record = ProviderRecordUtility.populate_provider_record( + current_provider_record=top_level_provider_record, + license_record=best_license.to_dict(), + privilege_records=[privilege.to_dict() for privilege in privilege_records], + ) + add_put(updated_provider_record.serialize_to_database_record(), update_record=False) + except CCNotFoundException: + # All licenses for the provider were removed as part of the rollback, meaning the provider + # needs to be removed as well. We first check to make sure there are no other record types + if len(provider_records_after_rollback.provider_records) > 1: + # We never expect this to happen, since license records should not have been removed if there were any + # privilege or other non-upload records found for the provider. If we hit this case, we will end the + # entire process to ensure we are not putting the system into a worse state. + message = ( + 'No licenses found for provider after rollback, but other record types still exist. ' + 'Killing process to prevent potential data corruption.' + ) + logger.error(message, provider_id=provider_id) + raise CCInternalException(message=str(message)) # noqa: B904 + + logger.info('Only top level provider record found. Deleting record', provider_id=provider_id) + serialized_provider_record = top_level_provider_record.serialize_to_database_record() + add_delete(pk=serialized_provider_record['pk'], sk=serialized_provider_record['sk'], update_record=False) + + _perform_transaction(primary_record_transaction_items, provider_id) + + logger.info( + 'Completed rollback for provider', + provider_id=provider_id, + licenses_reverted=reverted_licenses, + privileges_reverted=reverted_privileges, + updates_deleted=updates_deleted_sks, + ) + return ProviderRevertedSummary( + provider_id=provider_id, + licenses_reverted=reverted_licenses, + privileges_reverted=reverted_privileges, + updates_deleted=updates_deleted_sks, + ) + + +def _publish_revert_events( + revert_summary: ProviderRevertedSummary, + compact: str, + rollback_reason: str, + start_datetime: datetime, + end_datetime: datetime, + execution_name: str, +): + """ + Publish revert events for all reverted licenses and privileges. + + :param revert_summary: Summary of reverted provider records + :param compact: The compact name + :param rollback_reason: The reason for the rollback + :param start_datetime: The start time of the rollback window + :param end_datetime: The end time of the rollback window + :param execution_name: The execution name for the rollback operation + """ + with EventBatchWriter(config.events_client) as event_writer: + # Publish license revert events + for reverted_license in revert_summary.licenses_reverted: + try: + config.event_bus_client.publish_license_revert_event( + source='org.compactconnect.disaster-recovery', + compact=compact, + provider_id=revert_summary.provider_id, + jurisdiction=reverted_license.jurisdiction, + license_type=reverted_license.license_type, + rollback_reason=rollback_reason, + start_time=start_datetime, + end_time=end_datetime, + execution_name=execution_name, + event_batch_writer=event_writer, + ) + except Exception as e: # noqa BLE001 + # this event publishing is not business critical, so we log the error and move on + logger.error( + 'Unable to publish license revert event', + compact=compact, + provider_id=revert_summary.provider_id, + jurisdiction=reverted_license.jurisdiction, + license_type=reverted_license.license_type, + rollback_reason=rollback_reason, + start_time=start_datetime, + end_time=end_datetime, + error=str(e), + ) + + # Publish privilege revert events + for reverted_privilege in revert_summary.privileges_reverted: + try: + config.event_bus_client.publish_privilege_revert_event( + source='org.compactconnect.disaster-recovery', + compact=compact, + provider_id=revert_summary.provider_id, + jurisdiction=reverted_privilege.jurisdiction, + license_type=reverted_privilege.license_type, + rollback_reason=rollback_reason, + start_time=start_datetime, + end_time=end_datetime, + execution_name=execution_name, + event_batch_writer=event_writer, + ) + except Exception as e: # noqa BLE001 + # this event publishing is not business critical, so we log the error and move on + logger.error( + 'Unable to publish privilege revert event', + compact=compact, + provider_id=revert_summary.provider_id, + jurisdiction=reverted_privilege.jurisdiction, + license_type=reverted_privilege.license_type, + rollback_reason=rollback_reason, + start_time=start_datetime, + end_time=end_datetime, + error=str(e), + ) + + +def _load_results_from_s3(key: str, execution_name: str) -> RollbackResults: + """Load existing results from S3.""" + try: + response = config.s3_client.get_object(Bucket=config.disaster_recovery_results_bucket_name, Key=key) + data = json.loads(response['Body'].read().decode('utf-8')) + return RollbackResults.from_dict(data) + except config.s3_client.exceptions.NoSuchKey: + # First execution, no existing results + return RollbackResults(execution_name=execution_name) + except Exception as e: + logger.error(f'Error loading results from S3: {str(e)}') + raise + + +def _write_results_to_s3(key: str, results: RollbackResults): + """Write results to S3 with server-side encryption.""" + try: + config.s3_client.put_object( + Bucket=config.disaster_recovery_results_bucket_name, + Key=key, + Body=json.dumps(results.to_dict(), indent=2), + ContentType='application/json', + ) + logger.info('Results written to S3', bucket=config.disaster_recovery_results_bucket_name, key=key) + # handle json serialization errors + except TypeError as e: + logger.error(f'Error writing results to S3: {str(e)}') + raise + # handle other errors by logging the full object and raising the exception + except Exception as e: + logger.error(f'Error writing results to S3: {str(e)}', results=results.to_dict()) + raise diff --git a/backend/compact-connect/lambdas/python/disaster-recovery/tests/__init__.py b/backend/compact-connect/lambdas/python/disaster-recovery/tests/__init__.py index 5462ba841..d8d98d949 100644 --- a/backend/compact-connect/lambdas/python/disaster-recovery/tests/__init__.py +++ b/backend/compact-connect/lambdas/python/disaster-recovery/tests/__init__.py @@ -1,3 +1,4 @@ +import json import os from unittest import TestCase from unittest.mock import MagicMock @@ -14,6 +15,91 @@ def setUpClass(cls): 'DEBUG': 'true', 'ALLOWED_ORIGINS': '["https://example.org"]', 'AWS_DEFAULT_REGION': 'us-east-1', + 'DISASTER_RECOVERY_RESULTS_BUCKET_NAME': 'rollback-results-bucket', + 'EVENT_BUS_NAME': 'license-data-events', + 'PROVIDER_TABLE_NAME': 'provider-table', + 'RATE_LIMITING_TABLE_NAME': 'rate-limiting-table', + 'SSN_TABLE_NAME': 'ssn-table', + 'COMPACT_CONFIGURATION_TABLE_NAME': 'compact-configuration-table', + 'ENVIRONMENT_NAME': 'test', + 'PROV_FAM_GIV_MID_INDEX_NAME': 'providerFamGivMid', + 'FAM_GIV_INDEX_NAME': 'famGiv', + 'LICENSE_GSI_NAME': 'licenseGSI', + 'LICENSE_UPLOAD_DATE_INDEX_NAME': 'licenseUploadDateGSI', + 'PROV_DATE_OF_UPDATE_INDEX_NAME': 'providerDateOfUpdate', + 'SSN_INDEX_NAME': 'ssnIndex', + 'COMPACTS': '["aslp", "octp", "coun"]', + 'JURISDICTIONS': json.dumps( + [ + 'al', + 'ak', + 'az', + 'ar', + 'ca', + 'co', + 'ct', + 'de', + 'dc', + 'fl', + 'ga', + 'hi', + 'id', + 'il', + 'in', + 'ia', + 'ks', + 'ky', + 'la', + 'me', + 'md', + 'ma', + 'mi', + 'mn', + 'ms', + 'mo', + 'mt', + 'ne', + 'nv', + 'nh', + 'nj', + 'nm', + 'ny', + 'nc', + 'nd', + 'oh', + 'ok', + 'or', + 'pa', + 'pr', + 'ri', + 'sc', + 'sd', + 'tn', + 'tx', + 'ut', + 'vt', + 'va', + 'vi', + 'wa', + 'wv', + 'wi', + 'wy', + ] + ), + 'LICENSE_TYPES': json.dumps( + { + 'aslp': [ + {'name': 'audiologist', 'abbreviation': 'aud'}, + {'name': 'speech-language pathologist', 'abbreviation': 'slp'}, + ], + }, + ), }, ) + # Monkey-patch config object to be sure we have it based + # on the env vars we set above + import cc_common.config + + cls.config = cc_common.config._Config() # noqa: SLF001 protected-access + cc_common.config.config = cls.config cls.mock_context = MagicMock(name='MockLambdaContext', spec=LambdaContext) diff --git a/backend/compact-connect/lambdas/python/disaster-recovery/tests/function/__init__.py b/backend/compact-connect/lambdas/python/disaster-recovery/tests/function/__init__.py index 4d0c0808d..4cf83bde7 100644 --- a/backend/compact-connect/lambdas/python/disaster-recovery/tests/function/__init__.py +++ b/backend/compact-connect/lambdas/python/disaster-recovery/tests/function/__init__.py @@ -25,6 +25,15 @@ def setUp(self): # noqa: N801 invalid-name self.mock_source_table_arn = f'arn:aws:dynamodb:us-east-1:767398110685:table/{self.mock_source_table_name}' self.build_resources() + # these must be imported within the tests, since they import modules which require + # environment variables that are not set until the TstLambdas class is initialized + import cc_common.config + from common_test.test_data_generator import TestDataGenerator + + cc_common.config.config = cc_common.config._Config() # noqa: SLF001 protected-access + self.config = cc_common.config.config + self.test_data_generator = TestDataGenerator + self.addCleanup(self.delete_resources) def build_resources(self): @@ -32,6 +41,17 @@ def build_resources(self): # cleanup and restoration process regardless of the table that is being recovered self.mock_source_table = self.create_mock_table(table_name=self.mock_source_table_name) self.mock_destination_table = self.create_mock_table(table_name=self.mock_destination_table_name) + self.create_provider_table() + self.create_rollback_results_bucket() + self.create_event_bus() + + def create_rollback_results_bucket(self): + self._rollback_results_bucket = boto3.resource('s3').create_bucket( + Bucket=os.environ['DISASTER_RECOVERY_RESULTS_BUCKET_NAME'] + ) + + def create_event_bus(self): + self._event_bus = boto3.client('events').create_event_bus(Name=os.environ['EVENT_BUS_NAME']) def create_mock_table(self, table_name: str): return boto3.resource('dynamodb').create_table( @@ -44,6 +64,66 @@ def create_mock_table(self, table_name: str): BillingMode='PAY_PER_REQUEST', ) + def create_provider_table(self): + self._provider_table = boto3.resource('dynamodb').create_table( + AttributeDefinitions=[ + {'AttributeName': 'pk', 'AttributeType': 'S'}, + {'AttributeName': 'sk', 'AttributeType': 'S'}, + {'AttributeName': 'providerFamGivMid', 'AttributeType': 'S'}, + {'AttributeName': 'providerDateOfUpdate', 'AttributeType': 'S'}, + {'AttributeName': 'licenseGSIPK', 'AttributeType': 'S'}, + {'AttributeName': 'licenseGSISK', 'AttributeType': 'S'}, + {'AttributeName': 'licenseUploadDateGSIPK', 'AttributeType': 'S'}, + {'AttributeName': 'licenseUploadDateGSISK', 'AttributeType': 'S'}, + ], + TableName=os.environ['PROVIDER_TABLE_NAME'], + KeySchema=[{'AttributeName': 'pk', 'KeyType': 'HASH'}, {'AttributeName': 'sk', 'KeyType': 'RANGE'}], + BillingMode='PAY_PER_REQUEST', + GlobalSecondaryIndexes=[ + { + 'IndexName': os.environ['PROV_FAM_GIV_MID_INDEX_NAME'], + 'KeySchema': [ + {'AttributeName': 'sk', 'KeyType': 'HASH'}, + {'AttributeName': 'providerFamGivMid', 'KeyType': 'RANGE'}, + ], + 'Projection': {'ProjectionType': 'ALL'}, + }, + { + 'IndexName': os.environ['PROV_DATE_OF_UPDATE_INDEX_NAME'], + 'KeySchema': [ + {'AttributeName': 'sk', 'KeyType': 'HASH'}, + {'AttributeName': 'providerDateOfUpdate', 'KeyType': 'RANGE'}, + ], + 'Projection': {'ProjectionType': 'ALL'}, + }, + { + 'IndexName': os.environ['LICENSE_GSI_NAME'], + 'KeySchema': [ + {'AttributeName': 'licenseGSIPK', 'KeyType': 'HASH'}, + {'AttributeName': 'licenseGSISK', 'KeyType': 'RANGE'}, + ], + 'Projection': {'ProjectionType': 'ALL'}, + }, + { + 'IndexName': 'licenseUploadDateGSI', + 'KeySchema': [ + {'AttributeName': 'licenseUploadDateGSIPK', 'KeyType': 'HASH'}, + {'AttributeName': 'licenseUploadDateGSISK', 'KeyType': 'RANGE'}, + ], + 'Projection': { + 'ProjectionType': 'INCLUDE', + 'NonKeyAttributes': [ + 'providerId', + ], + }, + }, + ], + ) + def delete_resources(self): self.mock_source_table.delete() self.mock_destination_table.delete() + self._provider_table.delete() + self._rollback_results_bucket.objects.delete() + self._rollback_results_bucket.delete() + boto3.client('events').delete_event_bus(Name=os.environ['EVENT_BUS_NAME']) diff --git a/backend/compact-connect/lambdas/python/disaster-recovery/tests/function/test_rollback_license_upload.py b/backend/compact-connect/lambdas/python/disaster-recovery/tests/function/test_rollback_license_upload.py new file mode 100644 index 000000000..85ff56503 --- /dev/null +++ b/backend/compact-connect/lambdas/python/disaster-recovery/tests/function/test_rollback_license_upload.py @@ -0,0 +1,1531 @@ +""" +Tests for the license upload rollback handler. + +These tests verify the rollback functionality including: +- GSI queries for affected providers +- Eligibility validation +- Revert plan determination +- Transaction execution +- Event publishing +- S3 result management +""" + +import json +from datetime import datetime, timedelta +from unittest.mock import ANY, Mock, patch + +import pytest +from cc_common.data_model.update_tier_enum import UpdateTierEnum +from cc_common.exceptions import CCNotFoundException +from moto import mock_aws + +from . import TstFunction + +MOCK_DATETIME_STRING = '2025-10-23T08:15:00+00:00' +MOCK_ORIGINAL_GIVEN_NAME = 'originalGiven' +MOCK_ORIGINAL_FAMILY_NAME = 'originalFamily' +MOCK_UPDATED_GIVEN_NAME = 'updatedGiven' +MOCK_UPDATED_FAMILY_NAME = 'updatedFamily' +MOCK_PROVIDER_ID = 'ba880c7c-5ed3-4be4-8ad5-c8558f58ef6f' +MOCK_EXECUTION_NAME = 'test-execution-123' + + +@mock_aws +@patch('cc_common.config._Config.current_standard_datetime', datetime.fromisoformat(MOCK_DATETIME_STRING)) +class TestRollbackLicenseUpload(TstFunction): + """Test class for license upload rollback handler.""" + + def setUp(self): + """Set up test fixtures before each test method.""" + super().setUp() + # Create sample test data + self.compact = 'aslp' + self.license_jurisdiction = 'oh' + self.provider_id = MOCK_PROVIDER_ID + # default upload time between start and end time + self.default_upload_datetime = datetime.fromisoformat(MOCK_DATETIME_STRING) - timedelta(hours=1) + self.default_start_datetime = self.default_upload_datetime - timedelta(days=1) + self.default_end_datetime = self.default_upload_datetime + from cc_common.data_model.schema.common import UpdateCategory + + self.update_categories = UpdateCategory + + self.provider_data = self._add_provider_record() + + def _generate_test_event(self): + return { + 'compact': self.compact, + 'jurisdiction': self.license_jurisdiction, + 'startDateTime': self.default_start_datetime.isoformat(), + 'endDateTime': self.default_end_datetime.isoformat(), + 'rollbackReason': 'Test rollback', + 'executionName': MOCK_EXECUTION_NAME, + 'providersProcessed': 0, + } + + def _add_provider_record(self, provider_id: str | None = None): + if provider_id is None: + provider_id = self.provider_id + + # add provider record to provider table + return self.test_data_generator.put_default_provider_record_in_provider_table( + { + 'providerId': provider_id, + 'compact': self.compact, + 'jurisdiction': self.license_jurisdiction, + 'dateOfUpdate': self.default_start_datetime - timedelta(days=30), + } + ) + + # Helper methods for setting up test scenarios + def _when_provider_had_license_created_from_upload(self): + """ + Set up a scenario where a provider had a license created during the upload window. + Returns the created license data. + """ + return self.test_data_generator.put_default_license_record_in_provider_table( + { + 'providerId': self.provider_id, + 'compact': self.compact, + 'jurisdiction': self.license_jurisdiction, + 'firstUploadDate': self.default_upload_datetime, + 'dateOfUpdate': self.default_upload_datetime, + } + ) + + def _when_provider_had_license_updated_from_upload( + self, upload_datetime: datetime = None, license_upload_datetime: datetime = None, provider_id: str = None + ): + """ + Set up a scenario where a provider had an existing license updated during the upload window. + Returns the license and its update record. + """ + if upload_datetime is None: + upload_datetime = self.default_upload_datetime + if license_upload_datetime is None: + # by default, the license was originally uploaded a day before the bad upload + license_upload_datetime = self.default_start_datetime - timedelta(days=1) + if provider_id is None: + provider_id = self.provider_id + + # Create original license before upload window, unless different time is provided + original_license = self.test_data_generator.put_default_license_record_in_provider_table( + { + 'providerId': provider_id, + 'compact': self.compact, + 'jurisdiction': self.license_jurisdiction, + 'familyName': MOCK_ORIGINAL_FAMILY_NAME, + 'givenName': MOCK_ORIGINAL_GIVEN_NAME, + 'dateOfUpdate': self.default_start_datetime - timedelta(days=30), + # simulate license record that has not expired yet + 'dateOfExpiration': (self.default_start_datetime + timedelta(days=30)).date(), + 'firstUploadDate': license_upload_datetime, + 'licenseStatus': 'active', + } + ) + + # Create update record within upload window to simulate license deactivation + license_update = self.test_data_generator.put_default_license_update_record_in_provider_table( + { + 'providerId': provider_id, + 'compact': self.compact, + 'jurisdiction': self.license_jurisdiction, + 'licenseType': original_license.licenseType, + 'updateType': self.update_categories.DEACTIVATION, + 'createDate': upload_datetime, + 'effectiveDate': upload_datetime, + 'previous': { + 'dateOfExpiration': original_license.dateOfExpiration, + 'licenseStatus': 'active', + **original_license.to_dict(), + }, + 'updatedValues': { + # simulate accidentally changing the expiration to last year + 'dateOfExpiration': (upload_datetime - timedelta(days=365)).date(), + 'licenseStatus': 'inactive', + 'familyName': MOCK_UPDATED_FAMILY_NAME, + 'givenName': MOCK_UPDATED_GIVEN_NAME, + }, + } + ) + + # Update the license record to reflect the new expiration and status + updated_license = self.test_data_generator.put_default_license_record_in_provider_table( + { + 'providerId': provider_id, + 'compact': self.compact, + 'jurisdiction': self.license_jurisdiction, + 'familyName': MOCK_UPDATED_FAMILY_NAME, + 'givenName': MOCK_UPDATED_GIVEN_NAME, + 'dateOfUpdate': upload_datetime, + 'dateOfExpiration': (upload_datetime - timedelta(days=365)).date(), + 'licenseStatus': 'inactive', + 'firstUploadDate': license_upload_datetime, + } + ) + + return original_license, license_update, updated_license + + def _when_license_was_updated_twice(self, provider_id: str = None): + """ + Set up a scenario where a provider had an existing license updated twice during the upload window. + Returns the original license, both update records, and the final updated license. + """ + first_upload_datetime = self.default_start_datetime + timedelta(minutes=30) + second_upload_datetime = self.default_start_datetime + timedelta(hours=1) + if provider_id is None: + provider_id = self.provider_id + + # License was originally uploaded before the upload window + license_upload_datetime = self.default_start_datetime - timedelta(days=1) + + # Create original license before upload window + original_license = self.test_data_generator.put_default_license_record_in_provider_table( + { + 'providerId': provider_id, + 'compact': self.compact, + 'jurisdiction': self.license_jurisdiction, + 'familyName': MOCK_ORIGINAL_FAMILY_NAME, + 'givenName': MOCK_ORIGINAL_GIVEN_NAME, + 'dateOfExpiration': (self.default_start_datetime + timedelta(days=30)).date(), + 'firstUploadDate': license_upload_datetime, + 'licenseStatus': 'active', + } + ) + + # old update record before upload window + existing_update = self.test_data_generator.put_default_license_update_record_in_provider_table( + { + 'providerId': provider_id, + 'compact': self.compact, + 'jurisdiction': self.license_jurisdiction, + 'licenseType': original_license.licenseType, + 'updateType': self.update_categories.LICENSE_UPLOAD_UPDATE_OTHER, + # last update was 5 days before upload, this should be ignored + 'createDate': first_upload_datetime - timedelta(days=5), + 'effectiveDate': first_upload_datetime, + 'previous': { + **original_license.to_dict(), + 'familyName': 'someFamilyName', + 'givenName': 'someGivenName', + }, + 'updatedValues': { + 'familyName': original_license.familyName, + 'givenName': original_license.givenName, + }, + } + ) + + # Create first update record within upload window (e.g., RENEWAL) + first_update = self.test_data_generator.put_default_license_update_record_in_provider_table( + { + 'providerId': provider_id, + 'compact': self.compact, + 'jurisdiction': self.license_jurisdiction, + 'licenseType': original_license.licenseType, + 'updateType': self.update_categories.RENEWAL, + 'createDate': first_upload_datetime, + 'effectiveDate': first_upload_datetime, + 'previous': { + 'dateOfExpiration': original_license.dateOfExpiration, + 'licenseStatus': original_license.licenseStatus, + **original_license.to_dict(), + }, + 'updatedValues': { + 'dateOfExpiration': (first_upload_datetime + timedelta(days=365)).date(), + 'dateOfRenewal': first_upload_datetime.date(), + }, + } + ) + + # Create intermediate license state after first update + intermediate_license = self.test_data_generator.put_default_license_record_in_provider_table( + { + 'providerId': provider_id, + 'compact': self.compact, + 'jurisdiction': self.license_jurisdiction, + 'familyName': MOCK_ORIGINAL_FAMILY_NAME, + 'givenName': MOCK_ORIGINAL_GIVEN_NAME, + 'dateOfUpdate': first_upload_datetime, + 'dateOfExpiration': (first_upload_datetime + timedelta(days=365)).date(), + 'dateOfRenewal': first_upload_datetime.date(), + 'firstUploadDate': license_upload_datetime, + 'licenseStatus': 'active', + } + ) + + # Create second update record within upload window (e.g., DEACTIVATION) + second_update = self.test_data_generator.put_default_license_update_record_in_provider_table( + { + 'providerId': provider_id, + 'compact': self.compact, + 'jurisdiction': self.license_jurisdiction, + 'licenseType': original_license.licenseType, + 'updateType': self.update_categories.DEACTIVATION, + 'createDate': second_upload_datetime, + 'effectiveDate': second_upload_datetime, + 'previous': { + 'dateOfExpiration': intermediate_license.dateOfExpiration, + 'licenseStatus': intermediate_license.licenseStatus, + **intermediate_license.to_dict(), + }, + 'updatedValues': { + 'dateOfExpiration': (second_upload_datetime - timedelta(days=365)).date(), + 'licenseStatus': 'inactive', + 'familyName': MOCK_UPDATED_FAMILY_NAME, + 'givenName': MOCK_UPDATED_GIVEN_NAME, + }, + } + ) + + # Update the license record to reflect the final state after second update + final_license = self.test_data_generator.put_default_license_record_in_provider_table( + { + 'providerId': provider_id, + 'compact': self.compact, + 'jurisdiction': self.license_jurisdiction, + 'familyName': MOCK_UPDATED_FAMILY_NAME, + 'givenName': MOCK_UPDATED_GIVEN_NAME, + 'dateOfUpdate': second_upload_datetime, + 'dateOfExpiration': (second_upload_datetime - timedelta(days=365)).date(), + 'firstUploadDate': license_upload_datetime, + 'licenseStatus': 'inactive', + } + ) + + return existing_update, original_license, first_update, second_update, final_license + + def _when_provider_had_privilege_deactivated_from_upload(self, upload_datetime: datetime = None): + """ + Set up a scenario where a provider's privilege was deactivated due to license deactivation during upload. + Returns the privilege and its update record. + """ + from cc_common.data_model.schema.common import LicenseDeactivatedStatusEnum + + if upload_datetime is None: + upload_datetime = self.default_upload_datetime + + # provider has privilege in Nebraska that was deactivated by upload + privilege = self.test_data_generator.put_default_privilege_record_in_provider_table( + { + 'providerId': self.provider_id, + 'compact': self.compact, + 'jurisdiction': 'ne', + 'licenseJurisdiction': self.license_jurisdiction, + 'dateOfUpdate': self.default_start_datetime - timedelta(days=30), + 'licenseDeactivatedStatus': LicenseDeactivatedStatusEnum.LICENSE_DEACTIVATED, + 'dateOfExpiration': datetime.fromisoformat(MOCK_DATETIME_STRING), + } + ) + + # Create deactivation update record + privilege_update = self.test_data_generator.put_default_privilege_update_record_in_provider_table( + { + 'providerId': self.provider_id, + 'compact': self.compact, + 'jurisdiction': 'ne', + 'licenseType': privilege.licenseType, + 'updateType': self.update_categories.LICENSE_DEACTIVATION, + 'createDate': upload_datetime, + 'effectiveDate': upload_datetime, + 'previous': {**privilege.to_dict()}, + 'updatedValues': { + 'licenseDeactivatedStatus': LicenseDeactivatedStatusEnum.LICENSE_DEACTIVATED, + }, + } + ) + + return privilege, privilege_update + + def _when_provider_had_privilege_issued_during_upload(self): + """ + Set up a scenario where a provider had a non-upload-related privilege update AFTER the upload window. + This makes them ineligible for automatic rollback. + Returns the privilege and its update record. + """ + + return self.test_data_generator.put_default_privilege_record_in_provider_table( + { + 'providerId': self.provider_id, + 'compact': self.compact, + 'jurisdiction': 'ne', + 'licenseJurisdiction': self.license_jurisdiction, + 'dateOfIssuance': self.default_upload_datetime, + } + ) + + def _when_provider_had_privilege_update_after_upload(self, after_upload_datetime: datetime = None): + """ + Set up a scenario where a provider had a non-upload-related privilege update AFTER the upload window. + This makes them ineligible for automatic rollback. + Returns the privilege and its update record. + """ + if after_upload_datetime is None: + after_upload_datetime = self.default_end_datetime + timedelta(hours=1) + + privilege = self.test_data_generator.put_default_privilege_record_in_provider_table( + { + 'providerId': self.provider_id, + 'compact': self.compact, + 'jurisdiction': 'ne', + 'licenseJurisdiction': self.license_jurisdiction, + } + ) + + # Create a non-upload-related update (e.g., renewal) after the window + privilege_update = self.test_data_generator.put_default_privilege_update_record_in_provider_table( + { + 'providerId': self.provider_id, + 'compact': self.compact, + 'jurisdiction': 'ne', + 'licenseType': privilege.licenseType, + 'updateType': self.update_categories.RENEWAL, # Not LICENSE_DEACTIVATION + 'createDate': after_upload_datetime, + 'effectiveDate': after_upload_datetime, + } + ) + + return privilege, privilege_update + + def _when_provider_had_license_update_after_upload(self, after_upload_datetime: datetime = None): + """ + Set up a scenario where a provider had a non-upload-related license update AFTER the upload window. + This makes them ineligible for automatic rollback. + Returns the license and its update record. + """ + if after_upload_datetime is None: + after_upload_datetime = self.default_end_datetime + timedelta(hours=1) + + # Create a non-upload-related update (e.g., encumbrance) after the window + return self.test_data_generator.put_default_license_update_record_in_provider_table( + { + 'providerId': self.provider_id, + 'compact': self.compact, + 'jurisdiction': self.license_jurisdiction, + 'updateType': self.update_categories.ENCUMBRANCE, # Not an upload-related category + 'createDate': after_upload_datetime, + 'effectiveDate': after_upload_datetime, + } + ) + + def _when_provider_top_level_record_needs_reverted(self, before_upload_datetime: datetime = None): + """ + Set up a scenario where the provider's top-level record needs to be reverted. + Returns the provider record. + """ + if before_upload_datetime is None: + before_upload_datetime = self.default_start_datetime - timedelta(days=30) + + # Existing license updated during window + self._when_provider_had_license_updated_from_upload() + + # Create provider record with old values + provider = self.test_data_generator.put_default_provider_record_in_provider_table( + { + 'providerId': self.provider_id, + 'compact': self.compact, + 'familyName': MOCK_ORIGINAL_FAMILY_NAME, + 'givenName': MOCK_ORIGINAL_GIVEN_NAME, + 'dateOfUpdate': before_upload_datetime, + } + ) + + # Simulate that the provider record was updated during upload + updated_provider = self.test_data_generator.put_default_provider_record_in_provider_table( + { + 'providerId': self.provider_id, + 'compact': self.compact, + 'familyName': MOCK_UPDATED_FAMILY_NAME, + 'givenName': MOCK_UPDATED_GIVEN_NAME, + 'dateOfUpdate': self.default_upload_datetime, + } + ) + + return provider, updated_provider + + def _when_provider_changed_home_jurisdiction_after_license_upload(self): + self._when_provider_had_license_created_from_upload() + + provider_update_record = self.test_data_generator.put_default_provider_update_record_in_provider_table( + value_overrides={ + 'providerId': self.provider_id, + 'compact': self.compact, + # home jurisdiction was changed during license upload window + 'createDate': self.default_upload_datetime, + 'updateType': self.update_categories.HOME_JURISDICTION_CHANGE, + 'previous': {**self.provider_data.to_dict()}, + 'updatedValues': { + 'currentHomeJurisdiction': self.license_jurisdiction, + }, + }, + # home jurisdiction was changed during license upload window + date_of_update_override=self.default_upload_datetime.isoformat(), + ) + + # Simulate that the provider record was updated during upload + self.test_data_generator.put_default_provider_record_in_provider_table( + { + 'currentHomeJurisdiction': self.license_jurisdiction, + } + ) + + return provider_update_record + + def test_provider_top_level_record_reset_to_prior_values_when_upload_reverted(self): + """Test that provider top-level record is reset to values before upload.""" + from handlers.rollback_license_upload import rollback_license_upload + + # Setup: + # Provider record was updated during upload + old_provider, new_provider = self._when_provider_top_level_record_needs_reverted() + + # Execute: Perform rollback + event = self._generate_test_event() + + result = rollback_license_upload(event, Mock()) + + # Assert: Rollback completed successfully + self.assertEqual(result['rollbackStatus'], 'COMPLETE') + self.assertEqual(1, result['providersReverted']) + + # Verify: Provider record has been reset to old values + provider_records = self.config.data_client.get_provider_user_records( + compact=self.compact, + provider_id=self.provider_id, + ) + provider_record = provider_records.get_provider_record() + self.assertEqual(old_provider.givenName, provider_record.givenName) + self.assertEqual(old_provider.familyName, provider_record.familyName) + + def test_provider_top_level_record_deleted_when_license_created_during_bad_upload(self): + """Test that provider top-level record is deleted if the license record + is also deleted when reverting upload.""" + from handlers.rollback_license_upload import rollback_license_upload + + # Setup: + # License and provider records were created during upload + self._when_provider_had_license_created_from_upload() + + # Execute: Perform rollback + event = self._generate_test_event() + + result = rollback_license_upload(event, Mock()) + + # Assert: Rollback completed successfully + self.assertEqual(result['rollbackStatus'], 'COMPLETE') + self.assertEqual(1, result['providersReverted']) + + # Verify: All provider records have been deleted + with pytest.raises(CCNotFoundException): + self.config.data_client.get_provider_user_records( + compact=self.compact, + provider_id=self.provider_id, + ) + + def test_provider_license_record_reset_to_prior_values_when_upload_reverted(self): + """Test that license record is reset to values before upload.""" + from handlers.rollback_license_upload import rollback_license_upload + + # Setup: License was updated during upload (e.g., renewed), but was first uploaded before start time + original_license, license_update, updated_license = self._when_provider_had_license_updated_from_upload( + license_upload_datetime=self.default_start_datetime - timedelta(hours=1) + ) + + # Store the original expiration date from the update's previous values + original_expiration = license_update.previous['dateOfExpiration'] + + # Execute: Perform rollback + event = self._generate_test_event() + + result = rollback_license_upload(event, Mock()) + + # should return complete message + self.assertEqual(result['rollbackStatus'], 'COMPLETE') + self.assertEqual(result['providersReverted'], 1) + + # Verify: License record has been reset to original values + provider_records = self.config.data_client.get_provider_user_records( + compact=self.compact, + provider_id=self.provider_id, + include_update_tier=UpdateTierEnum.TIER_THREE, + ) + licenses = provider_records.get_license_records() + self.assertEqual(len(licenses), 1) + license_record = licenses[0] + self.assertEqual(license_record.dateOfExpiration, original_expiration) + + # Verify: Update record has been deleted + license_updates = provider_records.get_all_license_update_records() + self.assertEqual(len(license_updates), 0, 'License update records should be deleted') + + def test_provider_license_record_reverted_to_earliest_update_previous_values_when_multiple_updates(self): + from handlers.rollback_license_upload import rollback_license_upload + + # Setup: License was updated twice during upload window, but was first uploaded before start time + existing_update, original_license, first_update, second_update, final_license = ( + self._when_license_was_updated_twice() + ) + + # Execute: Perform rollback + event = self._generate_test_event() + + result = rollback_license_upload(event, Mock()) + + # Assert: Rollback completed successfully + self.assertEqual(result['rollbackStatus'], 'COMPLETE') + self.assertEqual(result['providersReverted'], 1) + + # Verify: License record has been reset to the values from the first (earliest) update's previous field + provider_records = self.config.data_client.get_provider_user_records( + compact=self.compact, + provider_id=self.provider_id, + include_update_tier=UpdateTierEnum.TIER_THREE, + ) + licenses = provider_records.get_license_records() + self.assertEqual(len(licenses), 1) + license_record = licenses[0] + # license should look the same as it did before the updates that were rolled back + self.assertEqual(original_license.serialize_to_database_record(), license_record.serialize_to_database_record()) + + # Verify: Both update records have been deleted + license_updates = provider_records.get_all_license_update_records() + # license update that existed before upload should still be there + self.assertEqual(len(license_updates), 1, 'Expected one existing license update to remain') + self.assertEqual( + existing_update.serialize_to_database_record(), license_updates[0].serialize_to_database_record() + ) + + def test_provider_privilege_record_reactivated_when_upload_reverted(self): + """Test that privilege is reactivated when license deactivation is reverted.""" + from handlers.rollback_license_upload import rollback_license_upload + + # Setup: Privilege was deactivated during upload due to license deactivation + # license was uploaded before rollback window + self._when_provider_had_license_updated_from_upload( + license_upload_datetime=self.default_start_datetime - timedelta(hours=1) + ) + self._when_provider_had_privilege_deactivated_from_upload() + + # Execute: Perform rollback + event = self._generate_test_event() + + result = rollback_license_upload(event, Mock()) + + # Assert: Rollback completed successfully + self.assertEqual(result['rollbackStatus'], 'COMPLETE') + self.assertEqual(result['providersReverted'], 1) + + # Verify: Privilege has been reactivated (status should be 'active') + provider_records = self.config.data_client.get_provider_user_records( + compact=self.compact, + provider_id=self.provider_id, + include_update_tier=UpdateTierEnum.TIER_THREE, + ) + privileges = provider_records.get_privilege_records() + self.assertEqual(len(privileges), 1) + privilege_record = privileges[0] + self.assertEqual(privilege_record.status, 'active', 'Privilege should be reactivated') + self.assertIsNone(privilege_record.licenseDeactivatedStatus) + + # Verify: Privilege update record has been deleted + privilege_updates = provider_records.get_all_privilege_update_records() + self.assertEqual(len(privilege_updates), 0, 'Privilege update records should be deleted') + + # make sure license record was reactivated as well + license_record = provider_records.get_specific_license_record( + jurisdiction=self.license_jurisdiction, license_abbreviation=privilege_record.licenseTypeAbbreviation + ) + self.assertEqual('active', license_record.licenseStatus) + + def test_provider_license_updates_and_license_record_within_time_period_removed_when_upload_reverted(self): + """Test that license update records and license record within the time window are deleted.""" + from handlers.rollback_license_upload import rollback_license_upload + + # Setup: License was uploaded and then updated during upload + self._when_provider_had_license_updated_from_upload( + license_upload_datetime=self.default_start_datetime + timedelta(hours=1) + ) + + # Verify update record exists before rollback + provider_records_before = self.config.data_client.get_provider_user_records( + compact=self.compact, + provider_id=self.provider_id, + include_update_tier=UpdateTierEnum.TIER_THREE, + ) + licenses_before = provider_records_before.get_license_records() + self.assertEqual(len(licenses_before), 1, 'Should have license record before rollback') + license_updates_before = provider_records_before.get_all_license_update_records() + self.assertEqual(len(license_updates_before), 1, 'Should have update record before rollback') + + # Execute: Perform rollback + event = self._generate_test_event() + + result = rollback_license_upload(event, Mock()) + + # Assert: Rollback completed successfully + self.assertEqual(result['rollbackStatus'], 'COMPLETE') + + # Verify: All records within time window have been deleted + with pytest.raises(CCNotFoundException): + self.config.data_client.get_provider_user_records( + compact=self.compact, + provider_id=self.provider_id, + include_update_tier=UpdateTierEnum.TIER_THREE, + ) + + def test_provider_skipped_if_license_updates_detected_after_end_of_time_window_when_upload_reverted(self): + """Test that provider is skipped if non-upload-related license updates exist after time window.""" + from handlers.rollback_license_upload import rollback_license_upload + + # Setup: Provider had valid license before upload, and update occurred during upload window + self._when_provider_had_license_updated_from_upload( + license_upload_datetime=self.default_start_datetime - timedelta(hours=1) + ) + # update also occurred after upload window + self._when_provider_had_license_update_after_upload() + + event = self._generate_test_event() + result = rollback_license_upload(event, Mock()) + + # Assert: Rollback completed but provider was skipped + self.assertEqual('COMPLETE', result['rollbackStatus']) + self.assertEqual(0, result['providersReverted']) + self.assertEqual(1, result['providersSkipped']) + + # Verify: License record and update still exist (not rolled back) + provider_records = self.config.data_client.get_provider_user_records( + compact=self.compact, + provider_id=self.provider_id, + include_update_tier=UpdateTierEnum.TIER_THREE, + ) + licenses = provider_records.get_license_records() + self.assertEqual(len(licenses), 1, 'License should still exist') + license_updates = provider_records.get_all_license_update_records() + self.assertEqual(2, len(license_updates), 'License updates should still exist') + + def test_provider_skipped_if_privilege_updates_detected_after_time_period_when_upload_reverted(self): + """Test that provider is skipped if non-upload-related privilege updates exist after time window.""" + from handlers.rollback_license_upload import rollback_license_upload + + # Setup: Provider had privilege update after upload window + self._when_provider_had_license_updated_from_upload() + self._when_provider_had_privilege_update_after_upload() + + # Execute: Perform rollback + event = self._generate_test_event() + + result = rollback_license_upload(event, Mock()) + + # Assert: Rollback completed but provider was skipped + self.assertEqual(result['rollbackStatus'], 'COMPLETE') + self.assertEqual(1, result['providersSkipped']) + self.assertEqual(0, result['providersReverted']) + + # Verify: Privilege record and update still exist (not rolled back) + provider_records = self.config.data_client.get_provider_user_records( + compact=self.compact, + provider_id=self.provider_id, + include_update_tier=UpdateTierEnum.TIER_THREE, + ) + privileges = provider_records.get_privilege_records() + self.assertEqual(1, len(privileges), 'Privilege should still exist') + privilege_updates = provider_records.get_all_privilege_update_records() + self.assertEqual(1, len(privilege_updates), 'Privilege update should still exist') + + # Validation tests + def test_rollback_validates_datetime_format(self): + from handlers.rollback_license_upload import rollback_license_upload + + event = self._generate_test_event() + event['startDateTime'] = 'invalid-datetime' + + result = rollback_license_upload(event, Mock()) + + self.assertEqual(result['rollbackStatus'], 'FAILED') + self.assertIn('Invalid datetime format', result['error']) + + def test_rollback_validates_time_window_order(self): + from handlers.rollback_license_upload import rollback_license_upload + + event = self._generate_test_event() + event['startDateTime'] = self.default_end_datetime.isoformat() + event['endDateTime'] = self.default_start_datetime.isoformat() + + result = rollback_license_upload(event, Mock()) + + self.assertEqual(result['rollbackStatus'], 'FAILED') + self.assertIn('Start time must be before end time', result['error']) + + def test_rollback_validates_maximum_time_window(self): + from handlers.rollback_license_upload import rollback_license_upload + + start = self.config.current_standard_datetime - timedelta(days=8) # More than 7 days + end = self.config.current_standard_datetime + + event = self._generate_test_event() + event['startDateTime'] = start.isoformat() + event['endDateTime'] = end.isoformat() + + result = rollback_license_upload(event, Mock()) + + self.assertEqual(result['rollbackStatus'], 'FAILED') + self.assertIn('cannot exceed', result['error']) + + def _perform_rollback_and_get_s3_object(self): + from handlers.rollback_license_upload import rollback_license_upload + + # Execute: Perform rollback + event = self._generate_test_event() + + rollback_license_upload(event, Mock()) + + # Read object from S3 and verify its contents match what is expected + s3_key = f'licenseUploadRollbacks/{MOCK_EXECUTION_NAME}/results.json' + s3_obj = self.config.s3_client.get_object(Bucket=self.config.disaster_recovery_results_bucket_name, Key=s3_key) + return json.loads(s3_obj['Body'].read().decode('utf-8')) + + # Tests for checking data written to S3 + def test_expected_s3_object_stored_when_provider_license_record_reset_to_prior_values(self): + # Setup: License was updated during upload (e.g., renewed), but was first uploaded before start time + original_license, license_update, updated_license = self._when_provider_had_license_updated_from_upload( + license_upload_datetime=self.default_start_datetime - timedelta(hours=1) + ) + + results_data = self._perform_rollback_and_get_s3_object() + + # Verify the structure of the results + self.assertEqual( + { + 'executionName': MOCK_EXECUTION_NAME, + 'failedProviderDetails': [], + 'revertedProviderSummaries': [ + { + 'licensesReverted': [ + { + 'action': 'REVERT', + 'jurisdiction': original_license.jurisdiction, + 'licenseType': original_license.licenseType, + } + ], + 'privilegesReverted': [], + 'providerId': self.provider_id, + # NOTE: if the test update data is modified, the sha here will need to be updated + 'updatesDeleted': [ + 'aslp#UPDATE#3#license/oh/slp/2025-10-23T07:15:00+00:00/d92450a96739428f1a77c051dce9d4a6' + ], + } + ], + 'skippedProviderDetails': [], + }, + results_data, + ) + + def test_expected_s3_object_stored_when_provider_license_record_deleted_from_rollback(self): + # Setup: License was updated during upload (e.g., renewed), but was first uploaded before start time + new_license = self._when_provider_had_license_created_from_upload() + + results_data = self._perform_rollback_and_get_s3_object() + + # Verify the structure of the results + self.assertEqual( + { + 'executionName': MOCK_EXECUTION_NAME, + 'failedProviderDetails': [], + 'revertedProviderSummaries': [ + { + 'licensesReverted': [ + { + 'action': 'DELETE', + 'jurisdiction': new_license.jurisdiction, + 'licenseType': new_license.licenseType, + } + ], + 'privilegesReverted': [], + 'providerId': self.provider_id, + 'updatesDeleted': [], + } + ], + 'skippedProviderDetails': [], + }, + results_data, + ) + + def test_expected_s3_object_stored_when_provider_privilege_record_reactivated_from_rollback(self): + # Setup: Privilege was deactivated during upload due to license deactivation + # license was uploaded before rollback window + self._when_provider_had_license_updated_from_upload( + license_upload_datetime=self.default_start_datetime - timedelta(hours=1) + ) + privilege, privilege_update = self._when_provider_had_privilege_deactivated_from_upload() + + results_data = self._perform_rollback_and_get_s3_object() + + # Verify the structure of the results + self.assertEqual( + { + 'executionName': MOCK_EXECUTION_NAME, + 'failedProviderDetails': [], + 'revertedProviderSummaries': [ + { + 'licensesReverted': [ + { + 'action': 'REVERT', + 'jurisdiction': self.license_jurisdiction, + 'licenseType': privilege.licenseType, + } + ], + 'privilegesReverted': [ + { + 'action': 'REACTIVATED', + 'jurisdiction': privilege.jurisdiction, + 'licenseType': privilege.licenseType, + } + ], + 'providerId': self.provider_id, + # NOTE: if the test update data is modified, the shas here will need to be updated + 'updatesDeleted': [ + 'aslp#UPDATE#1#privilege/ne/slp/2025-10-23T07:15:00+00:00/06b886756a79b796ad10b17bd67057e6', + 'aslp#UPDATE#3#license/oh/slp/2025-10-23T07:15:00+00:00/d92450a96739428f1a77c051dce9d4a6', + ], + } + ], + 'skippedProviderDetails': [], + }, + results_data, + ) + + def test_expected_s3_object_stored_when_provider_skipped_due_to_extra_license_updates(self): + # Setup: Provider had valid license before upload, and update occurred during upload window + original_license, license_update, updated_license = self._when_provider_had_license_updated_from_upload( + license_upload_datetime=self.default_start_datetime - timedelta(hours=1) + ) + # update also occurred after upload window + encumbrance_update = self._when_provider_had_license_update_after_upload() + + results_data = self._perform_rollback_and_get_s3_object() + + # Verify the structure of the results + expected_reason_message = ( + 'License was updated with a change unrelated to license upload or the update ' + 'occurred after rollback end time. Manual review required.' + ) + self.assertEqual( + { + 'executionName': MOCK_EXECUTION_NAME, + 'failedProviderDetails': [], + 'revertedProviderSummaries': [], + 'skippedProviderDetails': [ + { + 'ineligibleUpdates': [ + { + 'updateTime': encumbrance_update.createDate.isoformat(), + 'licenseType': original_license.licenseType, + 'reason': expected_reason_message, + 'recordType': 'licenseUpdate', + 'typeOfUpdate': encumbrance_update.updateType, + } + ], + 'providerId': MOCK_PROVIDER_ID, + 'reason': 'Provider has updates that are either ' + 'unrelated to license upload or ' + 'occurred after rollback end time. ' + 'Manual review required.', + } + ], + }, + results_data, + ) + + def test_expected_s3_object_stored_when_provider_skipped_due_to_privilege_issuance(self): + # Setup: Provider had privilege update after upload window + self._when_provider_had_license_updated_from_upload() + privilege = self._when_provider_had_privilege_issued_during_upload() + + results_data = self._perform_rollback_and_get_s3_object() + + # Verify the structure of the results + expected_reason_message = ( + f"Privilege in jurisdiction '{privilege.jurisdiction}' issued after license upload. Manual review required." + ) + self.assertEqual( + { + 'executionName': MOCK_EXECUTION_NAME, + 'failedProviderDetails': [], + 'revertedProviderSummaries': [], + 'skippedProviderDetails': [ + { + 'ineligibleUpdates': [ + { + 'updateTime': privilege.dateOfIssuance.isoformat(), + 'licenseType': privilege.licenseType, + 'reason': expected_reason_message, + 'recordType': 'privilegeUpdate', + 'typeOfUpdate': 'Issuance', + } + ], + 'providerId': MOCK_PROVIDER_ID, + 'reason': 'Provider has updates that are either ' + 'unrelated to license upload or ' + 'occurred after rollback end time. ' + 'Manual review required.', + } + ], + }, + results_data, + ) + + def test_expected_s3_object_stored_when_provider_skipped_due_to_extra_privilege_updates(self): + # Setup: Provider had privilege update after upload window + self._when_provider_had_license_updated_from_upload() + privilege, privilege_update = self._when_provider_had_privilege_update_after_upload() + + results_data = self._perform_rollback_and_get_s3_object() + + # Verify the structure of the results + expected_reason_message = ( + "Privilege in jurisdiction 'ne' was updated with a change unrelated to license upload. " + 'Manual review required.' + ) + self.assertEqual( + { + 'executionName': MOCK_EXECUTION_NAME, + 'failedProviderDetails': [], + 'revertedProviderSummaries': [], + 'skippedProviderDetails': [ + { + 'ineligibleUpdates': [ + { + 'updateTime': privilege_update.createDate.isoformat(), + 'licenseType': privilege.licenseType, + 'reason': expected_reason_message, + 'recordType': 'privilegeUpdate', + 'typeOfUpdate': privilege_update.updateType, + } + ], + 'providerId': MOCK_PROVIDER_ID, + 'reason': 'Provider has updates that are either ' + 'unrelated to license upload or ' + 'occurred after rollback end time. ' + 'Manual review required.', + } + ], + }, + results_data, + ) + + def test_expected_s3_object_stored_when_provider_skipped_due_to_extra_provider_updates(self): + # Setup: Provider had privilege update after upload window + provider_update = self._when_provider_changed_home_jurisdiction_after_license_upload() + + results_data = self._perform_rollback_and_get_s3_object() + + # Verify the structure of the results + expected_reason_message = 'Provider update occurred after rollback start time. Manual review required.' + self.assertEqual( + { + 'executionName': MOCK_EXECUTION_NAME, + 'failedProviderDetails': [], + 'revertedProviderSummaries': [], + 'skippedProviderDetails': [ + { + 'ineligibleUpdates': [ + { + 'updateTime': provider_update.dateOfUpdate.isoformat(), + 'reason': expected_reason_message, + 'recordType': 'providerUpdate', + 'typeOfUpdate': provider_update.updateType, + 'licenseType': 'N/A', + } + ], + 'providerId': MOCK_PROVIDER_ID, + 'reason': 'Provider has updates that are either ' + 'unrelated to license upload or ' + 'occurred after rollback end time. ' + 'Manual review required.', + } + ], + }, + results_data, + ) + + def test_expected_s3_object_stored_when_provider_schema_validation_fails_during_rollback(self): + """Test that failed provider details are correctly stored in S3 results when a validation exception occurs.""" + # Setup: License was updated during upload, but one update record has invalid field + original_license, license_update, updated_license = self._when_provider_had_license_updated_from_upload( + license_upload_datetime=self.default_start_datetime - timedelta(hours=1) + ) + serialized_license = updated_license.serialize_to_database_record() + serialized_license['jurisdictionUploadedLicenseStatus'] = 'foo' + self.config.provider_table.put_item(Item=serialized_license) + + results_data = self._perform_rollback_and_get_s3_object() + + # Verify the structure of the results contains failed provider details + self.assertEqual( + { + 'executionName': MOCK_EXECUTION_NAME, + 'failedProviderDetails': [ + { + 'error': 'Failed to rollback updates for provider. Manual review required: Validation error: ' + "{'jurisdictionUploadedLicenseStatus': ['Must be one of: active, inactive.']}", + 'providerId': self.provider_id, + } + ], + 'revertedProviderSummaries': [], + 'skippedProviderDetails': [], + }, + results_data, + ) + + def test_rollback_handles_loading_existing_s3_results_and_appends_new_data(self): + """Test that rollback can load existing S3 results and append new data without deleting previous data.""" + from uuid import uuid4 + + existing_skipped_provider_id = str(uuid4()) + existing_reverted_provider_id = str(uuid4()) + existing_failed_provider_id = str(uuid4()) + + # Setup: Create existing provider with license that will be reverted + # This provider will have a privilege that gets reactivated + self._when_provider_had_license_updated_from_upload( + license_upload_datetime=self.default_start_datetime - timedelta(hours=1) + ) + self._when_provider_had_privilege_deactivated_from_upload() + + # Create initial S3 results with data in all fields + s3_key = f'licenseUploadRollbacks/{MOCK_EXECUTION_NAME}/results.json' + + # Create existing results data in the format that from_dict expects (camelCase for all keys) + existing_results_data = { + 'executionName': MOCK_EXECUTION_NAME, + 'skippedProviderDetails': [ + { + 'providerId': existing_skipped_provider_id, + 'reason': 'Existing skipped provider reason', + 'ineligibleUpdates': [ + { + 'recordType': 'licenseUpdate', + 'typeOfUpdate': 'ENCUMBRANCE', + 'updateTime': (self.default_start_datetime - timedelta(days=2)).isoformat(), + 'reason': 'Existing ineligible update reason', + 'licenseType': 'audiologist', + } + ], + } + ], + 'failedProviderDetails': [ + { + 'providerId': existing_failed_provider_id, + 'error': 'Existing failure error message', + } + ], + 'revertedProviderSummaries': [ + { + 'providerId': existing_reverted_provider_id, + 'licensesReverted': [ + { + 'jurisdiction': 'tx', + 'licenseType': 'audiologist', + 'action': 'REVERT', + } + ], + 'privilegesReverted': [], + 'updatesDeleted': ['existing-update-sha-1'], + } + ], + } + + # Write existing results to S3 + self.config.s3_client.put_object( + Bucket=self.config.disaster_recovery_results_bucket_name, + Key=s3_key, + Body=json.dumps(existing_results_data, indent=2), + ContentType='application/json', + ) + + final_results_data = self._perform_rollback_and_get_s3_object() + + # Verify: All existing data is preserved and new data is appended + self.assertEqual( + { + 'executionName': MOCK_EXECUTION_NAME, + 'skippedProviderDetails': [ + { + 'providerId': existing_skipped_provider_id, + 'reason': 'Existing skipped provider reason', + 'ineligibleUpdates': [ + { + 'recordType': 'licenseUpdate', + 'typeOfUpdate': 'ENCUMBRANCE', + 'updateTime': (self.default_start_datetime - timedelta(days=2)).isoformat(), + 'reason': 'Existing ineligible update reason', + 'licenseType': 'audiologist', + } + ], + } + ], + 'failedProviderDetails': [ + { + 'providerId': existing_failed_provider_id, + 'error': 'Existing failure error message', + } + ], + 'revertedProviderSummaries': [ + { + 'providerId': existing_reverted_provider_id, + 'licensesReverted': [ + { + 'jurisdiction': 'tx', + 'licenseType': 'audiologist', + 'action': 'REVERT', + } + ], + 'privilegesReverted': [], + 'updatesDeleted': ['existing-update-sha-1'], + }, + { + 'providerId': self.provider_id, + 'licensesReverted': [ + { + 'action': 'REVERT', + 'jurisdiction': self.license_jurisdiction, + 'licenseType': ANY, + } + ], + 'privilegesReverted': [ + { + 'action': 'REACTIVATED', + 'jurisdiction': 'ne', + 'licenseType': ANY, + } + ], + 'updatesDeleted': ANY, + }, + ], + }, + final_results_data, + ) + + @patch('handlers.rollback_license_upload.time') + def test_rollback_handles_pagination_when_provider_id_present_in_event_input(self, mock_time): + """Test that rollback can paginate across multiple invocations using continueFromProviderId.""" + from handlers.rollback_license_upload import rollback_license_upload + + # Lambda functions have a timeout of 15 minutes, so we set a cutoff of 12 minutes before we loop around + # the step function to reset the timeout. This mock allows us to test that branch of logic. + # the first time the mock_time function is called, it will return current time + # the second time the mock_time function is called, it will return + 1 second + # the third time the mock_time function is called, it will return 12 minutes + 2 seconds (cutoff is 12 minutes) + # this should cause the lambda to return an IN_PROGRESS status with a pagination key + mock_time.time.side_effect = [0, 1, 12 * 60 + 2] # current time, 12 minutes + 2 seconds + + # Setup: Create two providers with licenses that will be reverted + # Provider IDs in sorted order (to ensure consistent pagination behavior) + mock_first_provider_id = '11111111-5ed3-4be4-8ad5-c8558f587890' + mock_second_provider_id = '22222222-5ed3-4be4-8ad5-c8558f587890' + + # Add first provider + self._add_provider_record(provider_id=mock_first_provider_id) + self._when_provider_had_license_updated_from_upload( + license_upload_datetime=self.default_start_datetime - timedelta(hours=1), provider_id=mock_first_provider_id + ) + + # Add second provider + self._add_provider_record(provider_id=mock_second_provider_id) + self._when_provider_had_license_updated_from_upload( + license_upload_datetime=self.default_start_datetime - timedelta(hours=1), + provider_id=mock_second_provider_id, + ) + + # Execute: First invocation (should timeout after processing first provider) + event = self._generate_test_event() + + result_first = rollback_license_upload(event, Mock()) + + # Assert: First invocation returned IN_PROGRESS status + self.assertEqual(result_first['rollbackStatus'], 'IN_PROGRESS') + self.assertEqual(1, result_first['providersProcessed']) + self.assertEqual(1, result_first['providersReverted']) + self.assertEqual(0, result_first['providersSkipped']) + self.assertEqual(0, result_first['providersFailed']) + self.assertEqual(mock_second_provider_id, result_first['continueFromProviderId']) + + # Execute: Second invocation (continue from where we left off) + # Reset mock time for second invocation + mock_time.time.side_effect = [0, 1] # Won't timeout this time + + result_second = rollback_license_upload(result_first, Mock()) + + # Assert: Second invocation completed successfully + self.assertEqual(result_second['rollbackStatus'], 'COMPLETE') + self.assertEqual(2, result_second['providersProcessed']) + self.assertEqual(2, result_second['providersReverted']) + self.assertEqual(0, result_second['providersSkipped']) + self.assertEqual(0, result_second['providersFailed']) + + # Verify: S3 results contain both providers + s3_key = f'licenseUploadRollbacks/{MOCK_EXECUTION_NAME}/results.json' + s3_obj = self.config.s3_client.get_object(Bucket=self.config.disaster_recovery_results_bucket_name, Key=s3_key) + final_results_data = json.loads(s3_obj['Body'].read().decode('utf-8')) + + # Should have 2 reverted providers + self.assertEqual( + { + 'executionName': MOCK_EXECUTION_NAME, + 'failedProviderDetails': [], + 'revertedProviderSummaries': [ + { + 'licensesReverted': [ + { + 'action': 'REVERT', + 'jurisdiction': 'oh', + 'licenseType': 'speech-language pathologist', + } + ], + 'privilegesReverted': [], + 'providerId': mock_first_provider_id, + 'updatesDeleted': [ + 'aslp#UPDATE#3#license/oh/slp/2025-10-23T07:15:00+00:00/d92450a96739428f1a77c051dce9d4a6' + ], + }, + { + 'licensesReverted': [ + { + 'action': 'REVERT', + 'jurisdiction': 'oh', + 'licenseType': 'speech-language pathologist', + } + ], + 'privilegesReverted': [], + 'providerId': mock_second_provider_id, + 'updatesDeleted': [ + 'aslp#UPDATE#3#license/oh/slp/2025-10-23T07:15:00+00:00/d92450a96739428f1a77c051dce9d4a6' + ], + }, + ], + 'skippedProviderDetails': [], + }, + final_results_data, + ) + + @patch('handlers.rollback_license_upload.config.event_bus_client') + def test_event_bus_client_called_with_expected_arguments_for_revert_events(self, mock_event_bus_client): + """Test that event bus client methods are called with expected arguments when publishing revert events.""" + from handlers.rollback_license_upload import rollback_license_upload + + # Setup: License was updated during upload and privilege was deactivated + # This scenario will trigger both license and privilege revert events + original_license, license_update, updated_license = self._when_provider_had_license_updated_from_upload( + license_upload_datetime=self.default_start_datetime - timedelta(hours=1) + ) + privilege, privilege_update = self._when_provider_had_privilege_deactivated_from_upload() + + # Execute: Perform rollback + event = self._generate_test_event() + result = rollback_license_upload(event, Mock()) + + # Assert: Rollback completed successfully + self.assertEqual(result['rollbackStatus'], 'COMPLETE') + self.assertEqual(result['providersReverted'], 1) + + # Verify: publish_license_revert_event was called with expected arguments + expected_license_kwargs = { + 'source': 'org.compactconnect.disaster-recovery', + 'compact': self.compact, + 'provider_id': self.provider_id, + 'jurisdiction': self.license_jurisdiction, + 'license_type': original_license.licenseType, + 'rollback_reason': 'Test rollback', + 'start_time': self.default_start_datetime, + 'end_time': self.default_end_datetime, + 'execution_name': MOCK_EXECUTION_NAME, + 'event_batch_writer': ANY, + } + mock_event_bus_client.publish_license_revert_event.assert_called_once_with(**expected_license_kwargs) + + # Verify: publish_privilege_revert_event was called with expected arguments + expected_privilege_kwargs = { + 'source': 'org.compactconnect.disaster-recovery', + 'compact': self.compact, + 'provider_id': self.provider_id, + 'jurisdiction': privilege.jurisdiction, + 'license_type': privilege.licenseType, + 'rollback_reason': 'Test rollback', + 'start_time': self.default_start_datetime, + 'end_time': self.default_end_datetime, + 'execution_name': MOCK_EXECUTION_NAME, + 'event_batch_writer': ANY, + } + mock_event_bus_client.publish_privilege_revert_event.assert_called_once_with(**expected_privilege_kwargs) + + def test_transaction_failure_is_logged_and_provider_marked_as_failed(self): + """Test that transaction failures are properly logged and the provider is marked as failed.""" + from botocore.exceptions import ClientError + + # Setup: Create a scenario with privilege deactivation which will have PUT, DELETE, and UPDATE operations + # - License update (DELETE of update record) + # - Privilege update (DELETE of update record) + # - Privilege reactivation (UPDATE to remove licenseDeactivatedStatus) + # - Provider record update (PUT) + self._when_provider_had_license_updated_from_upload( + license_upload_datetime=self.default_start_datetime - timedelta(hours=1) + ) + self._when_provider_had_privilege_deactivated_from_upload() + + # Mock the transaction to fail with a ClientError + mock_error = ClientError( + error_response={'Error': {'Code': 'TransactionCanceledException', 'Message': 'Transaction cancelled'}}, + operation_name='TransactWriteItems', + ) + + # Patch at the handler module level to ensure it works across the full test suite + with patch( + 'handlers.rollback_license_upload.config.provider_table.meta.client.transact_write_items', + side_effect=mock_error, + ): + results_data = self._perform_rollback_and_get_s3_object() + + # Verify: Provider was marked as failed + self.assertEqual(1, len(results_data['failedProviderDetails'])) + self.assertEqual(self.provider_id, results_data['failedProviderDetails'][0]['providerId']) + self.assertIn('TransactionCanceledException', results_data['failedProviderDetails'][0]['error']) + + # Verify: No providers were reverted or skipped + self.assertEqual(0, len(results_data['revertedProviderSummaries'])) + self.assertEqual(0, len(results_data['skippedProviderDetails'])) + + def test_orphaned_license_updates_cause_provider_to_be_skipped(self): + """Test that orphaned license update records (without top-level license records) + cause provider to be skipped.""" + from uuid import uuid4 + + from handlers.rollback_license_upload import rollback_license_upload + + orphaned_provider_id = str(uuid4()) + + # Setup: License was uploaded and then updated during upload + # Create update record within upload window to simulate license deactivation + orphaned_license_update = self.test_data_generator.put_default_license_update_record_in_provider_table( + { + 'providerId': orphaned_provider_id, + 'compact': self.compact, + 'jurisdiction': self.license_jurisdiction, + 'updateType': self.update_categories.DEACTIVATION, + 'createDate': self.default_upload_datetime, + 'effectiveDate': self.default_upload_datetime, + 'updatedValues': { + # simulate accidentally changing the expiration to last year + 'dateOfExpiration': (self.default_upload_datetime - timedelta(days=365)).date(), + 'licenseStatus': 'inactive', + 'familyName': MOCK_UPDATED_FAMILY_NAME, + 'givenName': MOCK_UPDATED_GIVEN_NAME, + }, + } + ) + + # Verify update record exists before rollback + provider_records_before = self.config.data_client.get_provider_user_records( + compact=self.compact, + provider_id=orphaned_provider_id, + include_update_tier=UpdateTierEnum.TIER_THREE, + ) + licenses_before = provider_records_before.get_license_records() + self.assertEqual(len(licenses_before), 0, 'Should not have license record before rollback') + license_updates_before = provider_records_before.get_all_license_update_records() + self.assertEqual(len(license_updates_before), 1, 'Should have orphaned update record before rollback') + + # Execute: Perform rollback + event = self._generate_test_event() + + result = rollback_license_upload(event, Mock()) + + # Assert: Rollback completed with provider skipped + self.assertEqual(result['rollbackStatus'], 'COMPLETE') + self.assertEqual(result['providersSkipped'], 1, 'Provider with orphaned updates should be skipped') + self.assertEqual(result['providersReverted'], 0, 'No providers should be reverted') + self.assertEqual(result['providersFailed'], 0, 'No providers should have failed') + + # Verify S3 results contain the orphaned update details + s3_key = f'licenseUploadRollbacks/{MOCK_EXECUTION_NAME}/results.json' + s3_obj = self.config.s3_client.get_object(Bucket=self.config.disaster_recovery_results_bucket_name, Key=s3_key) + results_data = json.loads(s3_obj['Body'].read().decode('utf-8')) + + # Verify the structure of the results + expected_reason = ( + f'License update record(s) exist for license in jurisdiction ' + f'{self.license_jurisdiction} with type {orphaned_license_update.licenseType}, ' + f'but no corresponding top-level license record was found. ' + f'This indicates data inconsistency. Manual review required.' + ) + + self.assertEqual(1, len(results_data['skippedProviderDetails'])) + skipped_detail = results_data['skippedProviderDetails'][0] + + self.assertEqual(orphaned_provider_id, skipped_detail['providerId']) + self.assertIn('Manual review required', skipped_detail['reason']) + + # Check ineligible updates details + self.assertEqual(1, len(skipped_detail['ineligibleUpdates'])) + ineligible_update = skipped_detail['ineligibleUpdates'][0] + + self.assertEqual('licenseUpdate', ineligible_update['recordType']) + self.assertEqual('Orphaned', ineligible_update['typeOfUpdate']) + self.assertEqual(orphaned_license_update.licenseType, ineligible_update['licenseType']) + self.assertEqual(expected_reason, ineligible_update['reason']) + + # Verify no providers were reverted or failed + self.assertEqual(0, len(results_data['revertedProviderSummaries'])) + self.assertEqual(0, len(results_data['failedProviderDetails'])) + + def test_provider_skipped_when_encumbrance_update_created_within_upload_window(self): + from handlers.rollback_license_upload import rollback_license_upload + + # Setup: License was created during upload window + self._when_provider_had_license_created_from_upload() + + # Create an encumbrance update that happens WITHIN the upload window + # but is NOT an upload-related update type + encumbrance_time = self.default_upload_datetime + timedelta(minutes=1) + self.test_data_generator.put_default_license_update_record_in_provider_table( + { + 'providerId': self.provider_id, + 'compact': self.compact, + 'jurisdiction': self.license_jurisdiction, + 'updateType': self.update_categories.ENCUMBRANCE, # Not an upload-related category + 'createDate': encumbrance_time, + 'effectiveDate': encumbrance_time, + 'updatedValues': { + 'encumberedStatus': 'encumbered', + }, + } + ) + + # Execute: Perform rollback + event = self._generate_test_event() + result = rollback_license_upload(event, Mock()) + + # Assert: Rollback completed but provider was skipped + self.assertEqual('COMPLETE', result['rollbackStatus']) + self.assertEqual(0, result['providersReverted']) + self.assertEqual(1, result['providersSkipped']) + + # Verify: License record and encumbrance update still exist (not rolled back) + provider_records = self.config.data_client.get_provider_user_records( + compact=self.compact, + provider_id=self.provider_id, + include_update_tier=UpdateTierEnum.TIER_THREE, + ) + licenses = provider_records.get_license_records() + self.assertEqual(len(licenses), 1, 'License should still exist') + license_updates = provider_records.get_all_license_update_records() + self.assertEqual(1, len(license_updates), 'Encumbrance update should still exist') + + # Verify S3 results contain skip details + s3_key = f'licenseUploadRollbacks/{MOCK_EXECUTION_NAME}/results.json' + s3_obj = self.config.s3_client.get_object(Bucket=self.config.disaster_recovery_results_bucket_name, Key=s3_key) + results_data = json.loads(s3_obj['Body'].read().decode('utf-8')) + + self.assertEqual(1, len(results_data['skippedProviderDetails'])) + skipped_detail = results_data['skippedProviderDetails'][0] + self.assertEqual(self.provider_id, skipped_detail['providerId']) + self.assertIn('Manual review required', skipped_detail['reason']) diff --git a/backend/compact-connect/lambdas/python/migration/migrate_update_sort_keys/main.py b/backend/compact-connect/lambdas/python/migration/migrate_update_sort_keys/main.py new file mode 100644 index 000000000..87a7bf2f0 --- /dev/null +++ b/backend/compact-connect/lambdas/python/migration/migrate_update_sort_keys/main.py @@ -0,0 +1,187 @@ +from boto3.dynamodb.conditions import Attr +from cc_common.config import config, logger +from cc_common.data_model.provider_record_util import ( + LicenseUpdateData, + PrivilegeUpdateData, + ProviderRecordType, + ProviderUpdateData, +) +from cc_common.exceptions import CCInternalException +from custom_resource_handler import CustomResourceHandler, CustomResourceResponse + + +class UpdateRecordSortKeyMigration(CustomResourceHandler): + """Migration for migrating update record sort keys to support license upload rollbacks""" + + def on_create(self, properties: dict) -> None: + do_migration(properties) + + def on_update(self, properties: dict) -> None: + """ + No-op on update. + """ + + def on_delete(self, _properties: dict) -> CustomResourceResponse | None: + """ + No-op on delete. + """ + + +on_event = UpdateRecordSortKeyMigration('update-record-sort-keys') + + +def do_migration(_properties: dict) -> None: + """ + This migration performs the following: + - Scans the provider table for all update records + - For each update record, load the records and serialize it again, + so the schema classes will generate the new sort key patterns + - Recreate the records by deleting the update records with the old sort key and storing the migrated records. + """ + logger.info('Starting update record sort key migration') + + # Scan for all update records + update_records = [] + scan_pagination = {} + + while True: + response = config.provider_table.scan( + FilterExpression=Attr('type').eq(ProviderRecordType.LICENSE_UPDATE) + | Attr('type').eq(ProviderRecordType.PROVIDER_UPDATE) + | Attr('type').eq(ProviderRecordType.PRIVILEGE_UPDATE), + **scan_pagination, + ) + + items = response.get('Items', []) + update_records.extend(items) + logger.info(f'Found {len(items)} update records in current scan batch') + + # Check if we need to continue pagination + last_evaluated_key = response.get('LastEvaluatedKey') + if not last_evaluated_key: + break + + scan_pagination = {'ExclusiveStartKey': last_evaluated_key} + + logger.info(f'Found {len(update_records)} total update records to process') + + if not update_records: + logger.info('No update records found, migration complete') + return + + # Process records in batches of 50 (DynamoDB transaction limit is 100 items, + # and each record generates 2 items: 1 update + 1 delete) + batch_size = 50 + + for i in range(0, len(update_records), batch_size): + batch = update_records[i : i + batch_size] + logger.info(f'Processing batch {i // batch_size + 1} with {len(batch)} records') + + _process_batch(batch) + logger.info(f'Processed batch {i // batch_size + 1}') + + +def _generate_delete_transaction_item(pk: str, sk: str) -> dict: + """ + Generate a delete transaction item for a provider record. + :param pk: The primary key of the provider record + :param sk: The sort key of the provider record + :return: Delete transaction item + """ + return { + 'Delete': { + 'TableName': config.provider_table.table_name, + 'Key': { + 'pk': pk, + 'sk': sk, + }, + } + } + + +def _generate_put_transaction_item(item: dict) -> dict: + """ + Generate a put transaction item for a provider record. + :param item: The provider record to put. + :return: Put transaction item + """ + return { + 'Put': { + 'TableName': config.provider_table.table_name, + 'Item': item, + } + } + + +def _generate_transaction_items(original_update_record: dict) -> list[dict]: + """ + In the case of a provider update record, we add a createDate field based on the dateOfUpdate field. + Then we use the ProviderUpdateData class to serialize the record and return the transaction items. + (one to delete the old record and one to create the new record) + + :param original_update_record: The provider update record to process + :return: List of transaction items + """ + # grab the old pk and sk from the object + old_pk = original_update_record['pk'] + old_sk = original_update_record['sk'] + record_type = original_update_record.get('type') + if record_type == ProviderRecordType.PROVIDER_UPDATE: + data_class = ProviderUpdateData + elif record_type == ProviderRecordType.LICENSE_UPDATE: + data_class = LicenseUpdateData + elif record_type == ProviderRecordType.PRIVILEGE_UPDATE: + data_class = PrivilegeUpdateData + else: + logger.error('invalid record type found', record_type=record_type, pk=old_pk, sk=old_sk) + raise CCInternalException('invalid record type found') + + # Performing deserialization/serialization on the record, which will generate + # the new pk/sks values we are migrating to. + + update_data = data_class.from_database_record(original_update_record) + migrated_provider_update_record = update_data.serialize_to_database_record() + # retain original dateOfUpdate value + migrated_provider_update_record['dateOfUpdate'] = original_update_record['dateOfUpdate'] + + logger.info( + 'Prepared update items for create date', + old_pk=old_pk, + old_sk=old_sk, + updated_pk=migrated_provider_update_record['pk'], + updated_sk=migrated_provider_update_record['sk'], + ) + + # delete old record with old pk/sk, and create new one + return [ + _generate_delete_transaction_item(pk=old_pk, sk=old_sk), + _generate_put_transaction_item(migrated_provider_update_record), + ] + + +def _process_batch(update_records: list[dict]) -> None: + """ + Process a batch of update records. + + :param update_records: List of update records to process + """ + transaction_items = [] + + for update_record in update_records: + try: + transaction_items.extend(_generate_transaction_items(update_record)) + except Exception as e: # noqa: BLE001 + logger.error( + 'Error preparing update items for update record, skipping.', + exc_info=e, + pk=update_record.get('pk'), + sk=update_record.get('sk'), + ) + + # Execute the transaction + if transaction_items: + logger.info(f'Executing transaction with {len(transaction_items)} items') + config.provider_table.meta.client.transact_write_items(TransactItems=transaction_items) + logger.info('Transaction completed successfully') + else: + logger.warning('No valid transaction items to process in this batch') diff --git a/backend/compact-connect/lambdas/python/migration/tests/__init__.py b/backend/compact-connect/lambdas/python/migration/tests/__init__.py index aa3942fac..6a10f0d67 100644 --- a/backend/compact-connect/lambdas/python/migration/tests/__init__.py +++ b/backend/compact-connect/lambdas/python/migration/tests/__init__.py @@ -1,3 +1,4 @@ +import json import os from unittest import TestCase from unittest.mock import MagicMock @@ -11,11 +12,89 @@ def setUpClass(cls): os.environ.update( { # Set to 'true' to enable debug logging - 'DEBUG': 'false', + 'DEBUG': 'true', + 'ALLOWED_ORIGINS': '["https://example.org"]', 'AWS_DEFAULT_REGION': 'us-east-1', - 'COMPACTS': '["aslp", "octp", "coun"]', - 'JURISDICTIONS': '["oh", "ky", "ne"]', + 'EVENT_BUS_NAME': 'license-data-events', + 'PROVIDER_TABLE_NAME': 'provider-table', + 'RATE_LIMITING_TABLE_NAME': 'rate-limiting-table', + 'SSN_TABLE_NAME': 'ssn-table', 'COMPACT_CONFIGURATION_TABLE_NAME': 'compact-configuration-table', + 'ENVIRONMENT_NAME': 'test', + 'PROV_FAM_GIV_MID_INDEX_NAME': 'providerFamGivMid', + 'FAM_GIV_INDEX_NAME': 'famGiv', + 'LICENSE_GSI_NAME': 'licenseGSI', + 'PROV_DATE_OF_UPDATE_INDEX_NAME': 'providerDateOfUpdate', + 'SSN_INDEX_NAME': 'ssnIndex', + 'COMPACTS': '["aslp", "octp", "coun"]', + 'JURISDICTIONS': json.dumps( + [ + 'al', + 'ak', + 'az', + 'ar', + 'ca', + 'co', + 'ct', + 'de', + 'dc', + 'fl', + 'ga', + 'hi', + 'id', + 'il', + 'in', + 'ia', + 'ks', + 'ky', + 'la', + 'me', + 'md', + 'ma', + 'mi', + 'mn', + 'ms', + 'mo', + 'mt', + 'ne', + 'nv', + 'nh', + 'nj', + 'nm', + 'ny', + 'nc', + 'nd', + 'oh', + 'ok', + 'or', + 'pa', + 'pr', + 'ri', + 'sc', + 'sd', + 'tn', + 'tx', + 'ut', + 'vt', + 'va', + 'vi', + 'wa', + 'wv', + 'wi', + 'wy', + ] + ), + 'LICENSE_TYPES': json.dumps( + { + 'aslp': [ + {'name': 'audiologist', 'abbreviation': 'aud'}, + {'name': 'speech-language pathologist', 'abbreviation': 'slp'}, + ], + 'coun': [ + {'name': 'licensed professional counselor', 'abbreviation': 'lpc'}, + ], + }, + ), }, ) # Monkey-patch config object to be sure we have it based diff --git a/backend/compact-connect/lambdas/python/migration/tests/function/__init__.py b/backend/compact-connect/lambdas/python/migration/tests/function/__init__.py new file mode 100644 index 000000000..bebb625a7 --- /dev/null +++ b/backend/compact-connect/lambdas/python/migration/tests/function/__init__.py @@ -0,0 +1,95 @@ +import logging +import os + +import boto3 +from moto import mock_aws + +from tests import TstLambdas + +logger = logging.getLogger(__name__) +logging.basicConfig() +logger.setLevel(logging.DEBUG if os.environ.get('DEBUG', 'false') == 'true' else logging.INFO) + + +@mock_aws +class TstFunction(TstLambdas): + """Base class to set up Moto mocking and create mock AWS resources for functional testing""" + + def setUp(self): # noqa: N801 invalid-name + super().setUp() + self.build_resources() + + # these must be imported within the tests, since they import modules which require + # environment variables that are not set until the TstLambdas class is initialized + import cc_common.config + from common_test.test_data_generator import TestDataGenerator + + cc_common.config.config = cc_common.config._Config() # noqa: SLF001 protected-access + self.config = cc_common.config.config + self.test_data_generator = TestDataGenerator + + self.addCleanup(self.delete_resources) + + def build_resources(self): + # in the case of DR, the lambda sync solution should be table agnostic, since we are performing the same + # cleanup and restoration process regardless of the table that is being recovered + self.create_provider_table() + + def create_provider_table(self): + self._provider_table = boto3.resource('dynamodb').create_table( + AttributeDefinitions=[ + {'AttributeName': 'pk', 'AttributeType': 'S'}, + {'AttributeName': 'sk', 'AttributeType': 'S'}, + {'AttributeName': 'providerFamGivMid', 'AttributeType': 'S'}, + {'AttributeName': 'providerDateOfUpdate', 'AttributeType': 'S'}, + {'AttributeName': 'licenseGSIPK', 'AttributeType': 'S'}, + {'AttributeName': 'licenseGSISK', 'AttributeType': 'S'}, + {'AttributeName': 'licenseUploadDateGSIPK', 'AttributeType': 'S'}, + {'AttributeName': 'licenseUploadDateGSISK', 'AttributeType': 'S'}, + ], + TableName=os.environ['PROVIDER_TABLE_NAME'], + KeySchema=[{'AttributeName': 'pk', 'KeyType': 'HASH'}, {'AttributeName': 'sk', 'KeyType': 'RANGE'}], + BillingMode='PAY_PER_REQUEST', + GlobalSecondaryIndexes=[ + { + 'IndexName': os.environ['PROV_FAM_GIV_MID_INDEX_NAME'], + 'KeySchema': [ + {'AttributeName': 'sk', 'KeyType': 'HASH'}, + {'AttributeName': 'providerFamGivMid', 'KeyType': 'RANGE'}, + ], + 'Projection': {'ProjectionType': 'ALL'}, + }, + { + 'IndexName': os.environ['PROV_DATE_OF_UPDATE_INDEX_NAME'], + 'KeySchema': [ + {'AttributeName': 'sk', 'KeyType': 'HASH'}, + {'AttributeName': 'providerDateOfUpdate', 'KeyType': 'RANGE'}, + ], + 'Projection': {'ProjectionType': 'ALL'}, + }, + { + 'IndexName': os.environ['LICENSE_GSI_NAME'], + 'KeySchema': [ + {'AttributeName': 'licenseGSIPK', 'KeyType': 'HASH'}, + {'AttributeName': 'licenseGSISK', 'KeyType': 'RANGE'}, + ], + 'Projection': {'ProjectionType': 'ALL'}, + }, + { + 'IndexName': 'licenseUploadDateGSI', + 'KeySchema': [ + {'AttributeName': 'licenseUploadDateGSIPK', 'KeyType': 'HASH'}, + {'AttributeName': 'licenseUploadDateGSISK', 'KeyType': 'RANGE'}, + ], + 'Projection': { + 'ProjectionType': 'INCLUDE', + 'NonKeyAttributes': [ + 'providerId', + ], + }, + }, + ], + ) + + def delete_resources(self): + self._provider_table.delete() diff --git a/backend/compact-connect/lambdas/python/migration/tests/function/test_migrate_update_sort_keys.py b/backend/compact-connect/lambdas/python/migration/tests/function/test_migrate_update_sort_keys.py new file mode 100644 index 000000000..dff07ff85 --- /dev/null +++ b/backend/compact-connect/lambdas/python/migration/tests/function/test_migrate_update_sort_keys.py @@ -0,0 +1,142 @@ +from datetime import datetime +from unittest.mock import patch + +from common_test.test_constants import ( + DEFAULT_LICENSE_JURISDICTION, + DEFAULT_LICENSE_UPDATE_CREATE_DATE, + DEFAULT_LICENSE_UPDATE_DATETIME, + DEFAULT_PRIVILEGE_JURISDICTION, + DEFAULT_PRIVILEGE_UPDATE_DATETIME, + DEFAULT_PROVIDER_UPDATE_DATETIME, +) +from moto import mock_aws + +from . import TstFunction + +MOCK_DATETIME_STRING = '2025-10-23T08:15:00+00:00' +MOCK_COMPACT = 'coun' +MOCK_PROVIDER_ID = '01d67765-76dd-47c8-b39a-8389445bb3b7' + + +@mock_aws +@patch('cc_common.config._Config.current_standard_datetime', datetime.fromisoformat(MOCK_DATETIME_STRING)) +class TestMigrateUpdateSortKeys(TstFunction): + """Test class for migrating update record sort keys.""" + + def test_should_migrate_provider_update_records_to_expected_pattern(self): + from migrate_update_sort_keys.main import do_migration + + old_provider_update_record = self.test_data_generator.generate_default_provider_update( + value_overrides={'compact': MOCK_COMPACT, 'providerId': MOCK_PROVIDER_ID} + ) + serialized_old_record = old_provider_update_record.serialize_to_database_record() + # replace sk with old pattern to simulate old record to be migrated + serialized_old_record['sk'] = 'coun#PROVIDER#UPDATE#1752526787/2f429ccda22d273b1ee4876f2917e27f' + del serialized_old_record['createDate'] + serialized_old_record['dateOfUpdate'] = DEFAULT_PROVIDER_UPDATE_DATETIME + self.config.provider_table.put_item(Item=serialized_old_record) + + # run migration + do_migration({}) + + # verify old record was deleted + old_record_resp = self.config.provider_table.get_item( + Key={'pk': serialized_old_record['pk'], 'sk': serialized_old_record['sk']} + ) + self.assertIsNone(old_record_resp.get('Item')) + + # verify new record was created with expected sk + expected_sk = ( + f'{MOCK_COMPACT}#UPDATE#2#provider/{DEFAULT_PROVIDER_UPDATE_DATETIME}/2f429ccda22d273b1ee4876f2917e27f' + ) + new_record = self.config.provider_table.get_item(Key={'pk': serialized_old_record['pk'], 'sk': expected_sk})[ + 'Item' + ] + + serialized_old_record['sk'] = expected_sk + # as part of migration, the createDate field will be populated with whatever the dateOfUpdate was + # so we expect that here + serialized_old_record['createDate'] = DEFAULT_PROVIDER_UPDATE_DATETIME + # only the sort key and the createDate should have been modified + self.assertEqual(serialized_old_record, new_record) + + def test_should_migrate_license_update_records_to_expected_pattern(self): + from migrate_update_sort_keys.main import do_migration + + old_license_update_record = self.test_data_generator.generate_default_license_update( + value_overrides={ + 'compact': MOCK_COMPACT, + 'providerId': MOCK_PROVIDER_ID, + 'licenseType': 'licensed professional counselor', + } + ) + serialized_old_record = old_license_update_record.serialize_to_database_record() + # replace sk with old pattern to simulate old record to be migrated + serialized_old_record['sk'] = ( + f'{MOCK_COMPACT}#PROVIDER#license/{DEFAULT_LICENSE_JURISDICTION}/lpc#UPDATE#1752526787/21554583eb71ccc5f8aa5988c8a50ac2' + ) + serialized_old_record['dateOfUpdate'] = DEFAULT_LICENSE_UPDATE_DATETIME + self.config.provider_table.put_item(Item=serialized_old_record) + + # run migration + do_migration({}) + + # verify old record was deleted + old_record_resp = self.config.provider_table.get_item( + Key={'pk': serialized_old_record['pk'], 'sk': serialized_old_record['sk']} + ) + self.assertIsNone(old_record_resp.get('Item')) + + # verify new record was created with expected sk + expected_sk = ( + f'{MOCK_COMPACT}#UPDATE#3#license/{DEFAULT_LICENSE_JURISDICTION}/lpc' + f'/{DEFAULT_LICENSE_UPDATE_CREATE_DATE}/21554583eb71ccc5f8aa5988c8a50ac2' + ) + new_record = self.config.provider_table.get_item(Key={'pk': serialized_old_record['pk'], 'sk': expected_sk})[ + 'Item' + ] + serialized_old_record['sk'] = expected_sk + # nothing on the record should have changed other than the sort key + self.assertEqual(serialized_old_record, new_record) + + def test_should_migrate_privilege_update_records_to_expected_pattern(self): + from migrate_update_sort_keys.main import do_migration + + mock_create_date = '2025-07-07T07:07:07+00:00' + + old_privilege_update_record = self.test_data_generator.generate_default_privilege_update( + value_overrides={ + 'compact': MOCK_COMPACT, + 'providerId': MOCK_PROVIDER_ID, + 'licenseType': 'licensed professional counselor', + 'createDate': datetime.fromisoformat(mock_create_date), + } + ) + serialized_old_record = old_privilege_update_record.serialize_to_database_record() + # replace sk with old pattern to simulate old record to be migrated + serialized_old_record['sk'] = ( + f'{MOCK_COMPACT}#PROVIDER#privilege/{DEFAULT_PRIVILEGE_JURISDICTION}/lpc#UPDATE#1752526787/399abde0989ad5e936920a3ba9f0944a' + ) + serialized_old_record['dateOfUpdate'] = DEFAULT_PRIVILEGE_UPDATE_DATETIME + self.config.provider_table.put_item(Item=serialized_old_record) + + # run migration + do_migration({}) + + # verify old record was deleted + old_record_resp = self.config.provider_table.get_item( + Key={'pk': serialized_old_record['pk'], 'sk': serialized_old_record['sk']} + ) + self.assertIsNone(old_record_resp.get('Item')) + + # verify new record was created with expected sk + expected_sk = ( + f'{MOCK_COMPACT}#UPDATE#1#privilege/{DEFAULT_PRIVILEGE_JURISDICTION}/lpc' + f'/{mock_create_date}/399abde0989ad5e936920a3ba9f0944a' + ) + new_record = self.config.provider_table.get_item(Key={'pk': serialized_old_record['pk'], 'sk': expected_sk})[ + 'Item' + ] + serialized_old_record['sk'] = expected_sk + # nothing on the record should have changed other than the sort key + self.assertEqual(serialized_old_record, new_record) diff --git a/backend/compact-connect/lambdas/python/provider-data-v1/handlers/__init__.py b/backend/compact-connect/lambdas/python/provider-data-v1/handlers/__init__.py index 7b3bff40a..8939ee726 100644 --- a/backend/compact-connect/lambdas/python/provider-data-v1/handlers/__init__.py +++ b/backend/compact-connect/lambdas/python/provider-data-v1/handlers/__init__.py @@ -1,4 +1,5 @@ from cc_common.config import config, logger +from cc_common.data_model.update_tier_enum import UpdateTierEnum from cc_common.utils import logger_inject_kwargs @@ -13,5 +14,8 @@ def get_provider_information(compact: str, provider_id: str) -> dict: :param provider_id: The provider's unique identifier. :return: Provider profile information. """ - provider_user_records = config.data_client.get_provider_user_records(compact=compact, provider_id=provider_id) + # Collect all main provider records and privilege update records, which are included in tier one. + provider_user_records = config.data_client.get_provider_user_records( + compact=compact, provider_id=provider_id, include_update_tier=UpdateTierEnum.TIER_ONE + ) return provider_user_records.generate_api_response_object() diff --git a/backend/compact-connect/lambdas/python/provider-data-v1/handlers/bulk_upload.py b/backend/compact-connect/lambdas/python/provider-data-v1/handlers/bulk_upload.py index 7c518bac3..81284a585 100644 --- a/backend/compact-connect/lambdas/python/provider-data-v1/handlers/bulk_upload.py +++ b/backend/compact-connect/lambdas/python/provider-data-v1/handlers/bulk_upload.py @@ -148,6 +148,7 @@ def process_bulk_upload_file( failed_validation_count = 0 # track which ssns were included in this file to detect duplicates, # which are not allowed within the same file upload + # We track by (ssn, licenseType) tuple to allow same SSN for different license types ssns_in_file_upload = {} with EventBatchWriter(config.events_client) as event_writer: @@ -158,17 +159,18 @@ def process_bulk_upload_file( # dict() here, because it prevents `compact` and `jurisdiction` from being allowed in the # raw_license validated_license = schema.load(dict(compact=compact, jurisdiction=jurisdiction, **raw_license)) - # verify that this ssn has not been used previously in the same batch - license_ssn = validated_license['ssn'] + # verify that this ssn/licenseType combination has not been used previously in the same batch + ssn_key = (validated_license['ssn'], validated_license['licenseType']) if duplicate_ssn_check_flag_enabled: - matched_ssn_index = ssns_in_file_upload.get(license_ssn) + matched_ssn_index = ssns_in_file_upload.get(ssn_key) if matched_ssn_index: raise ValidationError( - message=f'Duplicate License SSN detected. SSN matches with record ' - f'{matched_ssn_index}. Every record must have a unique SSN within the same ' - f'file.' + message=f'Duplicate License SSN detected for license type ' + f'{validated_license["licenseType"]}. SSN matches with record ' + f'{matched_ssn_index}. Every record must have a unique SSN per license type ' + f'within the same file.' ) - ssns_in_file_upload.update({license_ssn: i + 1}) + ssns_in_file_upload.update({ssn_key: i + 1}) except TypeError as e: # This will be raised, if `raw_license` includes compact and/or jurisdiction fields logger.error('License contains unsupported fields', fields=list(raw_license.keys()), exc_info=e) diff --git a/backend/compact-connect/lambdas/python/provider-data-v1/handlers/ingest.py b/backend/compact-connect/lambdas/python/provider-data-v1/handlers/ingest.py index e84ace0b0..0293dd22b 100644 --- a/backend/compact-connect/lambdas/python/provider-data-v1/handlers/ingest.py +++ b/backend/compact-connect/lambdas/python/provider-data-v1/handlers/ingest.py @@ -1,10 +1,12 @@ import json +from copy import deepcopy from boto3.dynamodb.types import TypeSerializer from cc_common.config import config, logger from cc_common.data_model.provider_record_util import ProviderRecordType, ProviderRecordUtility from cc_common.data_model.schema import LicenseRecordSchema from cc_common.data_model.schema.common import ActiveInactiveStatus, UpdateCategory +from cc_common.data_model.schema.license import LicenseData from cc_common.data_model.schema.license.ingest import LicenseIngestSchema from cc_common.data_model.schema.license.record import LicenseUpdateRecordSchema from cc_common.data_model.schema.provider import ProviderData @@ -113,15 +115,7 @@ def ingest_license_message(message: dict): # We fully JSON serialize then load again so that we have a completely independent copy of the data posted_license_record = license_record_schema.load(json.loads(dumped_license)) - dynamo_transactions = [ - # Put the posted license - { - 'Put': { - 'TableName': config.provider_table_name, - 'Item': TypeSerializer().serialize(json.loads(dumped_license))['M'], - }, - }, - ] + dynamo_transactions = [] home_jurisdiction = None try: @@ -169,6 +163,26 @@ def ingest_license_message(message: dict): dynamo_transactions=dynamo_transactions, data_events=data_events, ) + # now grab the firstUploadDate from the existing record if available and put it in the posted_license + # for the license upload date GSI + if existing_license.get('firstUploadDate'): + posted_license_record['firstUploadDate'] = existing_license.get('firstUploadDate') + else: + # If this is the first time creating the license record, + # set the firstUploadDate to the current time for license upload date GSI tracking + posted_license_record['firstUploadDate'] = config.current_standard_datetime + + # write the record to the table to reflect the latest values from the upload + license_data = LicenseData.create_new(deepcopy(posted_license_record)) + dynamo_transactions.append( + { + 'Put': { + 'TableName': config.provider_table_name, + 'Item': TypeSerializer().serialize(license_data.serialize_to_database_record())['M'], + } + } + ) + licenses_organized.setdefault(posted_license_record['jurisdiction'], {}) licenses_organized[posted_license_record['jurisdiction']][posted_license_record['licenseType']] = ( posted_license_record @@ -219,7 +233,8 @@ def _process_license_update(*, existing_license: dict, new_license: dict, dynamo :param list dynamo_transactions: The dynamodb transaction array to append records to """ # Remove fields that are calculated at runtime, not stored in the database - dynamic_keys = {'dateOfUpdate', 'status'} + # uploadDate is metadata tracking when the license was first uploaded, not part of the license data + dynamic_keys = {'dateOfUpdate', 'status', 'uploadDate'} updated_values = { key: value for key, value in new_license.items() @@ -312,6 +327,7 @@ def _populate_update_record(*, existing_license: dict, updated_values: dict, rem 'licenseType': existing_license['licenseType'], 'createDate': now, 'effectiveDate': now, + 'uploadDate': now, # Track when this update was created during upload 'previous': existing_license, 'updatedValues': updated_values, # We'll only include the removed values field if there are some diff --git a/backend/compact-connect/lambdas/python/provider-data-v1/handlers/licenses.py b/backend/compact-connect/lambdas/python/provider-data-v1/handlers/licenses.py index 405a22b5e..18ef1af65 100644 --- a/backend/compact-connect/lambdas/python/provider-data-v1/handlers/licenses.py +++ b/backend/compact-connect/lambdas/python/provider-data-v1/handlers/licenses.py @@ -72,15 +72,15 @@ def post_licenses(event: dict, context: LambdaContext): # noqa: ARG001 unused-a } ) if duplicate_ssn_check_flag_enabled: - # verify that none of the SSNs are repeats within the same batch - license_ssns = [license_record['ssn'] for license_record in licenses] - if len(set(license_ssns)) < len(license_ssns): + # verify that none of the SSN+LicenseType combinations are repeats within the same batch + license_keys = [(license_record['ssn'], license_record['licenseType']) for license_record in licenses] + if len(set(license_keys)) < len(license_keys): raise CCInvalidRequestCustomResponseException( response_body={ 'message': 'Invalid license records in request. See errors for more detail.', 'errors': { - 'SSN': 'Same SSN detected on multiple rows. ' - 'Every record must have a unique SSN within the same request.' + 'SSN': 'Same SSN for the same license type detected on multiple rows. ' + 'Every record must have a unique SSN per license type within the same request.' }, } ) diff --git a/backend/compact-connect/lambdas/python/provider-data-v1/handlers/state_api.py b/backend/compact-connect/lambdas/python/provider-data-v1/handlers/state_api.py index 5d9defd9e..8a6995d4f 100644 --- a/backend/compact-connect/lambdas/python/provider-data-v1/handlers/state_api.py +++ b/backend/compact-connect/lambdas/python/provider-data-v1/handlers/state_api.py @@ -12,6 +12,7 @@ StateProviderDetailGeneralResponseSchema, StateProviderDetailPrivateResponseSchema, ) +from cc_common.data_model.update_tier_enum import UpdateTierEnum from cc_common.exceptions import CCInternalException, CCInvalidRequestException, CCNotFoundException from cc_common.signature_auth import optional_signature_auth, required_signature_auth from cc_common.utils import ( @@ -147,7 +148,10 @@ def get_provider(event: dict, context: LambdaContext): # noqa: ARG001 unused-ar raise CCInvalidRequestException('Missing required field') from e with logger.append_context_keys(compact=compact, provider_id=provider_id, jurisdiction=jurisdiction): - provider_user_records = config.data_client.get_provider_user_records(compact=compact, provider_id=provider_id) + # Collect all main provider records and privilege update records, which are included in tier one. + provider_user_records = config.data_client.get_provider_user_records( + compact=compact, provider_id=provider_id, include_update_tier=UpdateTierEnum.TIER_ONE + ) # Get caller's scopes to determine private data access scopes = get_event_scopes(event) diff --git a/backend/compact-connect/lambdas/python/provider-data-v1/tests/function/__init__.py b/backend/compact-connect/lambdas/python/provider-data-v1/tests/function/__init__.py index fe0ac38f7..ba2061beb 100644 --- a/backend/compact-connect/lambdas/python/provider-data-v1/tests/function/__init__.py +++ b/backend/compact-connect/lambdas/python/provider-data-v1/tests/function/__init__.py @@ -129,6 +129,8 @@ def create_provider_table(self): {'AttributeName': 'providerDateOfUpdate', 'AttributeType': 'S'}, {'AttributeName': 'licenseGSIPK', 'AttributeType': 'S'}, {'AttributeName': 'licenseGSISK', 'AttributeType': 'S'}, + {'AttributeName': 'licenseUploadDateGSIPK', 'AttributeType': 'S'}, + {'AttributeName': 'licenseUploadDateGSISK', 'AttributeType': 'S'}, ], TableName=os.environ['PROVIDER_TABLE_NAME'], KeySchema=[{'AttributeName': 'pk', 'KeyType': 'HASH'}, {'AttributeName': 'sk', 'KeyType': 'RANGE'}], @@ -158,6 +160,17 @@ def create_provider_table(self): ], 'Projection': {'ProjectionType': 'ALL'}, }, + { + 'IndexName': 'licenseUploadDateGSI', + 'KeySchema': [ + {'AttributeName': 'licenseUploadDateGSIPK', 'KeyType': 'HASH'}, + {'AttributeName': 'licenseUploadDateGSISK', 'KeyType': 'RANGE'}, + ], + 'Projection': { + 'ProjectionType': 'INCLUDE', + 'NonKeyAttributes': ['providerId'], + }, + }, ], ) diff --git a/backend/compact-connect/lambdas/python/provider-data-v1/tests/function/test_data_model/test_provider_transformations.py b/backend/compact-connect/lambdas/python/provider-data-v1/tests/function/test_data_model/test_provider_transformations.py index 9567a88cc..d8a1d2577 100644 --- a/backend/compact-connect/lambdas/python/provider-data-v1/tests/function/test_data_model/test_provider_transformations.py +++ b/backend/compact-connect/lambdas/python/provider-data-v1/tests/function/test_data_model/test_provider_transformations.py @@ -2,16 +2,18 @@ from datetime import date, datetime from unittest.mock import patch -from boto3.dynamodb.conditions import Key +from cc_common.data_model.update_tier_enum import UpdateTierEnum from moto import mock_aws from .. import TstFunction +MOCK_CURRENT_DATETIME_STRING = '2024-11-08T23:59:59+00:00' + @mock_aws class TestTransformations(TstFunction): # Yes, this is an excessively long method. We're going with it for sake of a single illustrative test. - @patch('cc_common.config._Config.current_standard_datetime', datetime.fromisoformat('2024-11-08T23:59:59+00:00')) + @patch('cc_common.config._Config.current_standard_datetime', datetime.fromisoformat(MOCK_CURRENT_DATETIME_STRING)) @patch('cc_common.config._Config.license_preprocessing_queue') def test_transformations(self, mock_license_preprocessing_queue): """Provider data undergoes several transformations from when a license is first posted, stored into the @@ -98,6 +100,9 @@ def test_transformations(self, mock_license_preprocessing_queue): expected_provider = json.load(f) # this should be set during the registration flow expected_provider['currentHomeJurisdiction'] = 'oh' + # provider should be active and compact eligible + expected_provider['licenseStatus'] = 'active' + expected_provider['compactEligibility'] = 'eligible' # register the provider in the system client.process_registration_values( @@ -133,24 +138,33 @@ def test_transformations(self, mock_license_preprocessing_queue): ], ) - # Get the provider straight from the table, to inspect them - resp = self._provider_table.query( - Select='ALL_ATTRIBUTES', - KeyConditionExpression=Key('pk').eq(f'aslp#PROVIDER#{provider_id}') - & Key('sk').begins_with('aslp#PROVIDER'), + # Get the provider and all update records straight from the table, to inspect them + provider_user_records: ProviderUserRecords = self.config.data_client.get_provider_user_records( + compact='aslp', provider_id=provider_id, include_update_tier=UpdateTierEnum.TIER_THREE ) + # One record for each of: provider, providerUpdate, license, # privilege, and militaryAffiliation - self.assertEqual(5, len(resp['Items'])) - records = {item['type']: item for item in resp['Items']} + self.assertEqual(5, len(provider_user_records.provider_records)) + records = {item['type']: item for item in provider_user_records.provider_records} # Convert this to the data type expected from DynamoDB expected_provider['privilegeJurisdictions'] = set(expected_provider['privilegeJurisdictions']) with open('../common/tests/resources/dynamo/license.json') as f: expected_license = json.load(f) + # license should be active and compact eligible + expected_license['licenseStatus'] = 'active' + expected_license['compactEligibility'] = 'eligible' + expected_license['firstUploadDate'] = MOCK_CURRENT_DATETIME_STRING + expected_license['licenseUploadDateGSIPK'] = 'C#aslp#J#oh#D#2024-11' + expected_license['licenseUploadDateGSISK'] = ( + 'TIME#1731110399#LT#slp#PID#89a6377e-c3a5-40e5-bca5-317ec854c570' + ) with open('../common/tests/resources/dynamo/privilege.json') as f: expected_privilege = json.load(f) + # privilege status should be active + expected_privilege['status'] = 'active' with open('../common/tests/resources/dynamo/military-affiliation.json') as f: expected_military_affiliation = json.load(f) # in this case, the status will be initializing, since it is not set to active until diff --git a/backend/compact-connect/lambdas/python/provider-data-v1/tests/function/test_handlers/test_bulk_upload.py b/backend/compact-connect/lambdas/python/provider-data-v1/tests/function/test_handlers/test_bulk_upload.py index 865f912a4..4103dcf15 100644 --- a/backend/compact-connect/lambdas/python/provider-data-v1/tests/function/test_handlers/test_bulk_upload.py +++ b/backend/compact-connect/lambdas/python/provider-data-v1/tests/function/test_handlers/test_bulk_upload.py @@ -343,8 +343,8 @@ def test_bulk_upload_prevents_repeated_ssns_within_the_same_file_upload(self): 'dateOfExpiration': '2026-01-01', }, 'errors': [ - 'Duplicate License SSN detected. SSN matches with record 1. ' - 'Every record must have a unique SSN within the same file.' + 'Duplicate License SSN detected for license type audiologist. SSN matches with record 1. ' + 'Every record must have a unique SSN per license type within the same file.' ], } ), @@ -353,6 +353,56 @@ def test_bulk_upload_prevents_repeated_ssns_within_the_same_file_upload(self): self.assertEqual(expected_entry, call_args) + def test_bulk_upload_allows_repeated_ssns_for_different_license_types(self): + """Test that duplicate SSNs within a CSV upload are allowed if the license types are different.""" + from handlers.bulk_upload import parse_bulk_upload_file + + # Create CSV content that includes duplicate SSNs but different license types + csv_content = ( + 'ssn,npi,licenseNumber,givenName,middleName,familyName,suffix,dateOfBirth,dateOfIssuance' + ',dateOfRenewal,dateOfExpiration,licenseStatus,compactEligibility,homeAddressStreet1' + ',homeAddressStreet2,homeAddressCity,homeAddressState,homeAddressPostalCode' + ',emailAddress,phoneNumber,licenseType,licenseStatusName\n' + '123-45-6789,1234567890,LICENSE123,John,Middle,Doe,Jr.,1990-01-01,2020-01-01,2021-01-01,2023-01-01,active,' + 'eligible,123 Main St,Apt 1,Columbus,OH,43215,test@example.com,+15551234567,audiologist,Active\n' + '123-45-6789,1234567890,LICENSE456,John,Middle,Doe,Jr.,1990-01-01,2023-01-01,2025-01-01,2026-01-01,active,' + 'eligible,123 Main St,Apt 1,Columbus,OH,43215,test@example.com,+15551234567,speech-language pathologist,' + 'Active' + ) + + # Upload the CSV content directly to the mock S3 bucket + object_key = f'aslp/oh/{uuid4().hex}' + self._bucket.put_object(Key=object_key, Body=csv_content) + + # Simulate the s3 bucket event + with open('../common/tests/resources/put-event.json') as f: + event = json.load(f) + + event['Records'][0]['s3']['bucket'] = { + 'name': self._bucket.name, + 'arn': f'arn:aws:s3:::{self._bucket.name}', + 'ownerIdentity': {'principalId': 'ASDFG123'}, + } + event['Records'][0]['s3']['object']['key'] = object_key + + parse_bulk_upload_file(event, self.mock_context) + + # Verify that both messages were sent to the preprocessing queue + messages = self._license_preprocessing_queue.receive_messages(MaxNumberOfMessages=10) + self.assertEqual(2, len(messages)) + + message_data_1 = json.loads(messages[0].body) + message_data_2 = json.loads(messages[1].body) + + # Verify the license types are correct + # Messages might not be in order, so we check both + license_types = {message_data_1['licenseType'], message_data_2['licenseType']} + self.assertEqual({'audiologist', 'speech-language pathologist'}, license_types) + + # Verify SSNs are the same + self.assertEqual(message_data_1['ssn'], '123-45-6789') + self.assertEqual(message_data_2['ssn'], '123-45-6789') + def test_bulk_upload_handles_bom_character(self): """Test that CSV files with BOM characters are handled correctly.""" from handlers.bulk_upload import parse_bulk_upload_file diff --git a/backend/compact-connect/lambdas/python/provider-data-v1/tests/function/test_handlers/test_encumbrance.py b/backend/compact-connect/lambdas/python/provider-data-v1/tests/function/test_handlers/test_encumbrance.py index ed11c68ea..42fdd2fba 100644 --- a/backend/compact-connect/lambdas/python/provider-data-v1/tests/function/test_handlers/test_encumbrance.py +++ b/backend/compact-connect/lambdas/python/provider-data-v1/tests/function/test_handlers/test_encumbrance.py @@ -139,7 +139,6 @@ def test_privilege_encumbrance_handler_adds_adverse_action_record_in_provider_da @patch('cc_common.feature_flag_client.is_feature_enabled', return_value=True) def test_privilege_encumbrance_handler_adds_privilege_update_record_in_provider_data_table(self, mock_flag): # noqa: ARG002 - from cc_common.data_model.schema.privilege import PrivilegeUpdateData from handlers.encumbrance import encumbrance_handler event, test_privilege_record = self._when_testing_privilege_encumbrance() @@ -148,18 +147,13 @@ def test_privilege_encumbrance_handler_adds_privilege_update_record_in_provider_ self.assertEqual(200, response['statusCode'], msg=json.loads(response['body'])) # Verify that the encumbrance record was added to the provider data table - # Perform a query to list all encumbrances for the provider using the starts_with key condition - privilege_update_records = self._provider_table.query( - Select='ALL_ATTRIBUTES', - KeyConditionExpression=Key('pk').eq(test_privilege_record.serialize_to_database_record()['pk']) - & Key('sk').begins_with( - f'{test_privilege_record.compact}#PROVIDER#privilege/{test_privilege_record.jurisdiction}/slp#UPDATE' - ), + privilege_update_records = ( + self.test_data_generator.query_privilege_update_records_for_given_record_from_database( + test_privilege_record + ) ) - self.assertEqual(1, len(privilege_update_records['Items'])) - item = privilege_update_records['Items'][0] - - loaded_privilege_update_data = PrivilegeUpdateData.from_database_record(item) + self.assertEqual(1, len(privilege_update_records)) + loaded_privilege_update_data = privilege_update_records[0] expected_privilege_update_data = self.test_data_generator.generate_default_privilege_update( value_overrides={ @@ -186,7 +180,6 @@ def test_privilege_encumbrance_handler_adds_privilege_update_record_in_provider_ self, mock_flag, # noqa: ARG002 ): - from cc_common.data_model.schema.privilege import PrivilegeUpdateData from handlers.encumbrance import encumbrance_handler event, test_privilege_record = self._when_testing_privilege_encumbrance() @@ -195,18 +188,13 @@ def test_privilege_encumbrance_handler_adds_privilege_update_record_in_provider_ self.assertEqual(200, response['statusCode'], msg=json.loads(response['body'])) # Verify that the encumbrance record was added to the provider data table - # Perform a query to list all encumbrances for the provider using the starts_with key condition - privilege_update_records = self._provider_table.query( - Select='ALL_ATTRIBUTES', - KeyConditionExpression=Key('pk').eq(test_privilege_record.serialize_to_database_record()['pk']) - & Key('sk').begins_with( - f'{test_privilege_record.compact}#PROVIDER#privilege/{test_privilege_record.jurisdiction}/slp#UPDATE' - ), + privilege_update_records = ( + self.test_data_generator.query_privilege_update_records_for_given_record_from_database( + test_privilege_record + ) ) - self.assertEqual(1, len(privilege_update_records['Items'])) - item = privilege_update_records['Items'][0] - - loaded_privilege_update_data = PrivilegeUpdateData.from_database_record(item) + self.assertEqual(1, len(privilege_update_records)) + loaded_privilege_update_data = privilege_update_records[0] expected_privilege_update_data = self.test_data_generator.generate_default_privilege_update( value_overrides={ @@ -515,7 +503,6 @@ def test_license_encumbrance_handler_adds_adverse_action_record_in_provider_data ) def test_license_encumbrance_handler_adds_license_update_record_in_provider_data_table(self): - from cc_common.data_model.schema.license import LicenseUpdateData from handlers.encumbrance import encumbrance_handler event, test_license_record = self._when_testing_valid_license_encumbrance() @@ -524,15 +511,11 @@ def test_license_encumbrance_handler_adds_license_update_record_in_provider_data self.assertEqual(200, response['statusCode'], msg=json.loads(response['body'])) # Verify that the update record was added for the license - license_update_records = self._provider_table.query( - Select='ALL_ATTRIBUTES', - KeyConditionExpression=Key('pk').eq(test_license_record.serialize_to_database_record()['pk']) - & Key('sk').begins_with( - f'{test_license_record.compact}#PROVIDER#license/{test_license_record.jurisdiction}/slp#UPDATE' - ), + license_update_records = self.test_data_generator.query_license_update_records_for_given_record_from_database( + test_license_record ) - self.assertEqual(1, len(license_update_records['Items'])) - item = license_update_records['Items'][0] + self.assertEqual(1, len(license_update_records)) + loaded_license_update_data = license_update_records[0] expected_license_update_data = self.test_data_generator.generate_default_license_update( value_overrides={ @@ -542,7 +525,6 @@ def test_license_encumbrance_handler_adds_license_update_record_in_provider_data 'effectiveDate': datetime.fromisoformat(TEST_ENCUMBRANCE_EFFECTIVE_DATETIME), } ) - loaded_license_update_data = LicenseUpdateData.from_database_record(item) self.assertEqual( expected_license_update_data.to_dict(), diff --git a/backend/compact-connect/lambdas/python/provider-data-v1/tests/function/test_handlers/test_investigation.py b/backend/compact-connect/lambdas/python/provider-data-v1/tests/function/test_handlers/test_investigation.py index 347361ac0..c89a412d1 100644 --- a/backend/compact-connect/lambdas/python/provider-data-v1/tests/function/test_handlers/test_investigation.py +++ b/backend/compact-connect/lambdas/python/provider-data-v1/tests/function/test_handlers/test_investigation.py @@ -3,6 +3,7 @@ from unittest.mock import patch from uuid import UUID +from cc_common.data_model.update_tier_enum import UpdateTierEnum from common_test.test_constants import ( DEFAULT_AA_SUBMITTING_USER_ID, DEFAULT_DATE_OF_UPDATE_TIMESTAMP, @@ -955,9 +956,11 @@ def test_closing_one_of_multiple_investigations_maintains_investigation_status(s provider_user_records = self.config.data_client.get_provider_user_records( compact=test_privilege_record.compact, provider_id=test_privilege_record.providerId, + include_update_tier=UpdateTierEnum.TIER_THREE, ) update_records = provider_user_records.get_update_records_for_privilege( - jurisdiction=test_privilege_record.jurisdiction, license_type=test_privilege_record.licenseType + jurisdiction=test_privilege_record.jurisdiction, + license_type=test_privilege_record.licenseType, ) investigation_update_records = [ @@ -1002,6 +1005,7 @@ def test_closing_one_of_multiple_investigations_maintains_investigation_status(s provider_user_records = self.config.data_client.get_provider_user_records( compact=test_privilege_record.compact, provider_id=test_privilege_record.providerId, + include_update_tier=UpdateTierEnum.TIER_THREE, ) updated_privilege_record = provider_user_records.get_privilege_records()[0] @@ -1020,6 +1024,7 @@ def test_closing_one_of_multiple_investigations_maintains_investigation_status(s provider_user_records = self.config.data_client.get_provider_user_records( compact=test_privilege_record.compact, provider_id=test_privilege_record.providerId, + include_update_tier=UpdateTierEnum.TIER_THREE, ) update_records = provider_user_records.get_update_records_for_privilege( jurisdiction=test_privilege_record.jurisdiction, license_type=test_privilege_record.licenseType @@ -1088,6 +1093,7 @@ def test_closing_one_of_multiple_investigations_maintains_investigation_status(s provider_user_records = self.config.data_client.get_provider_user_records( compact=test_license_record.compact, provider_id=test_license_record.providerId, + include_update_tier=UpdateTierEnum.TIER_THREE, ) investigation_records = provider_user_records.get_investigation_records_for_license( license_jurisdiction=test_license_record.jurisdiction, @@ -1119,6 +1125,7 @@ def test_closing_one_of_multiple_investigations_maintains_investigation_status(s provider_user_records = self.config.data_client.get_provider_user_records( compact=test_license_record.compact, provider_id=test_license_record.providerId, + include_update_tier=UpdateTierEnum.TIER_THREE, ) investigation_records = provider_user_records.get_investigation_records_for_license( license_jurisdiction=test_license_record.jurisdiction, @@ -1154,6 +1161,7 @@ def test_closing_one_of_multiple_investigations_maintains_investigation_status(s provider_user_records = self.config.data_client.get_provider_user_records( compact=test_license_record.compact, provider_id=test_license_record.providerId, + include_update_tier=UpdateTierEnum.TIER_THREE, ) updated_license_record = provider_user_records.get_license_records()[0] @@ -1188,6 +1196,7 @@ def test_closing_one_of_multiple_investigations_maintains_investigation_status(s provider_user_records = self.config.data_client.get_provider_user_records( compact=test_license_record.compact, provider_id=test_license_record.providerId, + include_update_tier=UpdateTierEnum.TIER_THREE, ) update_records = provider_user_records.get_update_records_for_license( jurisdiction=test_license_record.jurisdiction, license_type=test_license_record.licenseType @@ -1235,6 +1244,7 @@ def test_closing_one_of_multiple_investigations_maintains_investigation_status(s provider_user_records = self.config.data_client.get_provider_user_records( compact=test_license_record.compact, provider_id=test_license_record.providerId, + include_update_tier=UpdateTierEnum.TIER_THREE, ) updated_license_record = provider_user_records.get_license_records()[0] @@ -1253,6 +1263,7 @@ def test_closing_one_of_multiple_investigations_maintains_investigation_status(s provider_user_records = self.config.data_client.get_provider_user_records( compact=test_license_record.compact, provider_id=test_license_record.providerId, + include_update_tier=UpdateTierEnum.TIER_THREE, ) update_records = provider_user_records.get_update_records_for_license( jurisdiction=test_license_record.jurisdiction, license_type=test_license_record.licenseType diff --git a/backend/compact-connect/lambdas/python/provider-data-v1/tests/function/test_handlers/test_licenses.py b/backend/compact-connect/lambdas/python/provider-data-v1/tests/function/test_handlers/test_licenses.py index f338eb0e6..e6ce09d7d 100644 --- a/backend/compact-connect/lambdas/python/provider-data-v1/tests/function/test_handlers/test_licenses.py +++ b/backend/compact-connect/lambdas/python/provider-data-v1/tests/function/test_handlers/test_licenses.py @@ -392,13 +392,44 @@ def test_post_licenses_returns_400_if_repeated_ssns_detected(self): { 'message': 'Invalid license records in request. See errors for more detail.', 'errors': { - 'SSN': 'Same SSN detected on multiple rows. ' - 'Every record must have a unique SSN within the same request.', + 'SSN': 'Same SSN for the same license type detected on multiple rows. ' + 'Every record must have a unique SSN per license type within the same request.', }, }, json.loads(resp['body']), ) + def test_post_licenses_succeeds_with_same_ssn_different_license_types(self): + from handlers.licenses import post_licenses + + with open('../common/tests/resources/api-event.json') as f: + event = json.load(f) + + # The user has write permission for aslp/oh + event['requestContext']['authorizer']['claims']['scope'] = 'openid email aslp/readGeneral oh/aslp.write' + event['pathParameters'] = {'compact': 'aslp', 'jurisdiction': 'oh'} + + with open('../common/tests/resources/api/license-post.json') as f: + license_data_1 = json.load(f) + + # Create second license with same SSN but different license type + license_data_2 = license_data_1.copy() + license_data_1['licenseType'] = 'audiologist' + license_data_2['licenseType'] = 'speech-language pathologist' + + event['body'] = json.dumps([license_data_1, license_data_2]) + + # Add signature authentication headers + event = self._create_signed_event(event) + + resp = post_licenses(event, self.mock_context) + + self.assertEqual(200, resp['statusCode']) + + # assert that the messages were sent to the preprocessing queue + queue_messages = self._license_preprocessing_queue.receive_messages(MaxNumberOfMessages=10) + self.assertEqual(2, len(queue_messages)) + def test_post_licenses_strips_whitespace_from_string_fields(self): """Test that whitespace is stripped from all string fields in license data.""" from handlers.licenses import post_licenses diff --git a/backend/compact-connect/lambdas/python/provider-data-v1/tests/function/test_handlers/test_registration.py b/backend/compact-connect/lambdas/python/provider-data-v1/tests/function/test_handlers/test_registration.py index 506983a8e..f0e4ed6ae 100644 --- a/backend/compact-connect/lambdas/python/provider-data-v1/tests/function/test_handlers/test_registration.py +++ b/backend/compact-connect/lambdas/python/provider-data-v1/tests/function/test_handlers/test_registration.py @@ -822,6 +822,7 @@ def test_registration_creates_provider_update_record(self, mock_verify_recaptcha self.assertEqual( { 'compact': provider_data.compact, + 'createDate': update_data.createDate, 'dateOfUpdate': datetime.fromisoformat(DEFAULT_DATE_OF_UPDATE_TIMESTAMP), 'previous': { 'compact': provider_data.compact, diff --git a/backend/compact-connect/pipeline/backend_stage.py b/backend/compact-connect/pipeline/backend_stage.py index 318dfc887..6a83de8a9 100644 --- a/backend/compact-connect/pipeline/backend_stage.py +++ b/backend/compact-connect/pipeline/backend_stage.py @@ -4,6 +4,7 @@ from stacks.api_lambda_stack import ApiLambdaStack from stacks.api_stack import ApiStack +from stacks.data_migration_stack import DataMigrationStack from stacks.disaster_recovery_stack import DisasterRecoveryStack from stacks.event_listener_stack import EventListenerStack from stacks.event_state_stack import EventStateStack @@ -203,3 +204,19 @@ def __init__( environment_context=environment_context, standard_tags=standard_tags, ) + + # Stack to house data migration custom resources + # This stack depends on the API and event listener stacks to ensure + # all core infrastructure is in place before migrations run + self.data_migration_stack = DataMigrationStack( + self, + 'DataMigrationStack', + env=environment, + environment_name=environment_name, + environment_context=environment_context, + standard_tags=standard_tags, + persistent_stack=self.persistent_stack, + ) + # Explicitly declare the dependency to ensure proper deployment order + self.data_migration_stack.add_dependency(self.api_stack) + self.data_migration_stack.add_dependency(self.event_listener_stack) diff --git a/backend/compact-connect/stacks/data_migration_stack/__init__.py b/backend/compact-connect/stacks/data_migration_stack/__init__.py new file mode 100644 index 000000000..5413c5b50 --- /dev/null +++ b/backend/compact-connect/stacks/data_migration_stack/__init__.py @@ -0,0 +1,51 @@ +from cdk_nag import NagSuppressions +from common_constructs.stack import AppStack +from constructs import Construct + +from common_constructs.data_migration import DataMigration +from stacks import persistent_stack as ps + + +class DataMigrationStack(AppStack): + """ + Stack to house data migration custom resources that run scripts to perform data migrations. + This stack should be deployed after other infrastructure stacks are in place. + """ + + def __init__( + self, + scope: Construct, + construct_id: str, + *, + environment_name: str, + environment_context: dict, + persistent_stack: ps.PersistentStack, + **kwargs, + ): + super().__init__( + scope, construct_id, environment_context=environment_context, environment_name=environment_name, **kwargs + ) + + update_sort_keys_migration = DataMigration( + self, + 'MigrateUpdateSortKeys', + migration_dir='migrate_update_sort_keys', + lambda_environment={ + 'PROVIDER_TABLE_NAME': persistent_stack.provider_table.table_name, + **self.common_env_vars, + }, + ) + persistent_stack.shared_encryption_key.grant_encrypt_decrypt(update_sort_keys_migration) + persistent_stack.provider_table.grant_read_write_data(update_sort_keys_migration) + NagSuppressions.add_resource_suppressions_by_path( + self, + f'{update_sort_keys_migration.migration_function.role.node.path}/DefaultPolicy/Resource', + suppressions=[ + { + 'id': 'AwsSolutions-IAM5', + 'reason': 'This policy contains wild-carded actions and resources but they are scoped to the ' + 'specific actions, Table and Key that this lambda needs access to in order to perform the' + 'migration.', + }, + ], + ) diff --git a/backend/compact-connect/stacks/disaster_recovery_stack/__init__.py b/backend/compact-connect/stacks/disaster_recovery_stack/__init__.py index 56c70085c..00756452f 100644 --- a/backend/compact-connect/stacks/disaster_recovery_stack/__init__.py +++ b/backend/compact-connect/stacks/disaster_recovery_stack/__init__.py @@ -1,11 +1,16 @@ -from aws_cdk import RemovalPolicy +from aws_cdk import RemovalPolicy, Stack from aws_cdk.aws_dynamodb import Table from aws_cdk.aws_iam import PolicyStatement, ServicePrincipal from aws_cdk.aws_kms import Key +from aws_cdk.aws_s3 import BlockPublicAccess, Bucket, BucketEncryption, ObjectOwnership +from cdk_nag import NagSuppressions from common_constructs.stack import AppStack from constructs import Construct from stacks import persistent_stack as ps +from stacks.disaster_recovery_stack.license_upload_rollback_step_function import ( + LicenseUploadRollbackStepFunctionConstruct, +) from stacks.disaster_recovery_stack.restore_dynamo_db_table_step_function import ( RestoreDynamoDbTableStepFunctionConstruct, ) @@ -56,6 +61,35 @@ def __init__( ) ) + # Create S3 bucket for license upload rollback results + stack = Stack.of(self) + self.disaster_recovery_results_bucket = Bucket( + self, + 'DisasterRecoveryResultsBucket', + encryption=BucketEncryption.KMS, + encryption_key=self.dr_shared_encryption_key, + removal_policy=removal_policy, + auto_delete_objects=removal_policy == RemovalPolicy.DESTROY, + versioned=True, + enforce_ssl=True, + block_public_access=BlockPublicAccess.BLOCK_ALL, + object_ownership=ObjectOwnership.BUCKET_OWNER_ENFORCED, + server_access_logs_bucket=persistent_stack.access_logs_bucket, + server_access_logs_prefix=f'_logs/{stack.account}/{stack.region}/{self.node.path}/DisasterRecoveryResultsBucket/', + ) + + # Suppress replication requirement - replication to a logs archive account may be added as a future enhancement + NagSuppressions.add_resource_suppressions( + self.disaster_recovery_results_bucket, + suppressions=[ + { + 'id': 'HIPAA.Security-S3BucketReplicationEnabled', + 'reason': 'This bucket is for generating one time' + ' results of the rollback workflow and is not intended to be replicated.', + }, + ], + ) + # Create Step Functions for restoring DynamoDB tables self.dr_workflows = {} @@ -77,6 +111,15 @@ def __init__( ssn_table=persistent_stack.ssn_table ) + # Create License Upload Rollback workflow + self.license_upload_rollback_workflow = LicenseUploadRollbackStepFunctionConstruct( + self, + 'LicenseUploadRollback', + persistent_stack=persistent_stack, + rollback_results_bucket=self.disaster_recovery_results_bucket, + dr_shared_encryption_key=self.dr_shared_encryption_key, + ) + def _create_dynamodb_table_dr_recovery_workflow(self, table: Table, shared_persistent_stack_key: Key): """Create the DR workflow for a standard DynamoDB table.""" # Prefix for restored (source) tables created by the restore workflow. The diff --git a/backend/compact-connect/stacks/disaster_recovery_stack/license_upload_rollback_step_function.py b/backend/compact-connect/stacks/disaster_recovery_stack/license_upload_rollback_step_function.py new file mode 100644 index 000000000..b8e78a96d --- /dev/null +++ b/backend/compact-connect/stacks/disaster_recovery_stack/license_upload_rollback_step_function.py @@ -0,0 +1,245 @@ +import os + +from aws_cdk import Duration +from aws_cdk.aws_events import EventBus +from aws_cdk.aws_kms import Key +from aws_cdk.aws_logs import LogGroup, RetentionDays +from aws_cdk.aws_s3 import Bucket +from aws_cdk.aws_stepfunctions import ( + Choice, + Condition, + DefinitionBody, + Fail, + IChainable, + LogLevel, + LogOptions, + Pass, + StateMachine, + Succeed, +) +from aws_cdk.aws_stepfunctions_tasks import LambdaInvoke +from cdk_nag import NagSuppressions +from common_constructs.stack import Stack +from constructs import Construct + +from common_constructs.python_function import PythonFunction +from common_constructs.ssm_parameter_utility import SSMParameterUtility +from stacks import persistent_stack as ps + + +class LicenseUploadRollbackStepFunctionConstruct(Construct): + """ + Step Function construct for rolling back invalid license uploads. + + This construct creates a Lambda function to process the rollback and a Step Function + state machine to orchestrate the process with pagination support. + """ + + def __init__( + self, + scope: Construct, + construct_id: str, + *, + persistent_stack: ps.PersistentStack, + rollback_results_bucket: Bucket, + dr_shared_encryption_key: Key, + **kwargs, + ): + super().__init__(scope, construct_id, **kwargs) + + stack = Stack.of(self) + # We explicitly get the event bus arn from parameter store, to avoid issues with cross stack updates + data_event_bus = SSMParameterUtility.load_data_event_bus_from_ssm_parameter(self) + + # Create Lambda function for rollback processing + self._create_rollback_function( + stack=stack, + persistent_stack=persistent_stack, + rollback_results_bucket=rollback_results_bucket, + data_event_bus=data_event_bus, + ) + + # Build Step Function definition + definition = self._build_rollback_state_machine_definition() + + # Create log group for state machine + state_machine_log_group = LogGroup( + self, + 'LicenseUploadRollbackStateMachineLogs', + # this state machine will hopefully not be run often, so we will not automatically clear these logs + retention=RetentionDays.INFINITE, + encryption_key=dr_shared_encryption_key, + ) + + # Suppress retention period requirement - we are deliberately retaining logs indefinitely + NagSuppressions.add_resource_suppressions( + state_machine_log_group, + suppressions=[ + { + 'id': 'HIPAA.Security-CloudWatchLogGroupRetentionPeriod', + 'reason': 'This system will be used infrequently.' + ' We are deliberately retaining logs indefinitely here.', + }, + ], + ) + + # Create state machine + self.state_machine = StateMachine( + self, + 'LicenseUploadRollbackStateMachine', + definition_body=DefinitionBody.from_chainable(definition), + timeout=Duration.hours(8), # Long timeout for processing many providers + logs=LogOptions( + destination=state_machine_log_group, + level=LogLevel.ALL, + include_execution_data=True, + ), + tracing_enabled=True, + ) + + # Grant state machine permission to invoke the Lambda + self.rollback_function.grant_invoke(self.state_machine) + + NagSuppressions.add_resource_suppressions_by_path( + stack=stack, + path=f'{self.state_machine.node.path}/Role/DefaultPolicy/Resource', + suppressions=[ + { + 'id': 'AwsSolutions-IAM5', + 'reason': """ + This policy contains wild-carded actions and resources but they are scoped to the specific + Lambda function that this state machine needs access to. + """, + }, + ], + ) + + def _create_rollback_function( + self, + stack: Stack, + persistent_stack: ps.PersistentStack, + rollback_results_bucket: Bucket, + data_event_bus: EventBus, + ): + """Create the Lambda function for processing license upload rollback.""" + self.rollback_function = PythonFunction( + self, + 'LicenseUploadRollbackFunction', + description='Rollback invalid license uploads for a compact/jurisdiction/time window', + lambda_dir='disaster-recovery', + index=os.path.join('handlers', 'rollback_license_upload.py'), + handler='rollback_license_upload', + timeout=Duration.minutes(15), + memory_size=3008, # for managing potentially large results files + environment={ + **stack.common_env_vars, + 'PROVIDER_TABLE_NAME': persistent_stack.provider_table.table_name, + 'DISASTER_RECOVERY_RESULTS_BUCKET_NAME': rollback_results_bucket.bucket_name, + 'LICENSE_UPLOAD_DATE_INDEX_NAME': persistent_stack.provider_table.license_upload_date_gsi_name, + 'EVENT_BUS_NAME': data_event_bus.event_bus_name, + }, + ) + + # Grant permissions to read/write provider table + persistent_stack.shared_encryption_key.grant_decrypt(self.rollback_function) + persistent_stack.provider_table.grant_read_write_data(self.rollback_function) + + # Grant S3 permissions for results bucket + rollback_results_bucket.grant_read_write(self.rollback_function) + + # Grant EventBridge permissions to publish events + data_event_bus.grant_put_events_to(self.rollback_function) + + NagSuppressions.add_resource_suppressions_by_path( + stack=stack, + path=f'{self.rollback_function.role.node.path}/DefaultPolicy/Resource', + suppressions=[ + { + 'id': 'AwsSolutions-IAM5', + 'reason': """ + This policy contains wild-carded actions and resources but they are scoped to the + specific table, S3 bucket, and event bus that this lambda needs access to. + """, + }, + ], + ) + + def _build_rollback_state_machine_definition(self) -> IChainable: + """ + Build the Step Function definition for license upload rollback. + + Flow: + 1. Initialize - Set up execution parameters including executionId + 2. RollbackLicenses (Lambda) - Process providers and rollback + 3. CheckStatus - Check if complete or needs continuation + - IN_PROGRESS: Loop back to RollbackLicenses + - COMPLETE: Success + - default: Fail + """ + + # Initialize state - prepare input and add executionId + initialize_rollback = Pass( + self, + 'InitializeRollback', + parameters={ + 'compact.$': '$.compact', + 'jurisdiction.$': '$.jurisdiction', + 'startDateTime.$': '$.startDateTime', + 'endDateTime.$': '$.endDateTime', + 'rollbackReason.$': '$.rollbackReason', + 'executionName.$': '$$.Execution.Name', + 'providersProcessed': 0, + }, + comment='Initialize rollback parameters with execution ID', + result_path='$', + ) + + # Rollback licenses Lambda task + rollback_licenses_task = LambdaInvoke( + self, + 'RollbackLicenses', + lambda_function=self.rollback_function, + comment='Process license upload rollback for affected providers', + payload_response_only=True, + result_path='$', + retry_on_service_exceptions=True, + ) + + # Check rollback status + rollback_status_choice = Choice( + self, + 'CheckRollbackStatus', + comment='Check if rollback is complete or needs continuation', + ) + + # Rollback failed state + rollback_failed = Fail( + self, + 'RollbackFailed', + comment='Rollback operation failed', + cause='Rollback operation encountered an error', + error='RollbackError', + ) + + # Success state + rollback_complete = Succeed( + self, + 'RollbackComplete', + comment='License upload rollback completed successfully', + ) + + # Define flow logic + initialize_rollback.next(rollback_licenses_task) + rollback_licenses_task.next(rollback_status_choice) + + # Rollback status flow + rollback_status_choice.when( + Condition.string_equals('$.rollbackStatus', 'COMPLETE'), + rollback_complete, + ).when( + Condition.string_equals('$.rollbackStatus', 'IN_PROGRESS'), + rollback_licenses_task, # Loop back to continue processing + ).otherwise(rollback_failed) + + # Start with initialization + return initialize_rollback diff --git a/backend/compact-connect/stacks/persistent_stack/provider_table.py b/backend/compact-connect/stacks/persistent_stack/provider_table.py index 5761f76f0..9ca0f49e4 100644 --- a/backend/compact-connect/stacks/persistent_stack/provider_table.py +++ b/backend/compact-connect/stacks/persistent_stack/provider_table.py @@ -48,6 +48,7 @@ def __init__( self.provider_date_of_update_index_name = 'providerDateOfUpdate' self.license_gsi_name = 'licenseGSI' self.compact_transaction_gsi_name = 'compactTransactionIdGSI' + self.license_upload_date_gsi_name = 'licenseUploadDateGSI' self.add_global_secondary_index( index_name=self.provider_fam_giv_mid_index_name, @@ -83,6 +84,17 @@ def __init__( 'providerId', ], ) + # in this case, we only need to include the provider id since this GSI is used to + # determine which providers were associated with a particular license upload time + self.add_global_secondary_index( + index_name=self.license_upload_date_gsi_name, + partition_key=Attribute(name='licenseUploadDateGSIPK', type=AttributeType.STRING), + sort_key=Attribute(name='licenseUploadDateGSISK', type=AttributeType.STRING), + projection_type=ProjectionType.INCLUDE, + non_key_attributes=[ + 'providerId', + ], + ) # Set up backup plan backup_enabled = environment_context['backup_enabled'] if backup_enabled and backup_infrastructure_stack is not None: diff --git a/backend/compact-connect/stacks/provider_users/provider_users.py b/backend/compact-connect/stacks/provider_users/provider_users.py index 449ac9a8b..1c8fd04e1 100644 --- a/backend/compact-connect/stacks/provider_users/provider_users.py +++ b/backend/compact-connect/stacks/provider_users/provider_users.py @@ -69,9 +69,7 @@ def __init__( if persistent_stack.hosted_zone: self.add_custom_app_client_domain( - app_client_domain_prefix='Licensee', - scope=self, - hosted_zone=persistent_stack.hosted_zone + app_client_domain_prefix='Licensee', scope=self, hosted_zone=persistent_stack.hosted_zone ) else: provider_prefix = f'{app_name}-provider' diff --git a/backend/compact-connect/tests/smoke/config.py b/backend/compact-connect/tests/smoke/config.py index 859e76fcf..a6b582e29 100644 --- a/backend/compact-connect/tests/smoke/config.py +++ b/backend/compact-connect/tests/smoke/config.py @@ -36,6 +36,10 @@ def environment_name(self): def aws_region(self): return os.environ['AWS_DEFAULT_REGION'] + @property + def license_upload_rollback_step_function_arn(self): + return os.environ['CC_TEST_ROLLBACK_STEP_FUNCTION_ARN'] + @property def provider_user_dynamodb_table(self): return boto3.resource('dynamodb').Table(os.environ['CC_TEST_PROVIDER_DYNAMO_TABLE_NAME']) diff --git a/backend/compact-connect/tests/smoke/rollback_license_upload_smoke_tests.py b/backend/compact-connect/tests/smoke/rollback_license_upload_smoke_tests.py new file mode 100644 index 000000000..139bab552 --- /dev/null +++ b/backend/compact-connect/tests/smoke/rollback_license_upload_smoke_tests.py @@ -0,0 +1,793 @@ +# ruff: noqa: T201 we use print statements for smoke testing +#!/usr/bin/env python3 +import json +import time +from datetime import UTC, datetime, timedelta + +import boto3 +import requests +from config import config, logger +from smoke_common import ( + LicenseData, + LicenseUpdateData, + SmokeTestFailureException, + create_test_app_client, + create_test_staff_user, + delete_test_app_client, + delete_test_staff_user, + get_api_base_url, + get_client_auth_headers, + get_provider_user_records, + get_staff_user_auth_headers, + load_smoke_test_env, +) + +COMPACT = 'coun' +JURISDICTION = 'ne' +TEST_STAFF_USER_EMAIL = 'testStaffUserLicenseRollback@smokeTestFakeEmail.com' +TEST_APP_CLIENT_NAME = 'test-license-rollback-client' + +LICENSE_TYPE = 'licensed professional counselor' + +# Test configuration +NUM_LICENSES_TO_UPLOAD = 300 +BATCH_SIZE = 100 # Upload in batches of 100 + +# Global list to track all provider IDs for cleanup +ALL_PROVIDER_IDS = [] + + +def upload_test_license_batch( + auth_headers: dict, batch_start_index: int, batch_size: int, street_address: str = '123 Test Street' +): + """ + Upload a batch of test license records. + + :param auth_headers: Authentication headers for app client + :param batch_start_index: Starting index for this batch + :param batch_size: Number of licenses to upload in this batch + :param street_address: Street address to use + :return: List of license records that were uploaded + """ + licenses_batch = [] + + for i in range(batch_start_index, batch_start_index + batch_size): + # Generate unique data for each license + license_data = { + 'licenseNumber': f'ROLLBACK-TEST-{i:04d}', + 'homeAddressPostalCode': '68001', + 'givenName': f'TestProvider{i:04d}', + # keep the family name consistent so we can query for all the providers which requires an exact + # match on the family name + 'familyName': 'RollbackTest', + 'homeAddressStreet1': street_address, + 'dateOfBirth': '1985-01-01', + 'dateOfIssuance': '2020-01-01', + 'ssn': f'999-50-{i:04d}', # Incrementing SSN with padded zeros + 'licenseType': LICENSE_TYPE, + 'dateOfExpiration': '2050-12-10', + 'homeAddressState': 'NE', + 'homeAddressCity': 'Omaha', + 'compactEligibility': 'eligible', + 'licenseStatus': 'active', + } + licenses_batch.append(license_data) + + # Upload the batch + logger.info( + f'Uploading batch of {len(licenses_batch)} licenses' + f' (indices {batch_start_index}-{batch_start_index + batch_size - 1})' + ) + + post_response = requests.post( + url=f'{config.state_api_base_url}/v1/compacts/{COMPACT}/jurisdictions/{JURISDICTION}/licenses', + headers=auth_headers, + json=licenses_batch, + timeout=60, # Longer timeout for batch uploads + ) + + if post_response.status_code != 200: + raise SmokeTestFailureException( + f'Failed to upload license batch {batch_start_index}. Response: {post_response.json()}' + ) + + logger.info(f'Successfully uploaded batch {batch_start_index}-{batch_start_index + batch_size - 1}') + return licenses_batch + + +def upload_test_licenses( + auth_headers: dict, num_licenses: int, batch_size: int, street_address: str = '123 Test Street' +): + """ + Upload test license records in batches. + + :param auth_headers: Authentication headers for app client + :param num_licenses: Total number of licenses to upload + :param batch_size: Number of licenses per batch + :param street_address: Street address to use + :return: Tuple of (all uploaded license data, upload start time, upload end time) + """ + all_licenses = [] + + logger.info(f'Starting upload of {num_licenses} test licenses in batches of {batch_size}') + + for batch_start in range(0, num_licenses, batch_size): + current_batch_size = min(batch_size, num_licenses - batch_start) + batch_licenses = upload_test_license_batch(auth_headers, batch_start, current_batch_size, street_address) + all_licenses.extend(batch_licenses) + + # Small delay between batches to avoid rate limiting + if batch_start + current_batch_size < num_licenses: + time.sleep(2) + + # wait for several minutes for all licenses to propagate in the system + logger.info(f'Completed upload of {len(all_licenses)} licenses') + + return all_licenses + + +def verify_license_update_records_created(provider_ids, retry_count: int = 0): + """ + Checks all provider ids for license update records, if none are found, adds to list to retry + and retries after a delay + :param provider_ids: List of provider IDs to check + :param retry_count: Current retry count + :return: None + """ + provider_ids_to_retry = [] + for provider_id in provider_ids: + provider_user_records = get_provider_user_records(COMPACT, provider_id) + if len(provider_user_records.get_all_license_update_records()) == 0: + logger.info(f'no license update records found for provider {provider_id}. Will retry.') + provider_ids_to_retry.append(provider_id) + + if provider_ids_to_retry: + if retry_count >= 3: + raise SmokeTestFailureException( + f'failed to find license update records for {len(provider_ids_to_retry)} providers after 3 retries' + ) + time.sleep(10) + logger.info(f'retrying {len(provider_ids_to_retry)} providers after 10 seconds...') + verify_license_update_records_created(provider_ids_to_retry, retry_count + 1) + else: + logger.info('all license update records found') + + +def wait_for_all_providers_created(staff_headers: dict, expected_count: int, max_wait_time: int = 120): + """ + Wait for all provider records to be created from uploaded licenses. + + :param staff_headers: Authentication headers for staff user + :param expected_count: Expected number of providers to be created + :param max_wait_time: Maximum time to wait in seconds (default: 900 = 15 minutes) + :return: List of provider IDs that were created + """ + logger.info(f'Waiting for {expected_count} provider records to be created...') + + start_time = time.time() + check_interval = 5 + + # Query using the common family name prefix 'RollbackTest' + # The API will return all providers with family names starting with this prefix + + last_key = None + page_num = 1 + all_provider_ids: set[str] = set() + while time.time() - start_time < max_wait_time: + # Collect all providers across all pages + while True: + query_body = { + 'query': {'familyName': 'RollbackTest'}, + 'pagination': {'pageSize': 100}, + } + if last_key: + query_body['pagination']['lastKey'] = last_key + + query_response = requests.post( + url=f'{get_api_base_url()}/v1/compacts/{COMPACT}/providers/query', + headers=staff_headers, + json=query_body, + timeout=30, + ) + + if query_response.status_code != 200: + logger.warning( + f'Query failed with status {query_response.status_code}: {query_response.json()} Retrying...' + ) + break + + response_data = query_response.json() + providers = response_data.get('providers', []) + pagination = response_data.get('pagination', {}) + + # Collect provider IDs from this page and add to set + page_provider_ids = [p['providerId'] for p in providers] + all_provider_ids.update(page_provider_ids) + + logger.info( + f'Page {page_num}: Found {len(page_provider_ids)} providers ' + f'(total: {len(all_provider_ids)}/{expected_count})' + ) + + # Check if there are more pages + last_key = pagination.get('lastKey') + if not last_key: + # No more pages + break + + page_num += 1 + + num_found = len(all_provider_ids) + logger.info( + f'Found {num_found}/{expected_count} providers with family name "RollbackTest" (across {page_num} pages)' + ) + + if num_found >= expected_count: + logger.info(f'All {expected_count} providers found!') + return list(all_provider_ids) # Return only the expected count + + elapsed = time.time() - start_time + if elapsed < max_wait_time: + logger.info(f'Waiting {check_interval}s for remaining providers... (elapsed: {elapsed:.1f}s)') + time.sleep(check_interval) + + # Timeout reached - make one final query to get the latest results + raise SmokeTestFailureException(f'Timeout reached waiting for providers after {max_wait_time}s.') + + +def start_rollback_step_function( + step_function_arn: str, + compact: str, + jurisdiction: str, + start_datetime: datetime, + end_datetime: datetime, +): + """ + Start the license upload rollback step function. + + :param step_function_arn: ARN of the step function + :param compact: Compact abbreviation + :param jurisdiction: Jurisdiction abbreviation + :param start_datetime: Start of rollback time window + :param end_datetime: End of rollback time window + :return: Execution ARN + """ + sfn_client = boto3.client('stepfunctions') + + # Generate unique execution name + execution_name = f'smoke-test-rollback-{int(datetime.now(tz=UTC).timestamp())}' + + input_data = { + 'compact': compact, + 'jurisdiction': jurisdiction, + 'startDateTime': start_datetime.isoformat(), + 'endDateTime': end_datetime.isoformat(), + 'rollbackReason': 'Smoke test validation of rollback functionality', + } + + logger.info(f'Starting step function execution: {execution_name}') + logger.info(f'Input: {json.dumps(input_data, indent=2)}') + + response = sfn_client.start_execution( + stateMachineArn=step_function_arn, + name=execution_name, + input=json.dumps(input_data), + ) + + execution_arn = response['executionArn'] + logger.info(f'Step function started. Execution ARN: {execution_arn}') + + return execution_arn + + +def wait_for_step_function_completion(execution_arn: str, max_wait_time: int = 3600): + """ + Poll the step function until it completes. + + :param execution_arn: ARN of the step function execution + :param max_wait_time: Maximum time to wait in seconds (default: 3600 = 1 hour) + :return: Final execution status and output + """ + sfn_client = boto3.client('stepfunctions') + + logger.info('Waiting for step function to complete...') + start_time = time.time() + check_interval = 30 + + while time.time() - start_time < max_wait_time: + response = sfn_client.describe_execution(executionArn=execution_arn) + + status = response['status'] + logger.info(f'Step function status: {status}') + + if status == 'SUCCEEDED': + output = json.loads(response['output']) + elapsed = time.time() - start_time + logger.info(f'Step function completed successfully after {elapsed:.1f}s') + return status, output + if status in ['FAILED', 'TIMED_OUT', 'ABORTED']: + raise SmokeTestFailureException( + f'Step function execution failed with status: {status}. ' + f'Error: {response.get("error", "N/A")}, Cause: {response.get("cause", "N/A")}' + ) + + # Still running + time.sleep(check_interval) + + raise SmokeTestFailureException(f'Step function did not complete within {max_wait_time}s timeout') + + +def get_rollback_results_from_s3(results_s3_key: str): + """ + Retrieve rollback results from S3. + + :param results_s3_key: S3 URI or key to the results file + :return: Parsed results data + """ + s3_client = boto3.client('s3') + + # Format: s3://bucket-name/key + parts = results_s3_key.replace('s3://', '').split('/', 1) + bucket_name = parts[0] + key = parts[1] + + logger.info(f'Retrieving results from S3: {bucket_name}/{key}') + + response = s3_client.get_object(Bucket=bucket_name, Key=key) + results_json = response['Body'].read().decode('utf-8') + results = json.loads(results_json) + + logger.info('Retrieved results from S3') + return results + + +def create_privilege_for_provider(provider_id: str, compact: str): + """ + Manually create a privilege record for a provider to test skip conditions. + + :param provider_id: The provider ID to create privilege for + :param compact: The compact abbreviation + """ + from datetime import date + + # Create a privilege record for a different jurisdiction (e.g., 'co' for Colorado) + privilege_jurisdiction = 'co' + license_type_abbr = 'lpc' + + privilege_record = { + 'pk': f'{compact}#PROVIDER#{provider_id}', + 'sk': f'{compact}#PROVIDER#privilege/{privilege_jurisdiction}/{license_type_abbr}#', + 'type': 'privilege', + 'providerId': provider_id, + 'compact': compact, + 'jurisdiction': privilege_jurisdiction, + 'licenseJurisdiction': JURISDICTION, + 'licenseType': LICENSE_TYPE, + 'dateOfIssuance': datetime.now(tz=UTC).isoformat(), + 'dateOfRenewal': datetime.now(tz=UTC).isoformat(), + 'dateOfExpiration': date(2050, 12, 10).isoformat(), + 'dateOfUpdate': datetime.now(tz=UTC).isoformat(), + 'privilegeId': f'{license_type_abbr.upper()}-{privilege_jurisdiction.upper()}-12345', + 'administratorSetStatus': 'active', + 'compactTransactionId': 'test-transaction-12345', + 'compactTransactionIdGSIPK': f'COMPACT#{compact}#TX#test-transaction-12345#', + 'attestations': [], + } + + config.provider_user_dynamodb_table.put_item(Item=privilege_record) + logger.info(f'Created privilege record for provider {provider_id}') + + +def create_encumbrance_update_for_provider(provider_id: str, compact: str, license_jurisdiction: str): + """ + Manually create a license encumbrance update record to test skip conditions. + + :param provider_id: The provider ID + :param compact: The compact abbreviation + :param license_jurisdiction: The jurisdiction of the license + """ + + license_type_abbr = 'lpc' + # Use current time or specified time + now = datetime.now(tz=UTC) + + # First, query the actual license record to get the previous state + license_sk = f'{compact}#PROVIDER#license/{license_jurisdiction}/{license_type_abbr}#' + + try: + response = config.provider_user_dynamodb_table.get_item( + Key={'pk': f'{compact}#PROVIDER#{provider_id}', 'sk': license_sk} + ) + license_record_item = response.get('Item') + + if not license_record_item: + raise SmokeTestFailureException(f'License record not found for provider {provider_id}') + + # Load the license record using the schema to get properly typed data + license_record = LicenseData.from_database_record(license_record_item) + + except Exception as e: + logger.error(f'Failed to retrieve license record for provider {provider_id}: {str(e)}') + raise + + # Create a license encumbrance update record using LicenseUpdateData + # This ensures proper schema validation and field generation (including SK hash) + update_data = LicenseUpdateData.create_new( + { + 'type': 'licenseUpdate', + 'updateType': 'encumbrance', + 'providerId': provider_id, + 'compact': compact, + 'jurisdiction': license_jurisdiction, + 'licenseType': LICENSE_TYPE, + 'createDate': now, + 'effectiveDate': now, + 'previous': license_record.to_dict(), + 'updatedValues': { + 'encumberedStatus': 'encumbered', + }, + } + ) + + # Serialize to database record format + update_record = update_data.serialize_to_database_record() + + config.provider_user_dynamodb_table.put_item(Item=update_record) + logger.info(f'Created encumbrance update record for provider {provider_id} with createDate {now.isoformat()}') + + +def delete_all_provider_records(provider_ids: list[str], compact: str): + """ + Delete all records for the given provider IDs. + + :param provider_ids: List of provider IDs to delete + :param compact: The compact abbreviation + """ + logger.info(f'Starting cleanup of {len(provider_ids)} provider records...') + + for i, provider_id in enumerate(provider_ids): + if i % 100 == 0: + logger.info(f'Cleaned up {i}/{len(provider_ids)} provider records') + + try: + # Query all records for this provider + response = config.provider_user_dynamodb_table.query( + KeyConditionExpression='pk = :pk', + ExpressionAttributeValues={':pk': f'{compact}#PROVIDER#{provider_id}'}, + ) + + # Delete all records in batches + with config.provider_user_dynamodb_table.batch_writer() as batch: + for item in response.get('Items', []): + batch.delete_item(Key={'pk': item['pk'], 'sk': item['sk']}) + except Exception as e: # noqa: BLE001 + logger.warning(f'Failed to delete records for provider {provider_id}: {str(e)}') + + logger.info(f'✅ Completed cleanup of {len(provider_ids)} provider records') + + +def verify_rollback_results(results: dict, expected_provider_count: int, expected_skipped_count: int = 0): + """ + Verify the rollback results match expected format and counts. + + :param results: Rollback results from S3 + :param expected_provider_count: Expected number of providers rolled back (reverted) + :param expected_skipped_count: Expected number of providers that should be skipped + """ + logger.info('Verifying rollback results...') + + # Verify structure + required_keys = ['revertedProviderSummaries', 'skippedProviderDetails', 'failedProviderDetails'] + for key in required_keys: + if key not in results: + raise SmokeTestFailureException(f'Missing required key in results: {key}') + + # Check counts + reverted = results['revertedProviderSummaries'] + skipped = results['skippedProviderDetails'] + failed = results['failedProviderDetails'] + + num_reverted = len(reverted) + num_skipped = len(skipped) + num_failed = len(failed) + + logger.info('Rollback summary:') + logger.info(f' - Reverted: {num_reverted}') + logger.info(f' - Skipped: {num_skipped}') + logger.info(f' - Failed: {num_failed}') + + # Verify skipped count matches expectation + if num_skipped != expected_skipped_count: + logger.error(f'Found {num_skipped} skipped providers, expected {expected_skipped_count}:') + for detail in skipped[:5]: # Show first 5 + logger.error(f'Details for skipped provider: {detail["providerId"]}', skipped=detail) + raise SmokeTestFailureException(f'Expected {expected_skipped_count} skipped providers but found {num_skipped}') + + if num_failed > 0: + logger.error(f'Found {num_failed} failed providers:') + for detail in failed[:5]: # Show first 5 + logger.error(f'Details for failed provider: {detail["providerId"]}', failed=detail) + raise SmokeTestFailureException(f'Expected 0 failed providers but found {num_failed}') + + # Verify we got the expected number of reverted providers + if num_reverted != expected_provider_count: + logger.warning(f'Expected {expected_provider_count} reverted providers but found {num_reverted}') + + # Verify the reverted provider has the expected structure + for i, summary in enumerate(reverted): + if 'providerId' not in summary: + raise SmokeTestFailureException(f'Reverted provider summary {i} missing providerId') + if 'licensesReverted' not in summary: + raise SmokeTestFailureException(f'Reverted provider summary {i} missing licensesReverted') + if 'updatesDeleted' not in summary: + raise SmokeTestFailureException(f'Reverted provider summary {i} missing updatesDeleted') + + # Verify each license was deleted (not reverted to previous state) + licenses_reverted = summary['licensesReverted'] + if len(licenses_reverted) != 1: + raise SmokeTestFailureException( + f'Expected 1 license reverted for provider {summary["providerId"]}, found {len(licenses_reverted)}' + ) + + license_action = licenses_reverted[0]['action'] + if license_action != 'DELETE': + raise SmokeTestFailureException( + f'Expected license action "DELETE" but found "{license_action}" for provider {summary["providerId"]}' + ) + + # Verify that update records were deleted (should have at least 1 from the re-upload) + updates_deleted = summary['updatesDeleted'] + if len(updates_deleted) < 1: + raise SmokeTestFailureException( + f'Expected at least 1 update record deleted for provider {summary["providerId"]}, ' + f'found {len(updates_deleted)}' + ) + + logger.info('✅ Rollback results verification passed') + + +def verify_providers_deleted_from_database(results: dict, compact: str): + """ + Verify that all provider records were actually deleted from DynamoDB. + + :param results: Rollback results containing provider IDs + :param compact: Compact abbreviation + """ + logger.info('Verifying providers were deleted from database...') + + reverted_summaries = results['revertedProviderSummaries'] + + for i, summary in enumerate(reverted_summaries): + if i % 100 == 0: + logger.info(f'Verified deletion for {i}/{len(reverted_summaries)} providers') + + provider_id = summary['providerId'] + + # Try to get provider records - should return empty or raise exception + provider_user_records = get_provider_user_records(compact, provider_id) + + # Check if any records exist + all_records = provider_user_records.provider_records + if all_records: + raise SmokeTestFailureException( + f'Provider {provider_id} still has {len(all_records)} records in database after rollback' + ) + + logger.info(f'✅ Verified {len(reverted_summaries)} providers were deleted from database') + + +def rollback_license_upload_smoke_test(): + """ + Main smoke test for license upload rollback functionality. + + Steps: + 1. Upload test license records (first time) + 2. Upload test license records again with different address (creates update records) + 3. Wait for all providers to be created AND verify license update records exist in DynamoDB + 4. Store all provider IDs for cleanup + 5. Create privilege for first provider (should be skipped) + 6. Create encumbrance update for second provider (should be skipped) + 7. Start rollback step function + 8. Wait for step function completion + 9. Retrieve and verify results from S3 + 10. Verify providers were deleted from database (except 2 skipped) + 11. Clean up remaining test records + """ + global ALL_PROVIDER_IDS + + # Get environment configuration + step_function_arn = config.license_upload_rollback_step_function_arn + + if not step_function_arn: + raise SmokeTestFailureException('CC_TEST_ROLLBACK_STEP_FUNCTION_ARN environment variable not set') + + # staff user to query providers + staff_headers = get_staff_user_auth_headers(TEST_STAFF_USER_EMAIL) + + # Create test app client for authentication + client_credentials = create_test_app_client(TEST_APP_CLIENT_NAME, COMPACT, JURISDICTION) + client_id = client_credentials['client_id'] + client_secret = client_credentials['client_secret'] + + skipped_provider_ids = [] + + try: + # Get authentication headers using app client + auth_headers = get_client_auth_headers(client_id, client_secret, COMPACT, JURISDICTION) + + # Step 1: Upload test licenses (first time) + logger.info('=' * 80) + logger.info('STEP 1: Uploading test licenses (first time)') + logger.info('=' * 80) + + first_upload_start_time = datetime.now(tz=UTC) + uploaded_licenses = upload_test_licenses( + auth_headers, + NUM_LICENSES_TO_UPLOAD, + BATCH_SIZE, + street_address='123 Test Street', + ) + first_upload_end_time = datetime.now(tz=UTC) + logger.info( + f'First upload time window: {first_upload_start_time.isoformat()} to {first_upload_end_time.isoformat()}' + ) + + # Wait for first upload's license records to be created before second upload + logger.info('=' * 80) + logger.info('Waiting for first upload providers and license records to be created...') + logger.info('=' * 80) + time.sleep(10) + wait_for_all_providers_created(staff_headers, len(uploaded_licenses)) + logger.info('✅ All first upload license records have been created') + + # Step 2: Upload test licenses again with different address to create update records + logger.info('=' * 80) + logger.info('STEP 2: Uploading test licenses again with different address (creates update records)') + logger.info('=' * 80) + + upload_test_licenses( + auth_headers, + NUM_LICENSES_TO_UPLOAD, + BATCH_SIZE, + street_address='456 Updated Street', + ) + + logger.info('Second upload completed - update records should be created') + + # Step 3: Wait for providers to be created and update records to propagate + logger.info('=' * 80) + logger.info('STEP 3: Waiting for provider records and update records to be created') + logger.info('=' * 80) + + provider_ids = wait_for_all_providers_created(staff_headers, len(uploaded_licenses)) + + # Store all provider IDs globally for cleanup + ALL_PROVIDER_IDS = provider_ids.copy() + + logger.info('Checking for license update records.') + verify_license_update_records_created(provider_ids) + # Capture end time after verifying update records exist + second_upload_end_time = datetime.now(tz=UTC) + + logger.info(f'Found {len(provider_ids)} provider records') + + # Step 4: Create privilege for first provider (should be skipped in rollback) + logger.info('=' * 80) + logger.info('STEP 4: Creating privilege for first provider to test skip condition') + logger.info('=' * 80) + + first_provider_id = provider_ids[0] + create_privilege_for_provider(first_provider_id, COMPACT) + skipped_provider_ids.append(first_provider_id) + logger.info(f'Created privilege for provider {first_provider_id} - should be skipped in rollback') + + # Step 5: Create encumbrance update for second provider (should be skipped in rollback) + logger.info('=' * 80) + logger.info('STEP 5: Creating encumbrance update for second provider to test skip condition') + logger.info('=' * 80) + + second_provider_id = provider_ids[1] + create_encumbrance_update_for_provider(second_provider_id, COMPACT, JURISDICTION) + skipped_provider_ids.append(second_provider_id) + logger.info(f'Created encumbrance update for provider {second_provider_id} - should be skipped in rollback') + + # Brief wait to ensure the manually created records are written + logger.info('Waiting briefly for test records to propagate...') + time.sleep(5) + + # Step 6: Start rollback step function + logger.info('=' * 80) + logger.info('STEP 6: Starting rollback step function') + logger.info('=' * 80) + + rollback_start = first_upload_start_time + # Add buffer to end time window to ensure we catch all uploads + rollback_end = second_upload_end_time + timedelta(minutes=5) + + execution_arn = start_rollback_step_function( + step_function_arn=step_function_arn, + compact=COMPACT, + jurisdiction=JURISDICTION, + start_datetime=rollback_start, + end_datetime=rollback_end, + ) + + # Step 7: Wait for step function completion + logger.info('=' * 80) + logger.info('STEP 7: Waiting for step function to complete') + logger.info('=' * 80) + + status, output = wait_for_step_function_completion(execution_arn) + + logger.info(f'Step function output: {json.dumps(output, indent=2)}') + + # Step 8: Retrieve and verify results from S3 + logger.info('=' * 80) + logger.info('STEP 8: Retrieving and verifying results from S3') + logger.info('=' * 80) + + results_s3_key = output.get('resultsS3Key') + if not results_s3_key: + raise SmokeTestFailureException('No resultsS3Key in step function output') + + results = get_rollback_results_from_s3(results_s3_key) + + # Expect all providers reverted except for the 2 skipped + expected_reverted = NUM_LICENSES_TO_UPLOAD - 2 + expected_skipped = 2 + verify_rollback_results(results, expected_reverted, expected_skipped) + + # Step 9: Verify providers deleted from database (except the 2 skipped ones) + logger.info('=' * 80) + logger.info('STEP 9: Verifying providers were deleted from database') + logger.info('=' * 80) + + verify_providers_deleted_from_database(results, COMPACT) + + # Step 10: Clean up the 2 skipped provider records + logger.info('=' * 80) + logger.info('STEP 10: Cleaning up skipped provider records') + logger.info('=' * 80) + + delete_all_provider_records(skipped_provider_ids, COMPACT) + + logger.info('=' * 80) + logger.info('✅ ALL TESTS PASSED') + logger.info('=' * 80) + except Exception as e: + logger.error(f'Test failed: {str(e)}') + # If test failed, we need to clean up all provider records + if ALL_PROVIDER_IDS: + logger.info('=' * 80) + logger.info('CLEANUP: Test failed, cleaning up all provider records') + logger.info('=' * 80) + delete_all_provider_records(ALL_PROVIDER_IDS, COMPACT) + raise + finally: + # Clean up the test app client + delete_test_app_client(client_id) + + +if __name__ == '__main__': + load_smoke_test_env() + + # Create staff user with permission to upload licenses and run rollback + test_user_sub = create_test_staff_user( + email=TEST_STAFF_USER_EMAIL, + compact=COMPACT, + jurisdiction=JURISDICTION, + permissions={'actions': {'admin'}, 'jurisdictions': {JURISDICTION: {'write', 'admin'}}}, + ) + + try: + rollback_license_upload_smoke_test() + logger.info('🎉 License upload rollback smoke test completed successfully!') + except SmokeTestFailureException as e: + logger.error(f'❌ License upload rollback smoke test failed: {str(e)}') + raise + except Exception as e: + logger.error(f'❌ Unexpected error during smoke test: {str(e)}', exc_info=True) + raise + finally: + # Clean up the test staff user + delete_test_staff_user(TEST_STAFF_USER_EMAIL, user_sub=test_user_sub, compact=COMPACT) diff --git a/backend/compact-connect/tests/smoke/smoke_common.py b/backend/compact-connect/tests/smoke/smoke_common.py index 9caded8c4..aa9b5cc09 100644 --- a/backend/compact-connect/tests/smoke/smoke_common.py +++ b/backend/compact-connect/tests/smoke/smoke_common.py @@ -33,7 +33,10 @@ def __init__(self, message): os.environ['LICENSE_TYPES'] = json.dumps(LICENSE_TYPES) # We have to import this after we've added the common lib to our path and environment -from cc_common.data_model.provider_record_util import ProviderUserRecords # noqa: E402 +from cc_common.data_model.provider_record_util import ProviderUserRecords # noqa: E402 F401 + +# importing this here so it can be easily referenced in the rollback upload tests +from cc_common.data_model.schema.license import LicenseData, LicenseUpdateData # noqa: E402 F401 from cc_common.data_model.schema.user.record import UserRecordSchema # noqa: E402 _TEST_STAFF_USER_PASSWORD = 'TestPass123!' # noqa: S105 test credential for test staff user @@ -257,11 +260,25 @@ def get_provider_user_records(compact: str, provider_id: str) -> ProviderUserRec :return: ProviderUserRecords instance containing all records for this provider """ # Query the provider database for all records - query_result = config.provider_user_dynamodb_table.query( - KeyConditionExpression=Key('pk').eq(f'{compact}#PROVIDER#{provider_id}') - ) + resp = {'Items': []} + last_evaluated_key = None + while True: + pagination = {'ExclusiveStartKey': last_evaluated_key} if last_evaluated_key else {} + # Grab all records under the provider partition + query_resp = config.provider_user_dynamodb_table.query( + Select='ALL_ATTRIBUTES', + KeyConditionExpression=Key('pk').eq(f'{compact}#PROVIDER#{provider_id}'), + ConsistentRead=True, + **pagination, + ) + + resp['Items'].extend(query_resp.get('Items', [])) - return ProviderUserRecords(query_result['Items']) + last_evaluated_key = query_resp.get('LastEvaluatedKey') + if not last_evaluated_key: + break + + return ProviderUserRecords(resp['Items']) def upload_license_record(staff_headers: dict, compact: str, jurisdiction: str, data_overrides: dict = None): @@ -470,6 +487,119 @@ def cleanup_test_provider_records(provider_id: str, compact: str): logger.warning(f'Error during cleanup: {str(e)}') +def create_test_app_client(client_name: str, compact: str, jurisdiction: str): + """ + Create a test app client in Cognito for authentication testing. + + :param client_name: Name for the test app client + :param compact: Compact abbreviation + :param jurisdiction: Jurisdiction abbreviation + :return: Dictionary containing client_id and client_secret + """ + logger.info(f'Creating test app client: {client_name}') + + try: + cognito_client = boto3.client('cognito-idp') + + # Create the user pool client + response = cognito_client.create_user_pool_client( + UserPoolId=config.cognito_state_auth_user_pool_id, + ClientName=client_name, + PreventUserExistenceErrors='ENABLED', + GenerateSecret=True, + TokenValidityUnits={'AccessToken': 'minutes'}, + AccessTokenValidity=15, + AllowedOAuthFlowsUserPoolClient=True, + AllowedOAuthFlows=['client_credentials'], + AllowedOAuthScopes=[f'{compact}/readGeneral', f'{jurisdiction}/{compact}.write'], + ) + + user_pool_client = response.get('UserPoolClient', {}) + client_id = user_pool_client.get('ClientId') + client_secret = user_pool_client.get('ClientSecret') + + if not client_id or not client_secret: + raise SmokeTestFailureException('Failed to extract client ID or secret from AWS response') + + logger.info(f'Successfully created test app client with ID: {client_id}') + return {'client_id': client_id, 'client_secret': client_secret} + + except ClientError as e: + error_code = e.response['Error']['Code'] + error_message = e.response['Error']['Message'] + logger.error(f'Failed to create app client: {error_code} - {error_message}') + raise SmokeTestFailureException(f'Failed to create app client: {error_code} - {error_message}') from e + + +def delete_test_app_client(client_id: str): + """Delete the test app client from Cognito.""" + try: + cognito_client = boto3.client('cognito-idp') + cognito_client.delete_user_pool_client(UserPoolId=config.cognito_state_auth_user_pool_id, ClientId=client_id) + logger.info(f'Successfully deleted test app client: {client_id}') + except ClientError as e: + logger.error(f'Failed to delete app client {client_id}: {str(e)}') + # Don't raise here as this is cleanup + + +def get_client_credentials_token(client_id: str, client_secret: str, compact: str, jurisdiction: str): + """ + Get an access token using client credentials flow. + + :param client_id: The client ID + :param client_secret: The client secret + :param compact: Compact abbreviation + :param jurisdiction: Jurisdiction abbreviation + :return: Access token + """ + try: + auth_url = config.state_auth_url + + # Prepare the request data for client credentials flow + data = { + 'grant_type': 'client_credentials', + 'client_id': client_id, + 'client_secret': client_secret, + 'scope': f'{compact}/readGeneral {jurisdiction}/{compact}.write', + } + + headers = {'Content-Type': 'application/x-www-form-urlencoded', 'Accept': 'application/json'} + + response = requests.post(auth_url, data=data, headers=headers, timeout=10) + + if response.status_code != 200: + raise SmokeTestFailureException( + f'Failed to get access token. Status: {response.status_code}, Response: {response.text}' + ) + + token_data = response.json() + access_token = token_data.get('access_token') + + if not access_token: + raise SmokeTestFailureException('No access token in response') + + logger.info('Successfully obtained access token using client credentials') + return access_token + + except requests.RequestException as e: + logger.error(f'Failed to get client credentials token: {str(e)}') + raise SmokeTestFailureException(f'Failed to get client credentials token: {str(e)}') from e + + +def get_client_auth_headers(client_id: str, client_secret: str, compact: str, jurisdiction: str): + """ + Get authentication headers for client credentials flow. + + :param client_id: The client ID + :param client_secret: The client secret + :param compact: Compact abbreviation + :param jurisdiction: Jurisdiction abbreviation + :return: Headers dictionary with Authorization header + """ + access_token = get_client_credentials_token(client_id, client_secret, compact, jurisdiction) + return {'Authorization': f'Bearer {access_token}'} + + def generate_opaque_data(card_number: str): """Generate a payment nonce using Authorize.Net's Secure Payment Container API. diff --git a/backend/compact-connect/tests/smoke/smoke_tests_env_example.json b/backend/compact-connect/tests/smoke/smoke_tests_env_example.json index 4dbfbb5c9..cec0fdade 100644 --- a/backend/compact-connect/tests/smoke/smoke_tests_env_example.json +++ b/backend/compact-connect/tests/smoke/smoke_tests_env_example.json @@ -18,5 +18,6 @@ "CC_TEST_PROVIDER_USER_PASSWORD": "examplePassword", "ENVIRONMENT_NAME": "sandboxEnvironmentNamePlaceholder", "SANDBOX_AUTHORIZE_NET_API_LOGIN_ID": "your_sandbox_api_login_id", - "SANDBOX_AUTHORIZE_NET_TRANSACTION_KEY": "your_sandbox_transaction_key" + "SANDBOX_AUTHORIZE_NET_TRANSACTION_KEY": "your_sandbox_transaction_key", + "CC_TEST_ROLLBACK_STEP_FUNCTION_ARN": "arn:aws:states:us-east-1:123456789012:stateMachine:Sandbox-DisasterRecoveryStack-LicenseUploadRollbackStateMachine" }