diff --git a/.changes/next-release/enhancement-s3-67811.json b/.changes/next-release/enhancement-s3-67811.json new file mode 100644 index 000000000000..9fb0cbb577cc --- /dev/null +++ b/.changes/next-release/enhancement-s3-67811.json @@ -0,0 +1,5 @@ +{ + "type": "enhancement", + "category": "``s3``", + "description": "Automatically calculate and validate full object checksums during multipart downloads, when available." +} diff --git a/awscli/botocore/httpchecksum.py b/awscli/botocore/httpchecksum.py index cddb7f0379ff..ff077c001ae0 100644 --- a/awscli/botocore/httpchecksum.py +++ b/awscli/botocore/httpchecksum.py @@ -276,6 +276,10 @@ def __init__(self, raw_stream, content_length, checksum, expected): self._checksum = checksum self._expected = expected + @property + def checksum(self): + return self._checksum + def read(self, amt=None): chunk = super().read(amt=amt) self._checksum.update(chunk) @@ -284,6 +288,8 @@ def read(self, amt=None): return chunk def _validate_checksum(self): + if self._expected is None: + return if self._checksum.digest() != base64.b64decode(self._expected): error_msg = f"Expected checksum {self._expected} did not match calculated checksum: {self._checksum.b64digest()}" raise FlexibleChecksumError(error_msg=error_msg) diff --git a/awscli/customizations/s3/filegenerator.py b/awscli/customizations/s3/filegenerator.py index 5e635193c471..088c0b7381eb 100644 --- a/awscli/customizations/s3/filegenerator.py +++ b/awscli/customizations/s3/filegenerator.py @@ -393,6 +393,11 @@ def _list_single_object(self, s3_path): try: params = {'Bucket': bucket, 'Key': key} params.update(self.request_parameters.get('HeadObject', {})) + if ( + self._client.meta.config.response_checksum_validation + == 'when_supported' + ): + params.setdefault('ChecksumMode', 'ENABLED') response = self._client.head_object(**params) except ClientError as e: # We want to try to give a more helpful error message. diff --git a/awscli/customizations/s3/s3handler.py b/awscli/customizations/s3/s3handler.py index 2b83fd951876..2867c04b7528 100644 --- a/awscli/customizations/s3/s3handler.py +++ b/awscli/customizations/s3/s3handler.py @@ -13,6 +13,7 @@ import logging import os +from s3transfer.checksums import resolve_full_object_checksum from s3transfer.manager import TransferManager from awscli.compat import get_binary_stdin @@ -39,6 +40,7 @@ DeleteSourceObjectSubscriber, DirectoryCreatorSubscriber, ProvideETagSubscriber, + ProvideFullObjectChecksumSubscriber, ProvideLastModifiedTimeSubscriber, ProvideSizeSubscriber, ProvideUploadContentTypeSubscriber, @@ -433,6 +435,14 @@ def _add_additional_subscribers(self, subscribers, fileinfo): fileinfo.case_conflict_key, ) ) + if fileinfo.associated_response_data: + checksum_info = resolve_full_object_checksum( + fileinfo.associated_response_data + ) + if checksum_info is not None: + subscribers.append( + ProvideFullObjectChecksumSubscriber(checksum_info) + ) def _submit_transfer_request(self, fileinfo, extra_args, subscribers): bucket, key = find_bucket_key(fileinfo.src) diff --git a/awscli/customizations/s3/subscribers.py b/awscli/customizations/s3/subscribers.py index 242c9cb59b6b..2ccdac714755 100644 --- a/awscli/customizations/s3/subscribers.py +++ b/awscli/customizations/s3/subscribers.py @@ -99,6 +99,25 @@ def on_queued(self, future, **kwargs): ) +class ProvideFullObjectChecksumSubscriber(BaseSubscriber): + """ + A subscriber which provides the stored full object checksum value. + """ + + def __init__(self, full_object_checksum): + self.full_object_checksum = full_object_checksum + + def on_queued(self, future, **kwargs): + if hasattr(future.meta, 'provide_full_object_checksum'): + future.meta.provide_full_object_checksum(self.full_object_checksum) + else: + LOGGER.debug( + 'Not providing full object checksum. Future: ' + f'{future} does not offer the capability to notify ' + 'the full object checksum', + ) + + class CaseConflictCleanupSubscriber(BaseSubscriber): """ A subscriber which removes object compare key from case conflict set diff --git a/awscli/customizations/s3/utils.py b/awscli/customizations/s3/utils.py index c90c421e8337..27c3501f600a 100644 --- a/awscli/customizations/s3/utils.py +++ b/awscli/customizations/s3/utils.py @@ -528,6 +528,7 @@ def map_head_object_params(cls, request_params, cli_params): """Map CLI params to HeadObject request params""" cls._set_sse_c_request_params(request_params, cli_params) cls._set_request_payer_param(request_params, cli_params) + cls._set_checksum_mode_param(request_params, cli_params) @classmethod def map_head_object_params_with_copy_source_sse( diff --git a/awscli/s3transfer/checksums.py b/awscli/s3transfer/checksums.py new file mode 100644 index 000000000000..dd58dfbbae12 --- /dev/null +++ b/awscli/s3transfer/checksums.py @@ -0,0 +1,105 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import base64 +import logging +from collections import namedtuple + +from awscrt import checksums as crt_checksums +from botocore.httpchecksum import _CHECKSUM_CLS +from s3transfer.exceptions import S3DownloadChecksumError + +logger = logging.getLogger(__name__) + + +CrcCombineInfo = namedtuple('CrcCombineInfo', ['combine_fn', 'byte_length']) + + +PartChecksum = namedtuple('PartChecksum', ['crc_int', 'data_length']) + + +FullObjectChecksum = namedtuple( + 'FullObjectChecksum', ['algorithm', 'expected_b64'] +) + + +_CRC_COMBINE_FUNCTIONS = { + 'crc32': CrcCombineInfo(crt_checksums.combine_crc32, 4), + 'crc32c': CrcCombineInfo(crt_checksums.combine_crc32c, 4), + 'crc64nvme': CrcCombineInfo(crt_checksums.combine_crc64nvme, 8), +} + + +_CHECKSUM_KEY_TO_ALGORITHM = { + 'ChecksumCRC32': 'crc32', + 'ChecksumCRC32C': 'crc32c', + 'ChecksumCRC64NVME': 'crc64nvme', +} + + +def resolve_full_object_checksum(response): + if response.get('ChecksumType', '').upper() != 'FULL_OBJECT': + return None + for key, algorithm in _CHECKSUM_KEY_TO_ALGORITHM.items(): + value = response.get(key) + if value: + return FullObjectChecksum(algorithm=algorithm, expected_b64=value) + return None + + +def create_checksum_for_algorithm(algorithm): + if checksum_cls := _CHECKSUM_CLS.get(algorithm): + return checksum_cls() + return None + + +class FullObjectChecksumCombiner: + def __init__(self, algorithm, num_parts, expected_b64=None): + self._algorithm = algorithm + self._expected_b64 = expected_b64 + self._num_parts = num_parts + self._combine_info = _CRC_COMBINE_FUNCTIONS[algorithm] + self._parts = {} + self._combined_bytes = None + + @property + def algorithm(self): + return self._algorithm + + def register_part(self, part_index, checksum, data_length): + crc_int = int.from_bytes(checksum.digest(), byteorder='big') + self._parts[part_index] = PartChecksum(crc_int, data_length) + + def combine_and_validate(self): + combined_bytes = self._get_combined_bytes() + combined_b64 = base64.b64encode(combined_bytes).decode('ascii') + expected_bytes = base64.b64decode(self._expected_b64) + if combined_bytes != expected_bytes: + raise S3DownloadChecksumError( + f'Expected full object checksum ' + f'({self._algorithm}) {self._expected_b64} did not match ' + f'combined checksum: {combined_b64}' + ) + logger.debug( + 'Full object %s checksum validated: %s', + self._algorithm, + combined_b64, + ) + + @property + def combined_b64(self): + combined_bytes = self._get_combined_bytes() + return base64.b64encode(combined_bytes).decode('ascii') + + def _get_combined_bytes(self): + if self._combined_bytes is not None: + return self._combined_bytes + crc = self._parts[0].crc_int + for i in range(1, self._num_parts): + part = self._parts[i] + crc = self._combine_info.combine_fn( + crc, part.crc_int, part.data_length + ) + self._combined_bytes = crc.to_bytes( + self._combine_info.byte_length, byteorder='big' + ) + return self._combined_bytes diff --git a/awscli/s3transfer/download.py b/awscli/s3transfer/download.py index 9307e48fa551..02b49461aeed 100644 --- a/awscli/s3transfer/download.py +++ b/awscli/s3transfer/download.py @@ -15,6 +15,11 @@ import threading from botocore.exceptions import ClientError +from botocore.httpchecksum import StreamingChecksumBody +from s3transfer.checksums import ( + FullObjectChecksumCombiner, + create_checksum_for_algorithm, +) from s3transfer.compat import seekable from s3transfer.exceptions import ( RetriesExceededError, @@ -477,11 +482,27 @@ def _submit_ranged_download_request( # Get any associated tags for the get object task. get_object_tag = download_output_manager.get_download_task_tag() + checksum_combiner = self._create_checksum_combiner( + client.meta.config, + transfer_future, + num_parts, + ) + # Callback invoker to submit the final io task once all downloads # are complete. + finalize_callback = self._get_final_io_task_submission_callback( + download_output_manager, io_executor + ) + pre_finalize_callbacks = [] + if checksum_combiner is not None: + pre_finalize_callbacks.append( + checksum_combiner.combine_and_validate + ) finalize_download_invoker = CountCallbackInvoker( - self._get_final_io_task_submission_callback( - download_output_manager, io_executor + FunctionContainer( + self._finalize_download, + pre_finalize_callbacks, + finalize_callback, ) ) for i in range(num_parts): @@ -512,9 +533,11 @@ def _submit_ranged_download_request( 'callbacks': progress_callbacks, 'max_attempts': config.num_download_attempts, 'start_index': i * part_size, + 'part_index': i, 'download_output_manager': download_output_manager, 'io_chunksize': config.io_chunksize, 'bandwidth_limiter': bandwidth_limiter, + 'checksum_combiner': checksum_combiner, }, done_callbacks=[finalize_download_invoker.decrement], ), @@ -522,6 +545,16 @@ def _submit_ranged_download_request( ) finalize_download_invoker.finalize() + def _finalize_download(self, pre_finalize_callbacks, finalize_callback): + for callback in pre_finalize_callbacks: + try: + callback() + except Exception as e: + self._transfer_coordinator.set_exception(e) + self._transfer_coordinator.announce_done() + return + finalize_callback() + def _get_final_io_task_submission_callback( self, download_manager, io_executor ): @@ -530,6 +563,25 @@ def _get_final_io_task_submission_callback( self._transfer_coordinator.submit, io_executor, final_task ) + def _create_checksum_combiner( + self, client_config, transfer_future, num_parts + ): + checksum_info = transfer_future.meta.full_object_checksum + if checksum_info is None: + return None + extra_args = transfer_future.meta.call_args.extra_args + auto_enabled = ( + client_config.response_checksum_validation == 'when_supported' + ) + explicitly_enabled = extra_args.get('ChecksumMode') == 'ENABLED' + if not auto_enabled and not explicitly_enabled: + return None + return FullObjectChecksumCombiner( + algorithm=checksum_info.algorithm, + num_parts=num_parts, + expected_b64=checksum_info.expected_b64, + ) + def _calculate_range_param(self, part_size, part_index, num_parts): # Used to calculate the Range parameter start_range = part_index * part_size @@ -554,7 +606,9 @@ def _main( download_output_manager, io_chunksize, start_index=0, + part_index=0, bandwidth_limiter=None, + checksum_combiner=None, ): """Downloads an object and places content into io queue @@ -571,8 +625,11 @@ def _main( download stream and queue in the io queue. :param start_index: The location in the file to start writing the content of the key to. + :param part_index: The part number for this ranged download. :param bandwidth_limiter: The bandwidth limiter to use when throttling the downloading of data in streams. + :param checksum_combiner: Optional FullObjectChecksumCombiner for + full object checksum validation on multipart downloads. """ last_exception = None for i in range(max_attempts): @@ -585,9 +642,25 @@ def _main( extra_args.get('Range'), response.get('ContentRange'), ) - streaming_body = StreamReaderProgress( - response['Body'], callbacks - ) + # When doing full object checksum combining and botocore + # hasn't already wrapped the body with a checksum + # calculator, wrap it in StreamingChecksumBody ourselves + # so the CRC is computed as data is read. We pass + # expected=None since we validate at the full object + # level, not per-part. + body = response['Body'] + if checksum_combiner is not None and not hasattr( + body, 'checksum' + ): + body = StreamingChecksumBody( + body, + response.get('ContentLength'), + create_checksum_for_algorithm( + checksum_combiner.algorithm + ), + expected=None, + ) + streaming_body = StreamReaderProgress(body, callbacks) if bandwidth_limiter: streaming_body = ( bandwidth_limiter.get_bandwith_limited_stream( @@ -595,12 +668,14 @@ def _main( ) ) + part_length = 0 chunks = DownloadChunkIterator(streaming_body, io_chunksize) for chunk in chunks: # If the transfer is done because of a cancellation # or error somewhere else, stop trying to submit more # data to be written and break out of the download. if not self._transfer_coordinator.done(): + part_length += len(chunk) self._handle_io( download_output_manager, fileobj, @@ -610,6 +685,13 @@ def _main( current_index += len(chunk) else: return + + if checksum_combiner is not None: + checksum_combiner.register_part( + part_index, + body.checksum, + part_length, + ) return except ClientError as e: error_code = e.response.get('Error', {}).get('Code') diff --git a/awscli/s3transfer/exceptions.py b/awscli/s3transfer/exceptions.py index 57ca0f55c5f4..db4fa19b9030 100644 --- a/awscli/s3transfer/exceptions.py +++ b/awscli/s3transfer/exceptions.py @@ -47,3 +47,7 @@ class FatalError(CancelledError): class S3ValidationError(Exception): pass + + +class S3DownloadChecksumError(Exception): + pass diff --git a/awscli/s3transfer/futures.py b/awscli/s3transfer/futures.py index 6222a42baba8..25e49f0d4a77 100644 --- a/awscli/s3transfer/futures.py +++ b/awscli/s3transfer/futures.py @@ -128,6 +128,7 @@ def __init__(self, call_args=None, transfer_id=None): self._size = None self._user_context = {} self._etag = None + self._full_object_checksum = None @property def call_args(self): @@ -172,6 +173,15 @@ def provide_object_etag(self, etag): """ self._etag = etag + @property + def full_object_checksum(self): + """The full object checksum info""" + return self._full_object_checksum + + def provide_full_object_checksum(self, full_object_checksum): + """A method to provide the full object checksum""" + self._full_object_checksum = full_object_checksum + class TransferCoordinator: """A helper class for managing TransferFuture""" diff --git a/tests/__init__.py b/tests/__init__.py index 6e7f26fc8157..9b7842a2e674 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -93,6 +93,7 @@ RecordingSubscriber, FileSizeProvider, ETagProvider, + FullObjectChecksumProvider, RecordingOSUtils, RecordingExecutor, TransferCoordinatorWithInterrupt, diff --git a/tests/functional/s3/test_cp_command.py b/tests/functional/s3/test_cp_command.py index dd7d68e223f6..809433ad7b01 100644 --- a/tests/functional/s3/test_cp_command.py +++ b/tests/functional/s3/test_cp_command.py @@ -536,6 +536,7 @@ def test_dryrun_download(self): { 'Bucket': 'bucket', 'Key': 'key.txt', + 'ChecksumMode': 'ENABLED', }, ) ] @@ -577,6 +578,7 @@ def test_dryrun_copy(self): { 'Bucket': 'bucket', 'Key': 'key.txt', + 'ChecksumMode': 'ENABLED', }, ) ] @@ -985,6 +987,7 @@ def test_cp_with_sse_c_copy_source_fileb(self): expected_head_args = { 'Bucket': 'bucket-one', 'Key': 'key.txt', + 'ChecksumMode': 'ENABLED', 'SSECustomerAlgorithm': 'AES256', 'SSECustomerKey': key_contents, } @@ -1016,6 +1019,7 @@ def test_s3s3_cp_with_destination_sse_c(self): expected_head_args = { 'Bucket': 'bucket-one', 'Key': 'key.txt', + 'ChecksumMode': 'ENABLED', # don't expect to see SSE-c params for the source } self.assertDictEqual(self.operations_called[0][1], expected_head_args) @@ -1047,6 +1051,7 @@ def test_s3s3_cp_with_different_sse_c_keys(self): expected_head_args = { 'Bucket': 'bucket-one', 'Key': 'key.txt', + 'ChecksumMode': 'ENABLED', 'SSECustomerAlgorithm': 'AES256', 'SSECustomerKey': 'foo', } @@ -1083,6 +1088,7 @@ def test_s3s3_cp_with_destination_sse_c_multipart(self): self.head_object_request( 'bucket-one', 'key.txt', + ChecksumMode='ENABLED', # no SSE-C params — source is unencrypted ), ('GetObjectTagging', mock.ANY), @@ -1134,6 +1140,7 @@ def test_s3s3_cp_with_different_sse_c_keys_multipart(self): expected_head_args = { 'Bucket': 'bucket-one', 'Key': 'key.txt', + 'ChecksumMode': 'ENABLED', 'SSECustomerAlgorithm': 'AES256', 'SSECustomerKey': 'source-key', } @@ -1143,6 +1150,7 @@ def test_s3s3_cp_with_different_sse_c_keys_multipart(self): self.head_object_request( 'bucket-one', 'key.txt', + ChecksumMode='ENABLED', SSECustomerAlgorithm='AES256', SSECustomerKey='source-key', ), @@ -1744,7 +1752,10 @@ def test_single_download(self): self.assert_operations_called( [ self.head_object_request( - 'mybucket', 'mykey', RequestPayer='requester' + 'mybucket', + 'mykey', + RequestPayer='requester', + ChecksumMode='ENABLED', ), self.get_object_request( 'mybucket', 'mykey', RequestPayer='requester' @@ -1767,7 +1778,10 @@ def test_ranged_download(self): self.assert_operations_called( [ self.head_object_request( - 'mybucket', 'mykey', RequestPayer='requester' + 'mybucket', + 'mykey', + RequestPayer='requester', + ChecksumMode='ENABLED', ), self.get_object_request( 'mybucket', @@ -1819,7 +1833,10 @@ def test_single_copy(self): self.assert_operations_called( [ self.head_object_request( - 'sourcebucket', 'sourcekey', RequestPayer='requester' + 'sourcebucket', + 'sourcekey', + RequestPayer='requester', + ChecksumMode='ENABLED', ), self.copy_object_request( 'sourcebucket', @@ -1848,7 +1865,10 @@ def test_multipart_copy(self): self.assert_operations_called( [ self.head_object_request( - 'sourcebucket', 'sourcekey', RequestPayer='requester' + 'sourcebucket', + 'sourcekey', + RequestPayer='requester', + ChecksumMode='ENABLED', ), self.create_mpu_request( 'mybucket', 'mykey', RequestPayer='requester' @@ -2017,7 +2037,11 @@ def test_download(self): self.run_cmd(cmdline, expected_rc=0) self.assert_operations_called( [ - self.head_object_request(self.accesspoint_arn, 'mykey'), + self.head_object_request( + self.accesspoint_arn, + 'mykey', + ChecksumMode='ENABLED', + ), self.get_object_request(self.accesspoint_arn, 'mykey'), ] ) @@ -2051,7 +2075,11 @@ def test_copy(self): self.run_cmd(cmdline, expected_rc=0) self.assert_operations_called( [ - self.head_object_request(self.accesspoint_arn, 'mykey'), + self.head_object_request( + self.accesspoint_arn, + 'mykey', + ChecksumMode='ENABLED', + ), self.copy_object_request( self.accesspoint_arn, 'mykey', diff --git a/tests/functional/s3/test_mv_command.py b/tests/functional/s3/test_mv_command.py index d4fcb1e6100a..365fc1f5ad9e 100644 --- a/tests/functional/s3/test_mv_command.py +++ b/tests/functional/s3/test_mv_command.py @@ -52,6 +52,7 @@ def test_dryrun_move(self): { 'Bucket': 'bucket', 'Key': 'key.txt', + 'ChecksumMode': 'ENABLED', }, ) ] @@ -148,6 +149,7 @@ def test_download_move_with_request_payer(self): { 'Bucket': 'mybucket', 'Key': 'mykey', + 'ChecksumMode': 'ENABLED', 'RequestPayer': 'requester', }, ), @@ -184,7 +186,10 @@ def test_copy_move_with_request_payer(self): self.assert_operations_called( [ self.head_object_request( - 'sourcebucket', 'sourcekey', RequestPayer='requester' + 'sourcebucket', + 'sourcekey', + RequestPayer='requester', + ChecksumMode='ENABLED', ), self.copy_object_request( 'sourcebucket', @@ -221,7 +226,9 @@ def test_with_copy_props(self): self.run_cmd(cmdline, expected_rc=0) self.assert_operations_called( [ - self.head_object_request('sourcebucket', 'sourcekey'), + self.head_object_request( + 'sourcebucket', 'sourcekey', ChecksumMode='ENABLED' + ), self.get_object_tagging_request('sourcebucket', 'sourcekey'), self.create_mpu_request('bucket', 'key', Metadata=metadata), self.upload_part_copy_request( @@ -275,7 +282,9 @@ def test_mv_does_not_delete_source_on_failed_put_tagging(self): self.run_cmd(cmdline, expected_rc=1) self.assert_operations_called( [ - self.head_object_request('sourcebucket', 'sourcekey'), + self.head_object_request( + 'sourcebucket', 'sourcekey', ChecksumMode='ENABLED' + ), self.get_object_tagging_request('sourcebucket', 'sourcekey'), self.create_mpu_request('bucket', 'key', Metadata=metadata), self.upload_part_copy_request( @@ -431,14 +440,10 @@ def test_mv_no_overwrite_flag_on_copy_when_small_object_does_not_exist_on_target # Verify all multipart copy operations were called self.assertEqual(len(self.operations_called), 3) self.assertEqual(self.operations_called[0][0].name, 'HeadObject') - self.assertEqual( - self.operations_called[1][0].name, 'CopyObject' - ) + self.assertEqual(self.operations_called[1][0].name, 'CopyObject') self.assertEqual(self.operations_called[1][1]['IfNoneMatch'], '*') - self.assertEqual( - self.operations_called[2][0].name, 'DeleteObject' - ) + self.assertEqual(self.operations_called[2][0].name, 'DeleteObject') def test_mv_no_overwrite_flag_on_copy_when_small_object_exists_on_target( self, @@ -455,9 +460,7 @@ def test_mv_no_overwrite_flag_on_copy_when_small_object_exists_on_target( # Verify all copy operations were called self.assertEqual(len(self.operations_called), 2) self.assertEqual(self.operations_called[0][0].name, 'HeadObject') - self.assertEqual( - self.operations_called[1][0].name, 'CopyObject' - ) + self.assertEqual(self.operations_called[1][0].name, 'CopyObject') # Verify the IfNoneMatch condition was set in the CopyObject request self.assertEqual(self.operations_called[1][1]['IfNoneMatch'], '*') diff --git a/tests/functional/s3transfer/test_download.py b/tests/functional/s3transfer/test_download.py index 66976117a145..fdc7f5d53bb5 100644 --- a/tests/functional/s3transfer/test_download.py +++ b/tests/functional/s3transfer/test_download.py @@ -19,9 +19,11 @@ from io import BytesIO from botocore.exceptions import ClientError +from s3transfer.checksums import FullObjectChecksum from s3transfer.compat import SOCKET_ERROR from s3transfer.exceptions import ( RetriesExceededError, + S3DownloadChecksumError, S3DownloadFailedError, S3ValidationError, ) @@ -31,6 +33,7 @@ BaseGeneralInterfaceTest, ETagProvider, FileSizeProvider, + FullObjectChecksumProvider, NonSeekableWriter, RecordingOSUtils, RecordingSubscriber, @@ -106,10 +109,12 @@ def create_expected_progress_callback_info(self): # that the stream is done. return [{'bytes_transferred': 10}] - def add_head_object_response(self, expected_params=None): + def add_head_object_response(self, expected_params=None, extras=None): head_response = self.create_stubbed_responses()[0] if expected_params: head_response['expected_params'] = expected_params + if extras: + head_response['service_response'].update(extras) self.stubber.add_response(**head_response) def add_successful_get_object_responses( @@ -647,3 +652,72 @@ def test_download_without_etag(self): # Ensure that the contents are correct with open(self.filename, 'rb') as f: self.assertEqual(self.content, f.read()) + + def test_ranged_download_full_object_checksum_validation(self): + checksum_crc32 = 'AUwfuQ==' + expected_params = { + 'Bucket': self.bucket, + 'Key': self.key, + } + expected_ranges = ['bytes=0-3', 'bytes=4-7', 'bytes=8-'] + stubbed_ranges = ['bytes 0-3/10', 'bytes 4-7/10', 'bytes 8-9/10'] + self.add_head_object_response( + expected_params, + extras={ + 'ChecksumCRC32': checksum_crc32, + 'ChecksumType': 'FULL_OBJECT', + }, + ) + self.add_successful_get_object_responses( + {**expected_params, 'IfMatch': self.etag}, + expected_ranges, + [{'ContentRange': r} for r in stubbed_ranges], + ) + future = self.manager.download( + self.bucket, + self.key, + self.filename, + self.extra_args, + [ + FullObjectChecksumProvider( + FullObjectChecksum('crc32', checksum_crc32) + ), + ], + ) + future.result() + with open(self.filename, 'rb') as f: + self.assertEqual(self.content, f.read()) + + def test_ranged_download_full_object_checksum_mismatch_raises(self): + expected_params = { + 'Bucket': self.bucket, + 'Key': self.key, + } + expected_ranges = ['bytes=0-3', 'bytes=4-7', 'bytes=8-'] + stubbed_ranges = ['bytes 0-3/10', 'bytes 4-7/10', 'bytes 8-9/10'] + self.add_head_object_response( + expected_params, + extras={ + 'ChecksumCRC32': 'AAAABB==', + 'ChecksumType': 'FULL_OBJECT', + }, + ) + self.add_successful_get_object_responses( + {**expected_params, 'IfMatch': self.etag}, + expected_ranges, + [{'ContentRange': r} for r in stubbed_ranges], + ) + future = self.manager.download( + self.bucket, + self.key, + self.filename, + self.extra_args, + [ + FullObjectChecksumProvider( + FullObjectChecksum('crc32', 'AAAABB==') + ), + ], + ) + with self.assertRaises(S3DownloadChecksumError): + future.result() + self.assertFalse(os.path.exists(self.filename)) diff --git a/tests/unit/customizations/s3/test_filegenerator.py b/tests/unit/customizations/s3/test_filegenerator.py index 8d0ed640b141..1b962ba8b5da 100644 --- a/tests/unit/customizations/s3/test_filegenerator.py +++ b/tests/unit/customizations/s3/test_filegenerator.py @@ -630,6 +630,69 @@ def test_s3_single_file_delete(self): ) self.client.head_object.assert_not_called() + def test_s3_single_file_sets_checksum_mode_when_supported(self): + input_s3_file = { + 'src': {'path': self.file1, 'type': 's3'}, + 'dest': {'path': 'text1.txt', 'type': 'local'}, + 'dir_op': False, + 'use_src_name': False, + } + self.client = mock.Mock() + self.client.meta.config.response_checksum_validation = 'when_supported' + self.client.head_object.return_value = { + 'ContentLength': 100, + 'LastModified': '2014-01-09T20:45:49.000Z', + 'ETag': '"abcd"', + } + file_gen = FileGenerator(self.client, '') + list(file_gen.call(input_s3_file)) + call_kwargs = self.client.head_object.call_args[1] + self.assertEqual(call_kwargs['ChecksumMode'], 'ENABLED') + + def test_s3_single_file_no_checksum_mode_when_required(self): + input_s3_file = { + 'src': {'path': self.file1, 'type': 's3'}, + 'dest': {'path': 'text1.txt', 'type': 'local'}, + 'dir_op': False, + 'use_src_name': False, + } + self.client = mock.Mock() + self.client.meta.config.response_checksum_validation = 'when_required' + self.client.head_object.return_value = { + 'ContentLength': 100, + 'LastModified': '2014-01-09T20:45:49.000Z', + 'ETag': '"abcd"', + } + file_gen = FileGenerator(self.client, '') + list(file_gen.call(input_s3_file)) + call_kwargs = self.client.head_object.call_args[1] + self.assertNotIn('ChecksumMode', call_kwargs) + + def test_s3_single_file_explicit_checksum_mode_overrides(self): + input_s3_file = { + 'src': {'path': self.file1, 'type': 's3'}, + 'dest': {'path': 'text1.txt', 'type': 'local'}, + 'dir_op': False, + 'use_src_name': False, + } + self.client = mock.Mock() + self.client.meta.config.response_checksum_validation = 'when_required' + self.client.head_object.return_value = { + 'ContentLength': 100, + 'LastModified': '2014-01-09T20:45:49.000Z', + 'ETag': '"abcd"', + } + file_gen = FileGenerator( + self.client, + '', + request_parameters={ + 'HeadObject': {'ChecksumMode': 'ENABLED'}, + }, + ) + list(file_gen.call(input_s3_file)) + call_kwargs = self.client.head_object.call_args[1] + self.assertEqual(call_kwargs['ChecksumMode'], 'ENABLED') + def test_s3_directory(self): """ Generates s3 files under a common prefix. Also it ensures that diff --git a/tests/unit/customizations/s3/test_s3handler.py b/tests/unit/customizations/s3/test_s3handler.py index 464b46c6000d..ec57fcdd6d4a 100644 --- a/tests/unit/customizations/s3/test_s3handler.py +++ b/tests/unit/customizations/s3/test_s3handler.py @@ -45,6 +45,7 @@ DeleteSourceObjectSubscriber, DirectoryCreatorSubscriber, ProvideETagSubscriber, + ProvideFullObjectChecksumSubscriber, ProvideLastModifiedTimeSubscriber, ProvideSizeSubscriber, ProvideUploadContentTypeSubscriber, @@ -712,6 +713,42 @@ def test_warn_if_file_exists_without_no_overwrite_flag(self): # And download should have happened self.assertEqual(len(self.transfer_manager.download.call_args_list), 1) + def test_submit_with_full_object_checksum(self): + fileinfo = self.create_file_info( + self.key, + associated_response_data={ + 'ChecksumType': 'FULL_OBJECT', + 'ChecksumCRC32': 'abc123==', + }, + ) + self.transfer_request_submitter.submit(fileinfo) + download_call_kwargs = self.transfer_manager.download.call_args[1] + actual_subscribers = download_call_kwargs['subscribers'] + subscriber_types = [type(s) for s in actual_subscribers] + self.assertIn(ProvideFullObjectChecksumSubscriber, subscriber_types) + + def test_submit_without_full_object_checksum(self): + fileinfo = self.create_file_info(self.key) + self.transfer_request_submitter.submit(fileinfo) + download_call_kwargs = self.transfer_manager.download.call_args[1] + actual_subscribers = download_call_kwargs['subscribers'] + subscriber_types = [type(s) for s in actual_subscribers] + self.assertNotIn(ProvideFullObjectChecksumSubscriber, subscriber_types) + + def test_submit_with_composite_checksum_does_not_add_subscriber(self): + fileinfo = self.create_file_info( + self.key, + associated_response_data={ + 'ChecksumType': 'COMPOSITE', + 'ChecksumCRC32': 'abc123==-5', + }, + ) + self.transfer_request_submitter.submit(fileinfo) + download_call_kwargs = self.transfer_manager.download.call_args[1] + actual_subscribers = download_call_kwargs['subscribers'] + subscriber_types = [type(s) for s in actual_subscribers] + self.assertNotIn(ProvideFullObjectChecksumSubscriber, subscriber_types) + class TestCopyRequestSubmitter(BaseTransferRequestSubmitterTest): def setUp(self): diff --git a/tests/unit/customizations/s3/test_subscribers.py b/tests/unit/customizations/s3/test_subscribers.py index 8c4e4f419579..90df9341ce36 100644 --- a/tests/unit/customizations/s3/test_subscribers.py +++ b/tests/unit/customizations/s3/test_subscribers.py @@ -19,6 +19,7 @@ import pytest from dateutil.tz import tzlocal +from s3transfer.checksums import FullObjectChecksum from s3transfer.crt import CRTTransferFuture, CRTTransferMeta from s3transfer.futures import TransferFuture, TransferMeta from s3transfer.manager import TransferConfig @@ -37,6 +38,7 @@ DirectoryCreatorSubscriber, OnDoneFilteredSubscriber, ProvideETagSubscriber, + ProvideFullObjectChecksumSubscriber, ProvideLastModifiedTimeSubscriber, ProvideSizeSubscriber, ProvideUploadContentTypeSubscriber, @@ -102,6 +104,30 @@ def test_does_not_try_to_set_etag_on_crt_transfer_future(self, caplog): assert "Not providing object ETag." in caplog.text +class TestProvideFullObjectChecksumSubscriber: + def test_checksum_set(self): + transfer_meta = TransferMeta() + transfer_future = mock.Mock(spec=TransferFuture) + transfer_future.meta = transfer_meta + checksum_info = FullObjectChecksum('crc32', 'abc123==') + + subscriber = ProvideFullObjectChecksumSubscriber(checksum_info) + subscriber.on_queued(transfer_future) + assert transfer_meta.full_object_checksum == checksum_info + + def test_does_not_try_to_set_on_crt_transfer_future(self, caplog): + caplog.set_level(logging.DEBUG) + crt_transfer_future = mock.Mock(spec=CRTTransferFuture) + crt_transfer_future.meta = CRTTransferMeta() + + subscriber = ProvideFullObjectChecksumSubscriber( + FullObjectChecksum('crc32', 'abc123==') + ) + subscriber.on_queued(crt_transfer_future) + assert not hasattr(crt_transfer_future.meta, 'full_object_checksum') + assert "Not providing full object checksum." in caplog.text + + class OnDoneFilteredRecordingSubscriber(OnDoneFilteredSubscriber): def __init__(self): self.on_success_calls = [] diff --git a/tests/unit/customizations/s3/test_utils.py b/tests/unit/customizations/s3/test_utils.py index 0c79dc0ea913..a1740b43b986 100644 --- a/tests/unit/customizations/s3/test_utils.py +++ b/tests/unit/customizations/s3/test_utils.py @@ -765,6 +765,18 @@ def test_get_object_no_checksums(self, cli_params_no_checksum): ) assert 'ChecksumMode' not in request_params + def test_head_object(self, cli_params): + request_params = {} + RequestParamsMapper.map_head_object_params(request_params, cli_params) + assert request_params == {'ChecksumMode': 'ENABLED'} + + def test_head_object_no_checksums(self, cli_params_no_checksum): + request_params = {} + RequestParamsMapper.map_head_object_params( + request_params, cli_params_no_checksum + ) + assert 'ChecksumMode' not in request_params + class TestRequestParamsMapperRequestPayer(unittest.TestCase): def setUp(self): diff --git a/tests/unit/s3transfer/test_checksums.py b/tests/unit/s3transfer/test_checksums.py new file mode 100644 index 000000000000..a68b83d79b75 --- /dev/null +++ b/tests/unit/s3transfer/test_checksums.py @@ -0,0 +1,182 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import base64 + +import pytest +from awscrt import checksums as crt_checksums +from s3transfer.checksums import ( + FullObjectChecksum, + FullObjectChecksumCombiner, + create_checksum_for_algorithm, + resolve_full_object_checksum, +) +from s3transfer.exceptions import S3DownloadChecksumError + + +def _compute_expected_b64(data, algorithm): + crc_fns = { + 'crc32': (crt_checksums.crc32, 4), + 'crc32c': (crt_checksums.crc32c, 4), + 'crc64nvme': (crt_checksums.crc64nvme, 8), + } + fn, byte_length = crc_fns[algorithm] + crc = fn(data) + return base64.b64encode(crc.to_bytes(byte_length, byteorder='big')).decode( + 'ascii' + ) + + +def _register_parts(combiner, parts, algorithm): + for i, part_data in enumerate(parts): + checksum = create_checksum_for_algorithm(algorithm) + checksum.update(part_data) + combiner.register_part(i, checksum, len(part_data)) + + +class TestResolveFullObjectChecksum: + def test_full_object_crc32(self): + response = { + 'ChecksumType': 'FULL_OBJECT', + 'ChecksumCRC32': 'abc123==', + } + result = resolve_full_object_checksum(response) + assert result == FullObjectChecksum( + algorithm='crc32', expected_b64='abc123==' + ) + + def test_full_object_crc32c(self): + response = { + 'ChecksumType': 'FULL_OBJECT', + 'ChecksumCRC32C': 'xyz789==', + } + result = resolve_full_object_checksum(response) + assert result == FullObjectChecksum( + algorithm='crc32c', expected_b64='xyz789==' + ) + + def test_full_object_crc64nvme(self): + response = { + 'ChecksumType': 'FULL_OBJECT', + 'ChecksumCRC64NVME': 'nvme64==', + } + result = resolve_full_object_checksum(response) + assert result == FullObjectChecksum( + algorithm='crc64nvme', expected_b64='nvme64==' + ) + + def test_missing_checksum_type(self): + assert resolve_full_object_checksum({}) is None + + def test_composite_checksum(self): + response = { + 'ChecksumType': 'COMPOSITE', + 'ChecksumCRC32': 'abc123==', + } + assert resolve_full_object_checksum(response) is None + + def test_full_object_sha_only(self): + response = { + 'ChecksumType': 'FULL_OBJECT', + 'ChecksumSHA256': 'sha256value==', + } + assert resolve_full_object_checksum(response) is None + + def test_case_insensitive_checksum_type(self): + response = { + 'ChecksumType': 'full_object', + 'ChecksumCRC32': 'abc123==', + } + result = resolve_full_object_checksum(response) + assert result is not None + assert result.algorithm == 'crc32' + + +class TestCreateChecksumForAlgorithm: + @pytest.mark.parametrize('algorithm', ['crc32', 'crc32c', 'crc64nvme']) + def test_known_algorithm(self, algorithm): + checksum = create_checksum_for_algorithm(algorithm) + assert checksum is not None + assert hasattr(checksum, 'update') + assert hasattr(checksum, 'digest') + + def test_unknown_algorithm(self): + assert create_checksum_for_algorithm('unknown') is None + + +class TestFullObjectChecksumCombiner: + def test_combine_and_validate_crc32(self): + data = b'hello world, this is a test of CRC combining' + parts = [data[:15], data[15:30], data[30:]] + expected = _compute_expected_b64(data, 'crc32') + + combiner = FullObjectChecksumCombiner( + 'crc32', len(parts), expected_b64=expected + ) + _register_parts(combiner, parts, 'crc32') + combiner.combine_and_validate() + + def test_combine_and_validate_crc32c(self): + data = b'testing crc32c combining across parts' + parts = [data[:10], data[10:]] + expected = _compute_expected_b64(data, 'crc32c') + + combiner = FullObjectChecksumCombiner( + 'crc32c', len(parts), expected_b64=expected + ) + _register_parts(combiner, parts, 'crc32c') + combiner.combine_and_validate() + + def test_combine_and_validate_crc64nvme(self): + data = b'testing crc64nvme combining across parts' + parts = [data[:10], data[10:20], data[20:]] + expected = _compute_expected_b64(data, 'crc64nvme') + + combiner = FullObjectChecksumCombiner( + 'crc64nvme', len(parts), expected_b64=expected + ) + _register_parts(combiner, parts, 'crc64nvme') + combiner.combine_and_validate() + + def test_checksum_mismatch_raises(self): + combiner = FullObjectChecksumCombiner( + 'crc32', 1, expected_b64='AAAABB==' + ) + checksum = create_checksum_for_algorithm('crc32') + checksum.update(b'some data') + combiner.register_part(0, checksum, 9) + + with pytest.raises(S3DownloadChecksumError, match='did not match'): + combiner.combine_and_validate() + + def test_combined_b64_without_expected(self): + data = b'upload use case' + parts = [data[:5], data[5:]] + expected = _compute_expected_b64(data, 'crc32') + + combiner = FullObjectChecksumCombiner('crc32', len(parts)) + _register_parts(combiner, parts, 'crc32') + assert combiner.combined_b64 == expected + + def test_combined_bytes_are_cached(self): + combiner = FullObjectChecksumCombiner('crc32', 1) + checksum = create_checksum_for_algorithm('crc32') + checksum.update(b'cache test') + combiner.register_part(0, checksum, 10) + + assert combiner.combined_b64 == combiner.combined_b64 + + def test_chunked_update_matches_single_update(self): + data = b'streaming chunk test data for verification' + expected = _compute_expected_b64(data, 'crc32') + + combiner = FullObjectChecksumCombiner( + 'crc32', 1, expected_b64=expected + ) + checksum = create_checksum_for_algorithm('crc32') + length = 0 + for i in range(0, len(data), 5): + chunk = data[i : i + 5] + checksum.update(chunk) + length += len(chunk) + combiner.register_part(0, checksum, length) + combiner.combine_and_validate() diff --git a/tests/unit/s3transfer/test_download.py b/tests/unit/s3transfer/test_download.py index 57d6418d54eb..4a11a2426700 100644 --- a/tests/unit/s3transfer/test_download.py +++ b/tests/unit/s3transfer/test_download.py @@ -16,7 +16,13 @@ import tempfile from io import BytesIO +from botocore.config import Config from s3transfer.bandwidth import BandwidthLimiter +from s3transfer.checksums import ( + FullObjectChecksum, + FullObjectChecksumCombiner, + create_checksum_for_algorithm, +) from s3transfer.compat import SOCKET_ERROR from s3transfer.download import ( CompleteDownloadNOOPTask, @@ -568,6 +574,64 @@ def tests_submits_tag_for_ranged_get_object_nonseekable_fileobj(self): # to that task submission. self.assert_tag_for_get_object(IN_MEMORY_DOWNLOAD_TAG) + def test_ranged_download_creates_combiner_when_auto_enabled(self): + self.configure_for_ranged_get() + self.add_head_object_response() + self.add_get_responses() + expected_b64 = self._compute_content_crc32_b64() + checksum_info = FullObjectChecksum('crc32', expected_b64) + self.transfer_future.meta.provide_full_object_checksum(checksum_info) + self.submission_task = self.get_download_submission_task() + self.wait_and_assert_completed_successfully(self.submission_task) + + def test_ranged_download_creates_combiner_when_explicitly_enabled(self): + self.reset_stubber_with_new_client( + {'config': Config(response_checksum_validation='when_required')} + ) + self.configure_for_ranged_get() + self.add_head_object_response() + self.add_get_responses() + expected_b64 = self._compute_content_crc32_b64() + self.extra_args['ChecksumMode'] = 'ENABLED' + self.call_args = self.get_call_args() + self.transfer_future = self.get_transfer_future(self.call_args) + checksum_info = FullObjectChecksum('crc32', expected_b64) + self.transfer_future.meta.provide_full_object_checksum(checksum_info) + self.submission_main_kwargs['client'] = self.client + self.submission_main_kwargs['transfer_future'] = self.transfer_future + self.submission_task = self.get_download_submission_task() + self.wait_and_assert_completed_successfully(self.submission_task) + + def test_ranged_download_no_combiner_without_checksum_info(self): + self.configure_for_ranged_get() + self.add_head_object_response() + self.add_get_responses() + self.submission_task = self.get_download_submission_task() + self.wait_and_assert_completed_successfully(self.submission_task) + + def test_ranged_download_no_combiner_when_validation_disabled(self): + self.reset_stubber_with_new_client( + {'config': Config(response_checksum_validation='when_required')} + ) + self.configure_for_ranged_get() + self.add_head_object_response() + self.add_get_responses() + checksum_info = FullObjectChecksum('crc32', 'wrong==') + self.transfer_future.meta.provide_full_object_checksum(checksum_info) + self.submission_main_kwargs['client'] = self.client + self.submission_task = self.get_download_submission_task() + self.wait_and_assert_completed_successfully(self.submission_task) + + def _compute_content_crc32_b64(self): + import base64 + + from awscrt import checksums as crt_checksums + + crc = crt_checksums.crc32(self.content) + return base64.b64encode(crc.to_bytes(4, byteorder='big')).decode( + 'ascii' + ) + class TestGetObjectTask(BaseTaskTest): def setUp(self): @@ -789,6 +853,50 @@ def test_handles_callback_on_initial_error(self): with self.assertRaises(RetriesExceededError): self.transfer_coordinator.result() + def test_checksum_combiner_self_computes_when_body_has_no_checksum(self): + self.stubber.add_response( + 'get_object', + service_response={'Body': self.stream}, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + combiner = FullObjectChecksumCombiner('crc32', 1) + task = self.get_download_task(checksum_combiner=combiner, part_index=0) + task() + assert combiner.combined_b64 is not None + + def test_checksum_combiner_reuses_botocore_checksum(self): + botocore_checksum = create_checksum_for_algorithm('crc32') + botocore_checksum.update(self.content) + body = BytesIO(self.content) + body.checksum = botocore_checksum + + self.stubber.add_response( + 'get_object', + service_response={'Body': body}, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + combiner = FullObjectChecksumCombiner('crc32', 1) + task = self.get_download_task(checksum_combiner=combiner, part_index=0) + task() + expected_checksum = create_checksum_for_algorithm('crc32') + expected_checksum.update(self.content) + expected_combiner = FullObjectChecksumCombiner('crc32', 1) + expected_combiner.register_part( + 0, expected_checksum, len(self.content) + ) + assert combiner.combined_b64 == expected_combiner.combined_b64 + + def test_no_checksum_computation_without_combiner(self): + self.stubber.add_response( + 'get_object', + service_response={'Body': self.stream}, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + task = self.get_download_task() + task() + self.stubber.assert_no_pending_responses() + self.assert_io_writes([(0, self.content)]) + class TestImmediatelyWriteIOGetObjectTask(TestGetObjectTask): def setUp(self): diff --git a/tests/unit/s3transfer/test_futures.py b/tests/unit/s3transfer/test_futures.py index a4c015b10cae..09c9da5e3b0d 100644 --- a/tests/unit/s3transfer/test_futures.py +++ b/tests/unit/s3transfer/test_futures.py @@ -169,6 +169,12 @@ def test_user_context(self): self.transfer_meta.user_context['foo'] = 'bar' self.assertEqual(self.transfer_meta.user_context, {'foo': 'bar'}) + def test_full_object_checksum(self): + self.assertIsNone(self.transfer_meta.full_object_checksum) + checksum = object() + self.transfer_meta.provide_full_object_checksum(checksum) + self.assertIs(self.transfer_meta.full_object_checksum, checksum) + class TestTransferCoordinator(unittest.TestCase): def setUp(self): diff --git a/tests/utils/s3transfer/__init__.py b/tests/utils/s3transfer/__init__.py index 2c1262ee2ccc..c84aa44a8169 100644 --- a/tests/utils/s3transfer/__init__.py +++ b/tests/utils/s3transfer/__init__.py @@ -180,6 +180,14 @@ def on_queued(self, future, **kwargs): future.meta.provide_object_etag(self.etag) +class FullObjectChecksumProvider: + def __init__(self, full_object_checksum): + self.full_object_checksum = full_object_checksum + + def on_queued(self, future, **kwargs): + future.meta.provide_full_object_checksum(self.full_object_checksum) + + class FileCreator: def __init__(self): self.rootdir = tempfile.mkdtemp()