Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changes/next-release/enhancement-s3-67811.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"type": "enhancement",
"category": "``s3``",
"description": "Automatically calculate and validate full object checksums during multipart downloads, when available."
}
6 changes: 6 additions & 0 deletions awscli/botocore/httpchecksum.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,10 @@ def __init__(self, raw_stream, content_length, checksum, expected):
self._checksum = checksum
self._expected = expected

@property
def checksum(self):
return self._checksum

def read(self, amt=None):
chunk = super().read(amt=amt)
self._checksum.update(chunk)
Expand All @@ -284,6 +288,8 @@ def read(self, amt=None):
return chunk

def _validate_checksum(self):
if self._expected is None:
return
if self._checksum.digest() != base64.b64decode(self._expected):
error_msg = f"Expected checksum {self._expected} did not match calculated checksum: {self._checksum.b64digest()}"
raise FlexibleChecksumError(error_msg=error_msg)
Expand Down
5 changes: 5 additions & 0 deletions awscli/customizations/s3/filegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,11 @@ def _list_single_object(self, s3_path):
try:
params = {'Bucket': bucket, 'Key': key}
params.update(self.request_parameters.get('HeadObject', {}))
if (
self._client.meta.config.response_checksum_validation
== 'when_supported'
):
params.setdefault('ChecksumMode', 'ENABLED')
response = self._client.head_object(**params)
except ClientError as e:
# We want to try to give a more helpful error message.
Expand Down
10 changes: 10 additions & 0 deletions awscli/customizations/s3/s3handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import logging
import os

from s3transfer.checksums import resolve_full_object_checksum
from s3transfer.manager import TransferManager

from awscli.compat import get_binary_stdin
Expand All @@ -39,6 +40,7 @@
DeleteSourceObjectSubscriber,
DirectoryCreatorSubscriber,
ProvideETagSubscriber,
ProvideFullObjectChecksumSubscriber,
ProvideLastModifiedTimeSubscriber,
ProvideSizeSubscriber,
ProvideUploadContentTypeSubscriber,
Expand Down Expand Up @@ -433,6 +435,14 @@ def _add_additional_subscribers(self, subscribers, fileinfo):
fileinfo.case_conflict_key,
)
)
if fileinfo.associated_response_data:
checksum_info = resolve_full_object_checksum(
fileinfo.associated_response_data
)
if checksum_info is not None:
subscribers.append(
ProvideFullObjectChecksumSubscriber(checksum_info)
)

def _submit_transfer_request(self, fileinfo, extra_args, subscribers):
bucket, key = find_bucket_key(fileinfo.src)
Expand Down
19 changes: 19 additions & 0 deletions awscli/customizations/s3/subscribers.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,25 @@ def on_queued(self, future, **kwargs):
)


class ProvideFullObjectChecksumSubscriber(BaseSubscriber):
"""
A subscriber which provides the stored full object checksum value.
"""

def __init__(self, full_object_checksum):
self.full_object_checksum = full_object_checksum

def on_queued(self, future, **kwargs):
if hasattr(future.meta, 'provide_full_object_checksum'):
future.meta.provide_full_object_checksum(self.full_object_checksum)
else:
LOGGER.debug(
'Not providing full object checksum. Future: '
f'{future} does not offer the capability to notify '
'the full object checksum',
)


class CaseConflictCleanupSubscriber(BaseSubscriber):
"""
A subscriber which removes object compare key from case conflict set
Expand Down
1 change: 1 addition & 0 deletions awscli/customizations/s3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,7 @@ def map_head_object_params(cls, request_params, cli_params):
"""Map CLI params to HeadObject request params"""
cls._set_sse_c_request_params(request_params, cli_params)
cls._set_request_payer_param(request_params, cli_params)
cls._set_checksum_mode_param(request_params, cli_params)

@classmethod
def map_head_object_params_with_copy_source_sse(
Expand Down
105 changes: 105 additions & 0 deletions awscli/s3transfer/checksums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import base64
import logging
from collections import namedtuple

from awscrt import checksums as crt_checksums
from botocore.httpchecksum import _CHECKSUM_CLS
from s3transfer.exceptions import S3DownloadChecksumError

logger = logging.getLogger(__name__)


CrcCombineInfo = namedtuple('CrcCombineInfo', ['combine_fn', 'byte_length'])


PartChecksum = namedtuple('PartChecksum', ['crc_int', 'data_length'])


FullObjectChecksum = namedtuple(
'FullObjectChecksum', ['algorithm', 'expected_b64']
)


_CRC_COMBINE_FUNCTIONS = {
'crc32': CrcCombineInfo(crt_checksums.combine_crc32, 4),
'crc32c': CrcCombineInfo(crt_checksums.combine_crc32c, 4),
'crc64nvme': CrcCombineInfo(crt_checksums.combine_crc64nvme, 8),
}


_CHECKSUM_KEY_TO_ALGORITHM = {
'ChecksumCRC32': 'crc32',
'ChecksumCRC32C': 'crc32c',
'ChecksumCRC64NVME': 'crc64nvme',
}


def resolve_full_object_checksum(response):
if response.get('ChecksumType', '').upper() != 'FULL_OBJECT':
return None
for key, algorithm in _CHECKSUM_KEY_TO_ALGORITHM.items():
value = response.get(key)
if value:
return FullObjectChecksum(algorithm=algorithm, expected_b64=value)
return None


def create_checksum_for_algorithm(algorithm):
if checksum_cls := _CHECKSUM_CLS.get(algorithm):
return checksum_cls()
return None


class FullObjectChecksumCombiner:
def __init__(self, algorithm, num_parts, expected_b64=None):
self._algorithm = algorithm
self._expected_b64 = expected_b64
self._num_parts = num_parts
self._combine_info = _CRC_COMBINE_FUNCTIONS[algorithm]
self._parts = {}
self._combined_bytes = None

@property
def algorithm(self):
return self._algorithm

def register_part(self, part_index, checksum, data_length):
crc_int = int.from_bytes(checksum.digest(), byteorder='big')
self._parts[part_index] = PartChecksum(crc_int, data_length)

def combine_and_validate(self):
combined_bytes = self._get_combined_bytes()
combined_b64 = base64.b64encode(combined_bytes).decode('ascii')
expected_bytes = base64.b64decode(self._expected_b64)
if combined_bytes != expected_bytes:
raise S3DownloadChecksumError(
f'Expected full object checksum '
f'({self._algorithm}) {self._expected_b64} did not match '
f'combined checksum: {combined_b64}'
)
logger.debug(
'Full object %s checksum validated: %s',
self._algorithm,
combined_b64,
)

@property
def combined_b64(self):
combined_bytes = self._get_combined_bytes()
return base64.b64encode(combined_bytes).decode('ascii')

def _get_combined_bytes(self):
if self._combined_bytes is not None:
return self._combined_bytes
crc = self._parts[0].crc_int
for i in range(1, self._num_parts):
part = self._parts[i]
crc = self._combine_info.combine_fn(
crc, part.crc_int, part.data_length
)
self._combined_bytes = crc.to_bytes(
self._combine_info.byte_length, byteorder='big'
)
return self._combined_bytes
92 changes: 87 additions & 5 deletions awscli/s3transfer/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
import threading

from botocore.exceptions import ClientError
from botocore.httpchecksum import StreamingChecksumBody
from s3transfer.checksums import (
FullObjectChecksumCombiner,
create_checksum_for_algorithm,
)
from s3transfer.compat import seekable
from s3transfer.exceptions import (
RetriesExceededError,
Expand Down Expand Up @@ -477,11 +482,27 @@ def _submit_ranged_download_request(
# Get any associated tags for the get object task.
get_object_tag = download_output_manager.get_download_task_tag()

checksum_combiner = self._create_checksum_combiner(
client.meta.config,
transfer_future,
num_parts,
)

# Callback invoker to submit the final io task once all downloads
# are complete.
finalize_callback = self._get_final_io_task_submission_callback(
download_output_manager, io_executor
)
pre_finalize_callbacks = []
if checksum_combiner is not None:
pre_finalize_callbacks.append(
checksum_combiner.combine_and_validate
)
finalize_download_invoker = CountCallbackInvoker(
self._get_final_io_task_submission_callback(
download_output_manager, io_executor
FunctionContainer(
self._finalize_download,
pre_finalize_callbacks,
finalize_callback,
)
)
for i in range(num_parts):
Expand Down Expand Up @@ -512,16 +533,28 @@ def _submit_ranged_download_request(
'callbacks': progress_callbacks,
'max_attempts': config.num_download_attempts,
'start_index': i * part_size,
'part_index': i,
'download_output_manager': download_output_manager,
'io_chunksize': config.io_chunksize,
'bandwidth_limiter': bandwidth_limiter,
'checksum_combiner': checksum_combiner,
},
done_callbacks=[finalize_download_invoker.decrement],
),
tag=get_object_tag,
)
finalize_download_invoker.finalize()

def _finalize_download(self, pre_finalize_callbacks, finalize_callback):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like when a part download fails or the transfer is cancelld, the counter still hits zero and _finalize_download runs. If some parts don't have their checksums registered with the combiner, combine_and_validate hits a KeyError on the missing part. Should we check self._transfer_coordinator.exception at the top and bail out early if the transfer already failed?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a secondary guard, _get_combined_bytes could validate that len(self._parts) == self._num_parts before iterating.

for callback in pre_finalize_callbacks:
try:
callback()
except Exception as e:
self._transfer_coordinator.set_exception(e)
self._transfer_coordinator.announce_done()
return
finalize_callback()

def _get_final_io_task_submission_callback(
self, download_manager, io_executor
):
Expand All @@ -530,6 +563,25 @@ def _get_final_io_task_submission_callback(
self._transfer_coordinator.submit, io_executor, final_task
)

def _create_checksum_combiner(
self, client_config, transfer_future, num_parts
):
checksum_info = transfer_future.meta.full_object_checksum
if checksum_info is None:
return None
extra_args = transfer_future.meta.call_args.extra_args
auto_enabled = (
client_config.response_checksum_validation == 'when_supported'
)
explicitly_enabled = extra_args.get('ChecksumMode') == 'ENABLED'
if not auto_enabled and not explicitly_enabled:
return None
return FullObjectChecksumCombiner(
algorithm=checksum_info.algorithm,
num_parts=num_parts,
expected_b64=checksum_info.expected_b64,
)

def _calculate_range_param(self, part_size, part_index, num_parts):
# Used to calculate the Range parameter
start_range = part_index * part_size
Expand All @@ -554,7 +606,9 @@ def _main(
download_output_manager,
io_chunksize,
start_index=0,
part_index=0,
bandwidth_limiter=None,
checksum_combiner=None,
):
"""Downloads an object and places content into io queue

Expand All @@ -571,8 +625,11 @@ def _main(
download stream and queue in the io queue.
:param start_index: The location in the file to start writing the
content of the key to.
:param part_index: The part number for this ranged download.
:param bandwidth_limiter: The bandwidth limiter to use when throttling
the downloading of data in streams.
:param checksum_combiner: Optional FullObjectChecksumCombiner for
full object checksum validation on multipart downloads.
"""
last_exception = None
for i in range(max_attempts):
Expand All @@ -585,22 +642,40 @@ def _main(
extra_args.get('Range'),
response.get('ContentRange'),
)
streaming_body = StreamReaderProgress(
response['Body'], callbacks
)
# When doing full object checksum combining and botocore
# hasn't already wrapped the body with a checksum
# calculator, wrap it in StreamingChecksumBody ourselves
# so the CRC is computed as data is read. We pass
# expected=None since we validate at the full object
# level, not per-part.
body = response['Body']
if checksum_combiner is not None and not hasattr(
body, 'checksum'
):
body = StreamingChecksumBody(
body,
response.get('ContentLength'),
create_checksum_for_algorithm(
checksum_combiner.algorithm
),
expected=None,
)
streaming_body = StreamReaderProgress(body, callbacks)
if bandwidth_limiter:
streaming_body = (
bandwidth_limiter.get_bandwith_limited_stream(
streaming_body, self._transfer_coordinator
)
)

part_length = 0
chunks = DownloadChunkIterator(streaming_body, io_chunksize)
for chunk in chunks:
# If the transfer is done because of a cancellation
# or error somewhere else, stop trying to submit more
# data to be written and break out of the download.
if not self._transfer_coordinator.done():
part_length += len(chunk)
self._handle_io(
download_output_manager,
fileobj,
Expand All @@ -610,6 +685,13 @@ def _main(
current_index += len(chunk)
else:
return

if checksum_combiner is not None:
checksum_combiner.register_part(
part_index,
body.checksum,
part_length,
)
return
except ClientError as e:
error_code = e.response.get('Error', {}).get('Code')
Expand Down
4 changes: 4 additions & 0 deletions awscli/s3transfer/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,7 @@ class FatalError(CancelledError):

class S3ValidationError(Exception):
pass


class S3DownloadChecksumError(Exception):
pass
Loading
Loading