Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,6 @@ def _add_resource_servers(self, stack: ps.PersistentStack):
scope_name='readGeneral',
scope_description='Read access for generally available data (not private) in the compact',
)
self.compact_read_ssn_scope = ResourceServerScope(
scope_name='readSSN',
scope_description='Read access for SSNs in the compact',
)

active_compacts = stack.get_list_of_compact_abbreviations()
self.compact_resource_servers = {}
Expand All @@ -49,7 +45,6 @@ def _add_resource_servers(self, stack: ps.PersistentStack):
self.compact_admin_scope,
self.compact_write_scope,
self.compact_read_scope,
self.compact_read_ssn_scope,
],
)
# we define the jurisdiction level scopes, which will be used by every
Expand Down Expand Up @@ -90,8 +85,4 @@ def _generate_resource_server_scopes_list_for_compact(self, compact: str):
scope_name=f'{compact}.readPrivate',
scope_description=f'Read access for SSNs in the {compact} compact within the jurisdiction',
),
ResourceServerScope(
scope_name=f'{compact}.readSSN',
scope_description=f'Read access for SSNs in the {compact} compact within the jurisdiction',
),
]
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@ class CCPermissionsAction(StrEnum):
ADMIN = 'admin'
READ_GENERAL = 'readGeneral'
READ_PRIVATE = 'readPrivate'
READ_SSN = 'readSSN'


class S3PresignedPostSchema(Schema):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
Compact,
CompactEligibility,
Jurisdiction,
SocialSecurityNumber,
)
from cc_common.data_model.schema.license.api import (
LicenseGeneralResponseSchema,
Expand Down Expand Up @@ -56,20 +55,6 @@ def _validate_no_cross_index_keys(obj, path: str = 'query') -> None:
# Scalar values (str, int, bool, None) are safe - we only check keys


class ProviderSSNResponseSchema(ForgivingSchema):
"""
Schema for provider SSN API responses.

This schema validates the response from the provider SSN endpoint,
ensuring the SSN is properly formatted.

Serialization direction:
Python -> load() -> API
"""

ssn = SocialSecurityNumber(required=True, allow_none=False)


class ProviderReadPrivateResponseSchema(ForgivingSchema):
"""
Provider object fields that are sanitized for users with the 'readPrivate' permission.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -758,12 +758,6 @@ def _user_has_read_private_access_for_provider(compact: str, provider_informatio
)


def user_has_read_ssn_access_for_provider(compact: str, provider_information: dict, scopes: set[str]) -> bool:
return _user_has_permission_for_action_on_user(
action=CCPermissionsAction.READ_SSN, compact=compact, provider_information=provider_information, scopes=scopes
)


def _user_has_permission_for_action_on_user(
action: str, compact: str, provider_information: dict, scopes: set[str]
) -> bool:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,33 +1,21 @@
from datetime import timedelta

from aws_lambda_powertools.utilities.typing import LambdaContext
from botocore.exceptions import ClientError
from cc_common.config import config, logger, metrics
from cc_common.config import config, logger
from cc_common.data_model.schema.common import CCPermissionsAction
from cc_common.data_model.schema.provider.api import (
ProviderGeneralResponseSchema,
ProviderSSNResponseSchema,
QueryProvidersRequestSchema,
)
from cc_common.exceptions import (
CCAccessDeniedException,
CCAwsServiceException,
CCInvalidRequestException,
CCRateLimitingException,
)
from cc_common.exceptions import CCInvalidRequestException
from cc_common.utils import (
api_handler,
authorize_compact,
get_event_scopes,
sanitize_provider_data_based_on_caller_scopes,
user_has_read_ssn_access_for_provider,
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
from marshmallow import ValidationError

from . import get_provider_information

SSN_RATE_LIMITING_PK = 'READ_SSN_REQUESTS'


@api_handler
@authorize_compact(action=CCPermissionsAction.READ_GENERAL)
Expand Down Expand Up @@ -136,162 +124,3 @@ def get_provider(event: dict, context: LambdaContext): # noqa: ARG001 unused-ar
return sanitize_provider_data_based_on_caller_scopes(
compact=compact, provider=provider_information, scopes=get_event_scopes(event)
)


@api_handler
@authorize_compact(action=CCPermissionsAction.READ_SSN)
def get_provider_ssn(event: dict, context: LambdaContext): # noqa: ARG001 unused-argument
"""
Return one provider's SSN
:param event: Standard API Gateway event, API schema documented in the CDK ApiStack
:param LambdaContext context:
"""
compact = event['pathParameters']['compact']
provider_id = event['pathParameters']['providerId']
user_id = event['requestContext']['authorizer']['claims']['sub']

with logger.append_context_keys(compact=compact, provider_id=provider_id, user_id=user_id):
logger.info('Processing provider SSN request')

# Check if the user has exceeded the rate limit
if _ssn_rate_limit_exceeded(context=context, user_id=user_id, provider_id=provider_id, compact=compact):
metrics.add_metric(name='rate-limited-ssn-access', value=1, unit='Count')
logger.warning('Rate limited SSN access attempt')
raise CCRateLimitingException(
'Rate limit exceeded. You have reached the maximum number of SSN requests allowed in a 24-hour period.'
)

provider_information = get_provider_information(compact=compact, provider_id=provider_id)

# Inspect the caller's scopes to determine if they have readSSN permission for this provider
if not user_has_read_ssn_access_for_provider(
compact=compact,
provider_information=provider_information,
scopes=get_event_scopes(event),
):
metrics.add_metric(name='unauthorized-ssn-access', value=1, unit='Count')
logger.warning('Unauthorized SSN access attempt')
raise CCAccessDeniedException(
f'User does not have {CCPermissionsAction.READ_SSN} permission for this provider'
)

# Query the provider's SSN from the database
ssn = config.data_client.get_ssn_by_provider_id(compact=compact, provider_id=provider_id)

metrics.add_metric(name='read-ssn', value=1, unit='Count')

# Apply schema validation
response_schema = ProviderSSNResponseSchema()
return response_schema.load({'ssn': ssn})


def _ssn_rate_limit_exceeded(context: LambdaContext, user_id: str, provider_id: str, compact: str) -> bool:
"""Check if the user has exceeded the SSN rate limit.

:param context: The lambda context, used to get the unique request id and lambda name
:param user_id: The Cognito user ID of the staff user
:param provider_id: The provider ID being accessed
:param compact: The compact being accessed
:return: True if rate limit is exceeded, False otherwise
"""
now = config.current_standard_datetime
window_start = now - timedelta(hours=24)
window_start_timestamp = window_start.timestamp()
now_timestamp = now.timestamp()

# Append the unique AWS request id for this request
# This ensures every request is recorded, even for
# requests within the same second
request_sk = f'TIME#{now_timestamp}#REQUEST#{context.aws_request_id}'

logger.info('Recording request in rate limiting table', request_sk=request_sk)

try:
# First, record this request in the rate limiting table
config.rate_limiting_table.put_item(
Item={
'pk': SSN_RATE_LIMITING_PK,
'sk': request_sk,
'compact': compact,
'provider_id': provider_id,
'staffUserId': user_id,
'ttl': int(now_timestamp) + 86400, # 24 hours in seconds
}
)

# Check if the global rate limit has been exceeded (more than 15 requests in 24 hours)
all_requests = config.rate_limiting_table.query(
KeyConditionExpression='pk = :pk AND sk BETWEEN :start_sk AND :end_sk',
ExpressionAttributeValues={
':pk': SSN_RATE_LIMITING_PK,
':start_sk': f'TIME#{window_start_timestamp}',
# Add 1 second to ensure we include all records at the current timestamp
':end_sk': f'TIME#{now_timestamp + 1}',
},
ConsistentRead=True,
)

global_request_count = len(all_requests['Items'])
logger.info(f'Global SSN request count in last 24 hours: {global_request_count}')

# If there are more than 15 requests globally in the last 24 hours, throttle the entire endpoint
if global_request_count > 15:
logger.critical(
'Global SSN rate limit exceeded, throttling endpoint',
global_request_count=global_request_count,
current_request_user_id=user_id,
current_request_compact=compact,
)

# Set the lambda's reserved concurrency to 0 to throttle the endpoint
try:
config.lambda_client.put_function_concurrency(
FunctionName=context.function_name, ReservedConcurrentExecutions=0
)
logger.critical('Lambda concurrency set to 0 due to excessive SSN requests')
metrics.add_metric(name='ssn-endpoint-disabled', value=1, unit='Count')
except ClientError as e:
logger.error('Failed to set lambda concurrency', error=str(e))

return True

# Count how many requests were made by this user
user_request_count = 0
for item in all_requests.get('Items', []):
if item.get('staffUserId') == user_id:
user_request_count += 1

logger.info(f'User SSN request count: {user_request_count}', user_id=user_id)

# If there are more than 7 requests by this user in the window, deactivate the user's account
if user_request_count >= 7:
logger.warning(
'User exceeded SSN rate limit multiple times, deactivating account',
user_id=user_id,
request_count=user_request_count,
)

# Deactivate the user's account
try:
config.cognito_client.admin_disable_user(UserPoolId=config.user_pool_id, Username=user_id)
logger.warning('User account deactivated due to excessive SSN requests', user_id=user_id)
except ClientError as e:
logger.error('Failed to deactivate user account', error=str(e), user_id=user_id)

return True

# If there are 5 or more requests by this user in the window, rate limit is exceeded
if user_request_count >= 6:
logger.warning('SSN rate limit exceeded for user', user_id=user_id, request_count=user_request_count)
return True

logger.info(
'Rate limit has not been exceeded, proceeding with request',
user_request_count=user_request_count,
staff_user_id=user_id,
provider_id=provider_id,
)
return False
except ClientError as e:
logger.error('Failed to check SSN rate limit', error=str(e))
raise CCAwsServiceException('Failed to check SSN rate limit') from e
Loading
Loading