From 8b12b60d1d2e81a74a8635b5323d84a4fcbe16de Mon Sep 17 00:00:00 2001 From: Thinhinane Ihadadene Date: Mon, 6 Oct 2025 11:33:34 +0100 Subject: [PATCH] Handle trailer headers and unsigned aws chunked payloads --- .../aws/proxy/spi/rest/RequestContent.java | 10 + .../server/rest/AwsChunkedInputStream.java | 95 ++++-- .../aws/proxy/server/rest/RequestBuilder.java | 9 +- .../server/rest/RequestHeadersBuilder.java | 36 ++- .../proxy/server/rest/TrinoS3ProxyClient.java | 12 +- .../proxy/server/TestGenericRestRequests.java | 2 +- .../aws/proxy/server/TestHttpChunked.java | 240 ++++++++++++-- .../rest/TestAwsChunkedInputStream.java | 304 ++++++++++++++++-- .../rest/TestRequestHeadersBuilder.java | 61 +++- .../signing/TestingChunkSigningSession.java | 15 +- 10 files changed, 688 insertions(+), 96 deletions(-) diff --git a/trino-aws-proxy-spi/src/main/java/io/trino/aws/proxy/spi/rest/RequestContent.java b/trino-aws-proxy-spi/src/main/java/io/trino/aws/proxy/spi/rest/RequestContent.java index 597e58d6..b9c35233 100644 --- a/trino-aws-proxy-spi/src/main/java/io/trino/aws/proxy/spi/rest/RequestContent.java +++ b/trino-aws-proxy-spi/src/main/java/io/trino/aws/proxy/spi/rest/RequestContent.java @@ -13,7 +13,10 @@ */ package io.trino.aws.proxy.spi.rest; +import com.google.common.collect.ImmutableList; + import java.io.InputStream; +import java.util.List; import java.util.Optional; @FunctionalInterface @@ -27,7 +30,9 @@ enum ContentType STANDARD, W3C_CHUNKED, AWS_CHUNKED, + AWS_CHUNKED_UNSIGNED, AWS_CHUNKED_IN_W3C_CHUNKED, + AWS_CHUNKED_IN_W3C_CHUNKED_UNSIGNED, } default ContentType contentType() @@ -48,5 +53,10 @@ default Optional contentLength() return Optional.empty(); } + default List trailerHeaders() + { + return ImmutableList.of(); + } + Optional inputStream(); } diff --git a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/AwsChunkedInputStream.java b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/AwsChunkedInputStream.java index 160df380..9b2af098 100644 --- a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/AwsChunkedInputStream.java +++ b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/AwsChunkedInputStream.java @@ -14,11 +14,13 @@ package io.trino.aws.proxy.server.rest; import com.google.common.base.Splitter; +import com.google.common.collect.ImmutableList; import io.trino.aws.proxy.spi.signing.ChunkSigningSession; import jakarta.ws.rs.WebApplicationException; import java.io.IOException; import java.io.InputStream; +import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; import java.util.Optional; @@ -30,7 +32,7 @@ class AwsChunkedInputStream extends InputStream { private final InputStream delegate; - private final ChunkSigningSession chunkSigningSession; + private final Optional chunkSigningSession; private enum State { @@ -44,12 +46,14 @@ private enum State private int bytesRemainingInChunk; private int bytesAccountedFor; private final int decodedContentLength; + private final List trailerHeaders; - AwsChunkedInputStream(InputStream delegate, ChunkSigningSession chunkSigningSession, int decodedContentLength) + AwsChunkedInputStream(InputStream delegate, Optional chunkSigningSession, int decodedContentLength, List trailerHeaders) { this.delegate = requireNonNull(delegate, "delegate is null"); this.chunkSigningSession = requireNonNull(chunkSigningSession, "chunkSigningSession is null"); this.decodedContentLength = decodedContentLength; + this.trailerHeaders = requireNonNull(ImmutableList.copyOf(trailerHeaders), "trailerHeaders is null"); } @Override @@ -65,7 +69,7 @@ public int read() throw new WebApplicationException("Unexpected end of stream", BAD_REQUEST); } - chunkSigningSession.write((byte) (i & 0xff)); + chunkSigningSession.ifPresent(chunkSigningSession -> chunkSigningSession.write((byte) (i & 0xff))); updateBytesRemaining(1); return i; @@ -86,7 +90,7 @@ public int read(byte[] b, int off, int len) throw new WebApplicationException("Unexpected end of stream", BAD_REQUEST); } - chunkSigningSession.write(b, off, count); + chunkSigningSession.ifPresent(chunkSigningSession -> chunkSigningSession.write(b, off, count)); updateBytesRemaining(count); return count; @@ -155,9 +159,6 @@ private void nextChunk() boolean success = false; do { List parts = Splitter.on(';').trimResults().limit(2).splitToList(header); - if (parts.size() != 2) { - break; - } int chunkSize; try { @@ -170,23 +171,41 @@ private void nextChunk() break; } - Optional chunkSignature = Splitter.on(';').trimResults().withKeyValueSeparator('=').split(parts.get(1)) - .entrySet() - .stream() - .filter(entry -> entry.getKey().equalsIgnoreCase("chunk-signature")) - .map(Map.Entry::getValue) - .findFirst(); + if (chunkSigningSession.isPresent()) { + if (parts.size() != 2) { + break; + } + + Optional chunkSignature = Splitter.on(';').trimResults().withKeyValueSeparator('=').split(parts.get(1)) + .entrySet() + .stream() + .filter(entry -> entry.getKey().equalsIgnoreCase("chunk-signature")) + .map(Map.Entry::getValue) + .findFirst(); - if (chunkSignature.isEmpty()) { - break; + if (chunkSignature.isEmpty()) { + break; + } + + chunkSigningSession.get().startChunk(chunkSignature.get()); + } + else { + if (parts.size() != 1) { + break; + } } - chunkSigningSession.startChunk(chunkSignature.get()); bytesRemainingInChunk = chunkSize; if (chunkSize == 0) { - readEmptyLine(); - chunkSigningSession.complete(); + if (trailerHeaders.isEmpty()) { + readEmptyLine(); + chunkSigningSession.ifPresent(ChunkSigningSession::complete); + } + else { + readTrailingHeaders(); + readEmptyLine(); + } state = State.LAST_CHUNK; } bytesAccountedFor += chunkSize; @@ -236,4 +255,44 @@ private String readLine() return line.toString(); } + + private TrailerHeaderChunk readTrailingHeadersChunk() + throws IOException + { + Optional signature = Optional.empty(); + StringBuilder trailerHeadersChunkBuilder = new StringBuilder(); + for (int i = 0; i < this.trailerHeaders.size(); i++) { + String trailerHeaders = readLine(); + List trailerHeadersValues = Splitter.on(":").trimResults().limit(2).splitToList(trailerHeaders); + String trailerHeaderName = trailerHeadersValues.getFirst(); + if ((trailerHeadersValues.size() != 2) || !this.trailerHeaders.contains(trailerHeaderName)) { + throw new WebApplicationException("Trailer header is invalid: " + trailerHeaders, BAD_REQUEST); + } + if (trailerHeaderName.equals("x-amz-trailer-signature")) { + signature = Optional.of(trailerHeadersValues.getLast()); + break; + } + else { + trailerHeadersChunkBuilder.append(trailerHeaders); + } + } + return new TrailerHeaderChunk(trailerHeadersChunkBuilder.toString(), signature); + } + + private void readTrailingHeaders() + throws IOException + { + TrailerHeaderChunk trailerHeaderChunk = readTrailingHeadersChunk(); + chunkSigningSession.ifPresent(chunkSigningSession -> { + if (trailerHeaderChunk.signature.isEmpty()) { + throw new WebApplicationException("Expected x-amz-trailer-signature, none found", BAD_REQUEST); + } + chunkSigningSession.startChunk(trailerHeaderChunk.signature.get()); + byte[] trailerHeaderContent = trailerHeaderChunk.trailerHeaders.getBytes(StandardCharsets.UTF_8); + chunkSigningSession.write(trailerHeaderContent, 0, trailerHeaderContent.length); + chunkSigningSession.complete(); + }); + } + + private record TrailerHeaderChunk(String trailerHeaders, Optional signature) {} } diff --git a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/RequestBuilder.java b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/RequestBuilder.java index 61757d87..b54ad190 100644 --- a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/RequestBuilder.java +++ b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/RequestBuilder.java @@ -15,6 +15,7 @@ import com.google.common.base.Splitter; import com.google.common.base.Suppliers; +import com.google.common.collect.ImmutableList; import io.airlift.log.Logger; import io.trino.aws.proxy.server.rest.RequestHeadersBuilder.InternalRequestHeaders; import io.trino.aws.proxy.server.signing.SigningQueryParameters; @@ -152,7 +153,7 @@ private static RequestContent buildRequestContent(InputStream requestEntityStrea // AWS does not mandate x-amz-decoded-content length is required for chunked transfer encoding // But we require it for simplicity (Content-Length is needed since we don't do chunking on outbound requests) - case AWS_CHUNKED, W3C_CHUNKED, AWS_CHUNKED_IN_W3C_CHUNKED -> () -> { + case AWS_CHUNKED, AWS_CHUNKED_UNSIGNED, W3C_CHUNKED, AWS_CHUNKED_IN_W3C_CHUNKED, AWS_CHUNKED_IN_W3C_CHUNKED_UNSIGNED -> () -> { int contentLength = requestHeaders.decodedContentLength() .orElseThrow(() -> new WebApplicationException(BAD_REQUEST)); return Optional.of(contentLength); @@ -188,6 +189,12 @@ public Optional inputStream() .map(bytes -> (InputStream) new ByteArrayInputStream(bytes)) .or(() -> Optional.of(requestEntityStream)); } + + @Override + public List trailerHeaders() + { + return ImmutableList.copyOf(requestHeaders.requestHeaders().unmodifiedHeaders().get("x-amz-trailer")); + } }; } } diff --git a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/RequestHeadersBuilder.java b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/RequestHeadersBuilder.java index bf907921..be3f36c4 100644 --- a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/RequestHeadersBuilder.java +++ b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/RequestHeadersBuilder.java @@ -50,7 +50,8 @@ private RequestHeadersBuilder() {} "connection", "amz-sdk-invocation-id", "amx-sdk-request", - "host"); + "host", + "x-amz-trailer"); record InternalRequestHeaders( RequestHeaders requestHeaders, @@ -75,7 +76,9 @@ static InternalRequestHeaders parseHeaders(MultiMap allRequestHeaders) Builder builder = new Builder(); allRequestHeaders.forEach((headerName, headerValues) -> { switch (headerName) { - case "authorization", "x-amz-security-token" -> {} // these get handled separately + case "authorization", "x-amz-security-token" -> { + // these get handled separately + } case "content-length" -> builder.contentLength(headerValues); case "x-amz-decoded-content-length" -> builder.decodedContentLength(headerValues); case "content-encoding" -> builder.contentEncoding(headerValues); @@ -104,6 +107,13 @@ private static class Builder private Optional contentSha256 = Optional.empty(); private Set seenRequestPayloadContentTypes = new HashSet<>(); + private static final Set CHUNKED_CONTENT_TYPES = Set.of( + ContentType.AWS_CHUNKED, + ContentType.AWS_CHUNKED_IN_W3C_CHUNKED, + ContentType.W3C_CHUNKED, + ContentType.AWS_CHUNKED_UNSIGNED, + ContentType.AWS_CHUNKED_IN_W3C_CHUNKED_UNSIGNED); + private Builder() {} private static Optional parseHeaderValuesAsSingle(List allValues) @@ -193,14 +203,18 @@ private void addPassthroughHeader(String headerName, List headerValues) passthroughHeadersBuilder.addAll(headerName, headerValues); } + private String requiredContentSha256() + { + return contentSha256.orElseThrow(() -> new WebApplicationException(BAD_REQUEST)); + } + private void assertContentTypeValid(ContentType actualContentType) { - if (actualContentType == ContentType.AWS_CHUNKED || actualContentType == ContentType.AWS_CHUNKED_IN_W3C_CHUNKED || actualContentType == ContentType.W3C_CHUNKED) { + if (CHUNKED_CONTENT_TYPES.contains(actualContentType)) { if (decodedContentLength.isEmpty()) { throw new WebApplicationException(LENGTH_REQUIRED); } - String sha256 = contentSha256.orElseThrow(() -> new WebApplicationException(BAD_REQUEST)); - if (actualContentType != ContentType.W3C_CHUNKED && !sha256.startsWith("STREAMING-")) { + if (actualContentType != ContentType.W3C_CHUNKED && !requiredContentSha256().startsWith("STREAMING-")) { throw new WebApplicationException(BAD_REQUEST); } } @@ -209,11 +223,21 @@ private void assertContentTypeValid(ContentType actualContentType) private InternalRequestHeaders build(MultiMap allHeaders) { Optional applicableContentType = switch (seenRequestPayloadContentTypes.size()) { - case 0, 1 -> seenRequestPayloadContentTypes.stream().findFirst(); + case 0 -> Optional.empty(); + case 1 -> { + Optional contentType = seenRequestPayloadContentTypes.stream().findFirst(); + if (contentType.get().equals(ContentType.AWS_CHUNKED) && requiredContentSha256().startsWith("STREAMING-UNSIGNED-PAYLOAD")) { + yield Optional.of(ContentType.AWS_CHUNKED_UNSIGNED); + } + yield contentType; + } case 2 -> { if (!seenRequestPayloadContentTypes.containsAll(ImmutableSet.of(ContentType.AWS_CHUNKED, ContentType.W3C_CHUNKED))) { throw new WebApplicationException(BAD_REQUEST); } + if (requiredContentSha256().startsWith("STREAMING-UNSIGNED-PAYLOAD")) { + yield Optional.of(ContentType.AWS_CHUNKED_IN_W3C_CHUNKED_UNSIGNED); + } yield Optional.of(ContentType.AWS_CHUNKED_IN_W3C_CHUNKED); } default -> throw new WebApplicationException(BAD_REQUEST); diff --git a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/TrinoS3ProxyClient.java b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/TrinoS3ProxyClient.java index d74dde6a..e07d80fc 100644 --- a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/TrinoS3ProxyClient.java +++ b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/TrinoS3ProxyClient.java @@ -212,8 +212,16 @@ private Optional contentInputStream(RequestContent requestContent, { return switch (requestContent.contentType()) { case AWS_CHUNKED, AWS_CHUNKED_IN_W3C_CHUNKED -> requestContent.inputStream() - .map(inputStream -> new AwsChunkedInputStream(limitStreamController.wrap(inputStream), signingMetadata.requiredSigningContext().chunkSigningSession(), requestContent.contentLength().orElseThrow())); - + .map(inputStream -> new AwsChunkedInputStream(limitStreamController.wrap(inputStream), + Optional.of(signingMetadata.requiredSigningContext().chunkSigningSession()), + requestContent.contentLength().orElseThrow(), + requestContent.trailerHeaders())); + case AWS_CHUNKED_UNSIGNED, AWS_CHUNKED_IN_W3C_CHUNKED_UNSIGNED -> requestContent.inputStream() + .map(inputStream -> new AwsChunkedInputStream( + limitStreamController.wrap(inputStream), + Optional.empty(), + requestContent.contentLength().orElseThrow(), + requestContent.trailerHeaders())); case STANDARD, W3C_CHUNKED -> requestContent.inputStream().map(inputStream -> { SigningContext signingContext = signingMetadata.requiredSigningContext(); return signingContext.contentHash() diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestGenericRestRequests.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestGenericRestRequests.java index 417436d2..fbebfc89 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestGenericRestRequests.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestGenericRestRequests.java @@ -248,7 +248,7 @@ private StatusResponse doAwsChunkedUpload( URI requestUri = UriBuilder.fromUri(baseUri).path(bucket).path(key).build(); RequestAuthorization requestAuthorization = signRequest(requestSigningCredential, requestUri, requestDate, "PUT", requestHeaderBuilder.build()); - String chunkedContent = chunkedPayloadMutator.apply(TestingChunkSigningSession.build(chunkSigningCredential, requestAuthorization.signature(), requestDate).generateChunkedStream(contentToUpload, partitionCount)); + String chunkedContent = chunkedPayloadMutator.apply(TestingChunkSigningSession.build(chunkSigningCredential, requestAuthorization.signature(), requestDate).generateChunkedStream(contentToUpload, partitionCount, Optional.empty())); Request.Builder requestBuilder = preparePut().setUri(requestUri); requestHeaderBuilder.add("Authorization", requestAuthorization.authorization()); diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestHttpChunked.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestHttpChunked.java index f295c440..f9d6a880 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestHttpChunked.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestHttpChunked.java @@ -13,6 +13,7 @@ */ package io.trino.aws.proxy.server; +import com.google.common.collect.ImmutableMap; import com.google.inject.Inject; import io.airlift.http.client.HttpClient; import io.airlift.http.client.Request; @@ -49,6 +50,8 @@ import java.nio.charset.StandardCharsets; import java.time.Instant; import java.util.LinkedList; +import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Queue; import java.util.UUID; @@ -76,7 +79,6 @@ public class TestHttpChunked private final HttpClient httpClient; private final IdentityCredential testingCredentials; private final S3Client storageClient; - private static final String TEST_CONTENT_TYPE = "text/plain;charset=utf-8"; private static final Credential VALID_CREDENTIAL = new Credential(UUID.randomUUID().toString(), UUID.randomUUID().toString()); @@ -164,6 +166,8 @@ public void testHttpChunked() { String bucket = "test-http-chunked"; String bucketTwo = "test-http-chunked-two"; + String bucketThree = "test-http-chunked-three"; + storageClient.createBucket(r -> r.bucket(bucket).build()); testHttpChunked(bucket, LOREM_IPSUM, "UNSIGNED-PAYLOAD", 1); testHttpChunked(bucket, LOREM_IPSUM, "UNSIGNED-PAYLOAD", 3); @@ -173,39 +177,88 @@ public void testHttpChunked() testHttpChunked(bucketTwo, LOREM_IPSUM, sha256(LOREM_IPSUM), 1); testHttpChunked(bucketTwo, LOREM_IPSUM, sha256(LOREM_IPSUM), 3); testHttpChunked(bucketTwo, LOREM_IPSUM, sha256(LOREM_IPSUM), 5); + + storageClient.createBucket(r -> r.bucket(bucketThree).build()); + testHttpChunked(bucketThree, LOREM_IPSUM, "STREAMING-UNSIGNED-PAYLOAD-TRAILER", 1); + testHttpChunked(bucketThree, LOREM_IPSUM, "STREAMING-UNSIGNED-PAYLOAD-TRAILER", 3); + testHttpChunked(bucketThree, LOREM_IPSUM, "STREAMING-UNSIGNED-PAYLOAD-TRAILER", 5); } - private void testHttpChunked(String bucket, String content, String sha256, int partitionCount) + @Test + public void testHttpChunkedWithAwsChunkedEncodingUnsignedPayload() throws IOException { - assertThat(doHttpChunkedUpload( + Map trailerHeaders = ImmutableMap.of("x-amz-checksum-crc32", "4ODU/w=="); + testHttpChunkedWithAwsChunkedEncodingUnsignedPayloadSuccess(ImmutableMultiMap.empty(), Optional.of(trailerHeaders), Optional.empty()); + } + + @Test + public void testHttpChunkedWithAwsChunkedEncodingUnsignedPayloadMultipleEncodings() + throws IOException + { + Map trailerHeaders = ImmutableMap.of("x-amz-checksum-crc32", "4ODU/w=="); + ImmutableMultiMap.Builder extraHeaders = ImmutableMultiMap.builder(false) + .add("Content-Encoding", "aws-chunked").add("Content-Encoding", "gzip,compress"); + testHttpChunkedWithAwsChunkedEncodingUnsignedPayloadSuccess(extraHeaders.build(), Optional.of(trailerHeaders), Optional.of("gzip,compress")); + } + + @Test + public void testHttpChunkedWithAwsChunkedEncodingUnsignedPayloadNoTrailerHeaders() + throws IOException + { + ImmutableMultiMap.Builder extraHeaders = ImmutableMultiMap.builder(false).add("Content-Encoding", "gzip,compress"); + testHttpChunkedWithAwsChunkedEncodingUnsignedPayloadSuccess(extraHeaders.build(), Optional.empty(), Optional.of("gzip,compress")); + } + + @Test + public void testHttpChunkedWithAwsChunkedEncodingUnsignedPayloadWithMultipleTrailerHeaders() + throws IOException + { + Map trailerHeaders = ImmutableMap.of("x-amz-checksum-something-else", "4ODU/w==", + "x-amz-checksum-something-else2", "10C"); + testHttpChunkedWithAwsChunkedEncodingUnsignedPayloadSuccess(ImmutableMultiMap.empty(), Optional.of(trailerHeaders), Optional.empty()); + } + + @Test + public void testHttpChunkedWithAwsChunkedEncodingUnsignedPayloadWithInvalidTrailerHeadersPassed() + { + String bucket = "test-http-chunked-aws-chunked-header"; + storageClient.createBucket(r -> r.bucket(bucket).build()); + ImmutableMultiMap.Builder headersBuilder = ImmutableMultiMap.builder(false) + .add("X-Amz-Content-Sha256", "STREAMING-UNSIGNED-PAYLOAD-TRAILER") + .add("Content-Encoding", "aws-chunked") + .add("x-amz-trailer", "x-different-trailer"); + + assertThat(doCustomHttpChunkedUpload( bucket, "basic-upload", - content, - partitionCount, - ImmutableMultiMap.builder(false) - .add("X-Amz-Content-Sha256", sha256).build())).isEqualTo(200); - assertThat(getFileFromStorage(storageClient, bucket, "basic-upload")).isEqualTo(content); - HeadObjectResponse basicUpload = headObjectInStorage(storageClient, bucket, "basic-upload"); - assertThat(basicUpload.contentEncoding()).isNullOrEmpty(); - assertThat(basicUpload.metadata()).isEmpty(); + 3, + headersBuilder.build(), + LOREM_IPSUM.length(), + signature -> generateUnsignedChunkedStream(LOREM_IPSUM, 3, + Optional.of(buildTrailerChunk(ImmutableMap.of("x-amz-unknown", "bla")))) + )).isEqualTo(400); + } - assertThat(doHttpChunkedUpload( + @Test + public void testHttpChunkedWithAwsChunkedEncodingUnsignedPayloadRaisesWhenNoValueForHeader() + { + String bucket = "test-http-chunked-aws-chunked-header"; + storageClient.createBucket(r -> r.bucket(bucket).build()); + ImmutableMultiMap.Builder headersBuilder = ImmutableMultiMap.builder(false) + .add("X-Amz-Content-Sha256", "STREAMING-UNSIGNED-PAYLOAD-TRAILER") + .add("Content-Encoding", "aws-chunked") + .add("x-amz-trailer", "x-some-trailer"); + + assertThat(doCustomHttpChunkedUpload( bucket, - "with-content-type", - content, - partitionCount, - ImmutableMultiMap.builder(false) - .add("X-Amz-Content-Sha256", sha256) - .add("Content-Type", TEST_CONTENT_TYPE) - .add("Content-Encoding", "gzip,compress") - .add("x-amz-meta-foobar", "baz") - .build())).isEqualTo(200); - assertThat(getFileFromStorage(storageClient, bucket, "with-content-type")).isEqualTo(content); - HeadObjectResponse withFields = headObjectInStorage(storageClient, bucket, "with-content-type"); - assertThat(withFields.contentType()).contains(TEST_CONTENT_TYPE); - assertThat(withFields.contentEncoding()).isEqualTo("gzip,compress"); - assertThat(withFields.metadata()).containsEntry("foobar", "baz"); + "basic-upload", + 3, + headersBuilder.build(), + LOREM_IPSUM.length(), + signature -> generateUnsignedChunkedStream(LOREM_IPSUM, 3, + Optional.empty()) + )).isEqualTo(400); } @Test @@ -230,14 +283,18 @@ public void testHttpChunkedContainingAwsChunkedPayload() ImmutableMultiMap.Builder requestHeadersBuilder = ImmutableMultiMap.builder(false) .add("X-Amz-Content-Sha256", "STREAMING-AWS4-HMAC-SHA256-PAYLOAD") .add("Content-Encoding", "aws-chunked"); - assertThat(doCustomHttpChunkedUpload(bucket, "test-upload", 3, requestHeadersBuilder.build(), LOREM_IPSUM.length(), signature -> TestingChunkSigningSession.build(VALID_CREDENTIAL, signature).generateChunkedStream(LOREM_IPSUM, 3))).isEqualTo(200); + assertThat( + doCustomHttpChunkedUpload(bucket, "test-upload", 3, + requestHeadersBuilder.build(), LOREM_IPSUM.length(), + signature -> TestingChunkSigningSession.build(VALID_CREDENTIAL, signature).generateChunkedStream(LOREM_IPSUM, 3, Optional.empty()))) + .isEqualTo(200); assertThat(getFileFromStorage(storageClient, bucket, "test-upload")).isEqualTo(LOREM_IPSUM); requestHeadersBuilder .add("Content-Type", TEST_CONTENT_TYPE) .add("Content-Encoding", "gzip,compress") .add("x-amz-meta-foobar", "baz"); - assertThat(doCustomHttpChunkedUpload(bucket, "test-upload-with-metadata", 3, requestHeadersBuilder.build(), LOREM_IPSUM.length(), signature -> TestingChunkSigningSession.build(VALID_CREDENTIAL, signature).generateChunkedStream(LOREM_IPSUM, 3))).isEqualTo(200); + assertThat(doCustomHttpChunkedUpload(bucket, "test-upload-with-metadata", 3, requestHeadersBuilder.build(), LOREM_IPSUM.length(), signature -> TestingChunkSigningSession.build(VALID_CREDENTIAL, signature).generateChunkedStream(LOREM_IPSUM, 3, Optional.empty()))).isEqualTo(200); assertThat(getFileFromStorage(storageClient, bucket, "test-upload-with-metadata")).isEqualTo(LOREM_IPSUM); HeadObjectResponse objectMetadata = headObjectInStorage(storageClient, bucket, "test-upload-with-metadata"); assertThat(objectMetadata.contentType()).contains(TEST_CONTENT_TYPE); @@ -256,10 +313,135 @@ public void testHttpChunkedContainingAwsChunkedPayloadValidatesChunkSignatures() .add("Content-Encoding", "aws-chunked"); assertThat(doCustomHttpChunkedUpload( bucket, "test-upload", 3, requestHeadersBuilder.build(), LOREM_IPSUM.length(), - signature -> TestingChunkSigningSession.build(new Credential(UUID.randomUUID().toString(), UUID.randomUUID().toString()), signature).generateChunkedStream(LOREM_IPSUM, 3))).isEqualTo(401); + signature -> TestingChunkSigningSession.build(new Credential(UUID.randomUUID().toString(), UUID.randomUUID().toString()), signature).generateChunkedStream(LOREM_IPSUM, 3, Optional.empty()))).isEqualTo(401); assertFileNotInS3(storageClient, bucket, "test-upload"); } + @Test + public void testHttpChunkedContainingAwsChunkedPayloadWithTrailerHeaders() + throws IOException + { + String bucket = "http-chunked-aws-chunked"; + storageClient.createBucket(r -> r.bucket(bucket).build()); + Map trailerHeaders = ImmutableMap.of("x-some-trailer", "foobarval"); + + ImmutableMultiMap.Builder requestHeadersBuilder = ImmutableMultiMap.builder(false) + .add("X-Amz-Content-Sha256", "STREAMING-AWS4-HMAC-SHA256-PAYLOAD") + .add("Content-Encoding", "aws-chunked"); + requestHeadersBuilder.addAll("x-amz-trailer", buildTrailerHeaderValues(trailerHeaders)); + requestHeadersBuilder.add("x-amz-trailer", "x-amz-trailer-signature"); + + assertThat( + doCustomHttpChunkedUpload(bucket, "test-upload", 3, + requestHeadersBuilder.build(), LOREM_IPSUM.length(), + signature -> TestingChunkSigningSession.build(VALID_CREDENTIAL, signature).generateChunkedStream(LOREM_IPSUM, 3, Optional.of(buildTrailerChunk(trailerHeaders))))) + .isEqualTo(200); + assertThat(getFileFromStorage(storageClient, bucket, "test-upload")).isEqualTo(LOREM_IPSUM); + } + + private String buildTrailerChunk(Map trailerHeaders) + { + StringBuilder chunk = new StringBuilder(); + for (Map.Entry entry : trailerHeaders.entrySet()) { + chunk.append(entry.getKey()).append(":").append(entry.getValue()).append("\r\n"); + } + return chunk.toString(); + } + + private List buildTrailerHeaderValues(Map trailerHeaders) + { + return trailerHeaders.keySet().stream().toList(); + } + + public String generateUnsignedChunkedStream(String content, int partitions, Optional trailerHeaders) + { + checkArgument(partitions > 1, "partitions must be greater than 1"); + StringBuilder chunkedStream = new StringBuilder(); + int chunkSize = Math.ceilDiv(content.length(), partitions); + int index = 0; + while (index < content.length()) { + int thisLength = Math.min(chunkSize, content.length() - index); + String thisChunk = content.substring(index, index + thisLength); + chunkedStream.append(Integer.toHexString(thisLength)).append("\r\n"); + chunkedStream.append(thisChunk).append("\r\n"); + index += thisLength; + } + chunkedStream.append("0").append("\r\n"); + + // trailer headers and end stream + trailerHeaders.ifPresent(chunkedStream::append); + + // Mark end of entire streaming + chunkedStream.append("\r\n"); + return chunkedStream.toString(); + } + + private void testHttpChunkedWithAwsChunkedEncodingUnsignedPayloadSuccess(MultiMap extraHeaders, Optional> trailerHeaders, Optional expectedContent) + throws IOException + { + String bucket = "test-http-chunked-aws-chunked-header"; + testHttpChunkedWithAwsChunkedEncodingUnsignedPayload(bucket, extraHeaders, trailerHeaders, 200); + assertThat(getFileFromStorage(storageClient, bucket, "basic-upload")).isEqualTo(LOREM_IPSUM); + if (expectedContent.isPresent()) { + assertThat(headObjectInStorage(storageClient, bucket, "basic-upload").contentEncoding()).isEqualTo(expectedContent.get()); + } + else { + assertThat(headObjectInStorage(storageClient, bucket, "basic-upload").contentEncoding()).isNullOrEmpty(); + } + } + + private void testHttpChunkedWithAwsChunkedEncodingUnsignedPayload(String bucket, MultiMap extraHeaders, Optional> trailerHeaders, int expectedReturnCode) + { + storageClient.createBucket(r -> r.bucket(bucket).build()); + ImmutableMultiMap.Builder headersBuilder = ImmutableMultiMap.builder(false) + .add("X-Amz-Content-Sha256", "STREAMING-UNSIGNED-PAYLOAD-TRAILER") + .add("Content-Encoding", "aws-chunked"); + extraHeaders.forEach(headersBuilder::addAll); + trailerHeaders.ifPresent(e -> headersBuilder.addAll("x-amz-trailer", buildTrailerHeaderValues(e))); + + assertThat(doCustomHttpChunkedUpload( + bucket, + "basic-upload", + 3, + headersBuilder.build(), + LOREM_IPSUM.length(), + signature -> generateUnsignedChunkedStream(LOREM_IPSUM, 3, trailerHeaders.map(this::buildTrailerChunk))) + ).isEqualTo(expectedReturnCode); + } + + private void testHttpChunked(String bucket, String content, String sha256, int partitionCount) + throws IOException + { + assertThat(doHttpChunkedUpload( + bucket, + "basic-upload", + content, + partitionCount, + ImmutableMultiMap.builder(false) + .add("X-Amz-Content-Sha256", sha256).build())).isEqualTo(200); + assertThat(getFileFromStorage(storageClient, bucket, "basic-upload")).isEqualTo(content); + HeadObjectResponse basicUpload = headObjectInStorage(storageClient, bucket, "basic-upload"); + assertThat(basicUpload.contentEncoding()).isNullOrEmpty(); + assertThat(basicUpload.metadata()).isEmpty(); + + assertThat(doHttpChunkedUpload( + bucket, + "with-content-type", + content, + partitionCount, + ImmutableMultiMap.builder(false) + .add("X-Amz-Content-Sha256", sha256) + .add("Content-Type", TEST_CONTENT_TYPE) + .add("Content-Encoding", "gzip,compress") + .add("x-amz-meta-foobar", "baz") + .build())).isEqualTo(200); + assertThat(getFileFromStorage(storageClient, bucket, "with-content-type")).isEqualTo(content); + HeadObjectResponse withFields = headObjectInStorage(storageClient, bucket, "with-content-type"); + assertThat(withFields.contentType()).contains(TEST_CONTENT_TYPE); + assertThat(withFields.contentEncoding()).isEqualTo("gzip,compress"); + assertThat(withFields.metadata()).containsEntry("foobar", "baz"); + } + private int doHttpChunkedUpload(String bucket, String key, String content, int chunkCount, MultiMap extraHeaders) { return doCustomHttpChunkedUpload(bucket, key, chunkCount, extraHeaders, content.length(), _ -> content); diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/rest/TestAwsChunkedInputStream.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/rest/TestAwsChunkedInputStream.java index 2115bac1..88831783 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/rest/TestAwsChunkedInputStream.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/rest/TestAwsChunkedInputStream.java @@ -29,6 +29,8 @@ import java.io.IOException; import java.io.InputStream; import java.io.UncheckedIOException; +import java.util.List; +import java.util.Optional; import java.util.stream.Stream; import static java.nio.charset.StandardCharsets.UTF_8; @@ -68,35 +70,51 @@ public void testGood() throws IOException { TestingChunkSigningSession session = goodTestSigningSession(); - String chunkedStream = session.generateChunkedStream(GOOD_CONTENT, 3); + String chunkedStream = session.generateChunkedStream(GOOD_CONTENT, 3, Optional.empty()); - assertThat(readChunked(chunkedStream, session)).isEqualTo(GOOD_CONTENT.getBytes(UTF_8)); + assertThat(readChunked(chunkedStream, session, ImmutableList.of())).isEqualTo(GOOD_CONTENT.getBytes(UTF_8)); } @Test public void testRecreateSessionValidatesGoodPayload() throws IOException { - String chunkedStream = goodTestSigningSession().generateChunkedStream(GOOD_CONTENT, 3); + String chunkedStream = goodTestSigningSession().generateChunkedStream(GOOD_CONTENT, 3, Optional.empty()); - assertThat(readChunked(chunkedStream, goodTestSigningSession())).isEqualTo(GOOD_CONTENT.getBytes(UTF_8)); + assertThat(readChunked(chunkedStream, goodTestSigningSession(), ImmutableList.of())).isEqualTo(GOOD_CONTENT.getBytes(UTF_8)); + } + + @Test + public void testRecreateSessionValidatesGoodPayloadWithTrailerHeader() + throws IOException + { + String chunkedStream = goodTestSigningSession().generateChunkedStream(GOOD_CONTENT, 3, Optional.of("x-amz-header1:foo\r\n")); + assertThat(readChunked(chunkedStream, goodTestSigningSession(), ImmutableList.of("x-amz-header1", "x-amz-trailer-signature"))).isEqualTo(GOOD_CONTENT.getBytes(UTF_8)); + } + + @Test + public void testRecreateSessionValidatesGoodPayloadWithTrailerHeaders() + throws IOException + { + String chunkedStream = goodTestSigningSession().generateChunkedStream(GOOD_CONTENT, 3, Optional.of("x-amz-header2:foo\r\nx-amz-header1:foo\r\n")); + assertThat(readChunked(chunkedStream, goodTestSigningSession(), ImmutableList.of("x-amz-header1", "x-amz-header2", "x-amz-trailer-signature"))).isEqualTo(GOOD_CONTENT.getBytes(UTF_8)); } @Test public void testBadSeed() { - String chunkedStream = goodTestSigningSession().generateChunkedStream(GOOD_CONTENT, 3); + String chunkedStream = goodTestSigningSession().generateChunkedStream(GOOD_CONTENT, 3, Optional.empty()); - assertThatThrownBy(() -> readChunked(chunkedStream, TestingChunkSigningSession.build(GOOD_CREDENTIAL, BAD_SEED))) + assertThatThrownBy(() -> readChunked(chunkedStream, TestingChunkSigningSession.build(GOOD_CREDENTIAL, BAD_SEED), ImmutableList.of())) .isInstanceOf(WebApplicationException.class); } @Test public void testBadCredential() { - String chunkedStream = goodTestSigningSession().generateChunkedStream(GOOD_CONTENT, 3); + String chunkedStream = goodTestSigningSession().generateChunkedStream(GOOD_CONTENT, 3, Optional.empty()); - assertThatThrownBy(() -> readChunked(chunkedStream, TestingChunkSigningSession.build(BAD_CREDENTIAL, BAD_SEED))) + assertThatThrownBy(() -> readChunked(chunkedStream, TestingChunkSigningSession.build(BAD_CREDENTIAL, BAD_SEED), ImmutableList.of())) .isInstanceOf(WebApplicationException.class); } @@ -104,10 +122,10 @@ public void testBadCredential() public void testMultipleExtensions() throws IOException { - String chunkedStream = goodTestSigningSession().generateChunkedStream(GOOD_CONTENT, 3); + String chunkedStream = goodTestSigningSession().generateChunkedStream(GOOD_CONTENT, 3, Optional.empty()); chunkedStream = chunkedStream.replace(";chunk-signature=", ";foo=bar;chunk-signature="); - assertThat(readChunked(chunkedStream, goodTestSigningSession())).isEqualTo(GOOD_CONTENT.getBytes(UTF_8)); + assertThat(readChunked(chunkedStream, goodTestSigningSession(), ImmutableList.of())).isEqualTo(GOOD_CONTENT.getBytes(UTF_8)); } @Test @@ -251,7 +269,7 @@ private static void tryReadAwsChunkedDataBatch(String chunkedData, int decodedCo throws IOException { int remainingBytes = decodedContentLength; - try (InputStream in = new AwsChunkedInputStream(new ByteArrayInputStream(chunkedData.getBytes(UTF_8)), signingSession, decodedContentLength)) { + try (InputStream in = new AwsChunkedInputStream(new ByteArrayInputStream(chunkedData.getBytes(UTF_8)), Optional.of(signingSession), decodedContentLength, ImmutableList.of())) { while (remainingBytes > 0) { byte[] readBytes = new byte[bytesToReadAtATime]; int count = in.read(readBytes, 0, bytesToReadAtATime); @@ -268,7 +286,7 @@ private static void tryReadAwsChunkedData(String chunkedData, int decodedContent throws IOException { int remainingBytes = decodedContentLength; - try (InputStream in = new AwsChunkedInputStream(new ByteArrayInputStream(chunkedData.getBytes(UTF_8)), signingSession, decodedContentLength)) { + try (InputStream in = new AwsChunkedInputStream(new ByteArrayInputStream(chunkedData.getBytes(UTF_8)), Optional.of(signingSession), decodedContentLength, ImmutableList.of())) { while (remainingBytes-- > 0) { int readByte = in.read(); if (readByte == -1) { @@ -286,10 +304,10 @@ private static void testIllegalAwsChunkedData(String chunkedData, int decodedCon assertThat(testOutput.toByteArray().length).isLessThan(decodedContentLength); } - private static byte[] readChunked(String chunkedStream, TestingChunkSigningSession signingSession) + private static byte[] readChunked(String chunkedStream, TestingChunkSigningSession signingSession, List trailerHeaders) throws IOException { - try (InputStream in = new AwsChunkedInputStream(new ByteArrayInputStream(chunkedStream.getBytes(UTF_8)), signingSession, chunkedStream.length())) { + try (InputStream in = new AwsChunkedInputStream(new ByteArrayInputStream(chunkedStream.getBytes(UTF_8)), Optional.of(signingSession), chunkedStream.length(), ImmutableList.copyOf(trailerHeaders))) { return ByteStreams.toByteArray(in); } } @@ -308,7 +326,7 @@ public void testChunkedInputStreamLargeBuffer() { byte[] rawBytes = CHUNKED_INPUT.getBytes(UTF_8); ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes); - InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession(), rawBytes.length); + InputStream in = new AwsChunkedInputStream(inputStream, Optional.of(new DummyChunkSigningSession()), rawBytes.length, ImmutableList.of()); byte[] buffer = new byte[300]; ByteArrayOutputStream out = new ByteArrayOutputStream(); int len; @@ -331,7 +349,7 @@ public void testChunkedInputStreamSmallBuffer() { byte[] rawBytes = CHUNKED_INPUT.getBytes(UTF_8); ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes); - InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession(), rawBytes.length); + InputStream in = new AwsChunkedInputStream(inputStream, Optional.of(new DummyChunkSigningSession()), rawBytes.length, ImmutableList.of()); byte[] buffer = new byte[7]; ByteArrayOutputStream out = new ByteArrayOutputStream(); @@ -354,7 +372,7 @@ public void testChunkedInputStreamOneByteRead() String s = "5;chunk-signature=0\r\n01234\r\n5;chunk-signature=0\r\n56789\r\n0;chunk-signature=0\r\n\r\n"; byte[] rawBytes = s.getBytes(UTF_8); ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes); - InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession(), rawBytes.length); + InputStream in = new AwsChunkedInputStream(inputStream, Optional.of(new DummyChunkSigningSession()), rawBytes.length, ImmutableList.of()); int ch; int i = '0'; while ((ch = in.read()) != -1) { @@ -374,7 +392,7 @@ public void testChunkedInputStreamNoClosingChunk() String s = "5;chunk-signature=0\r\n01234\r\n"; byte[] rawBytes = s.getBytes(UTF_8); ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes); - InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession(), rawBytes.length); + InputStream in = new AwsChunkedInputStream(inputStream, Optional.of(new DummyChunkSigningSession()), rawBytes.length, ImmutableList.of()); byte[] tmp = new byte[5]; // altered from original test. Our AwsChunkedInputStream is improved and throws when the final chunk is missing or bad assertThrows(WebApplicationException.class, () -> in.read(tmp)); @@ -390,7 +408,7 @@ public void testCorruptChunkedInputStreamTruncatedCRLF() .forEach(s -> { byte[] rawBytes = s.getBytes(UTF_8); ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes); - InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession(), rawBytes.length); + InputStream in = new AwsChunkedInputStream(inputStream, Optional.of(new DummyChunkSigningSession()), rawBytes.length, ImmutableList.of()); byte[] tmp = new byte[5]; // altered from original test. Our AwsChunkedInputStream is improved and throws when the final chunk is missing or bad assertThrows(WebApplicationException.class, () -> in.read(tmp)); @@ -410,7 +428,7 @@ public void testCorruptChunkedInputStreamMissingCRLF() String s = "5;chunk-signature=0\r\n012345\r\n56789\r\n0;chunk-signature=0\r\n\r\n"; byte[] rawBytes = s.getBytes(UTF_8); ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes); - InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession(), rawBytes.length); + InputStream in = new AwsChunkedInputStream(inputStream, Optional.of(new DummyChunkSigningSession()), rawBytes.length, ImmutableList.of()); byte[] buffer = new byte[300]; ByteArrayOutputStream out = new ByteArrayOutputStream(); assertThrows(WebApplicationException.class, () -> { @@ -430,7 +448,7 @@ public void testCorruptChunkedInputStreamMissingLF() String s = "5;chunk-signature=0\r01234\r\n5;chunk-signature=0\r\n56789\r\n0;chunk-signature=0\r\n\r\n"; byte[] rawBytes = s.getBytes(UTF_8); ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes); - InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession(), rawBytes.length); + InputStream in = new AwsChunkedInputStream(inputStream, Optional.of(new DummyChunkSigningSession()), rawBytes.length, ImmutableList.of()); assertThrows(WebApplicationException.class, in::read); in.close(); } @@ -443,7 +461,7 @@ public void testCorruptChunkedInputStreamInvalidSize() String s = "whatever;chunk-signature=0\r\n01234\r\n5;chunk-signature=0\r\n56789\r\n0;chunk-signature=0\r\n\r\n"; byte[] rawBytes = s.getBytes(UTF_8); ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes); - InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession(), rawBytes.length); + InputStream in = new AwsChunkedInputStream(inputStream, Optional.of(new DummyChunkSigningSession()), rawBytes.length, ImmutableList.of()); assertThrows(WebApplicationException.class, in::read); in.close(); } @@ -456,7 +474,7 @@ public void testCorruptChunkedInputStreamNegativeSize() String s = "-5;chunk-signature=0\r\n01234\r\n5;chunk-signature=0\r\n56789\r\n0;chunk-signature=0\r\n\r\n"; byte[] rawBytes = s.getBytes(UTF_8); ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes); - InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession(), rawBytes.length); + InputStream in = new AwsChunkedInputStream(inputStream, Optional.of(new DummyChunkSigningSession()), rawBytes.length, ImmutableList.of()); assertThrows(WebApplicationException.class, in::read); in.close(); } @@ -469,7 +487,7 @@ public void testCorruptChunkedInputStreamTruncatedChunk() String s = "3;chunk-signature=0\r\n12"; byte[] rawBytes = s.getBytes(UTF_8); ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes); - InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession(), rawBytes.length); + InputStream in = new AwsChunkedInputStream(inputStream, Optional.of(new DummyChunkSigningSession()), rawBytes.length, ImmutableList.of()); byte[] buffer = new byte[300]; assertEquals(2, in.read(buffer)); assertThrows(WebApplicationException.class, () -> in.read(buffer)); @@ -482,7 +500,7 @@ public void testCorruptChunkedInputStreamClose() String s = "whatever;chunk-signature=0\r\n01234\r\n5;chunk-signature=0\r\n56789\r\n0;chunk-signature=0\r\n\r\n"; byte[] rawBytes = s.getBytes(UTF_8); ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes); - InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession(), rawBytes.length); + InputStream in = new AwsChunkedInputStream(inputStream, Optional.of(new DummyChunkSigningSession()), rawBytes.length, ImmutableList.of()); assertThrows(WebApplicationException.class, in::read); } @@ -493,7 +511,7 @@ public void testEmptyChunkedInputStream() String s = "0;chunk-signature=0\r\n\r\n"; byte[] rawBytes = s.getBytes(UTF_8); ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes); - InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession(), rawBytes.length); + InputStream in = new AwsChunkedInputStream(inputStream, Optional.of(new DummyChunkSigningSession()), rawBytes.length, ImmutableList.of()); byte[] buffer = new byte[300]; ByteArrayOutputStream out = new ByteArrayOutputStream(); int len; @@ -511,7 +529,7 @@ public void testHugeChunk() { byte[] rawBytes = "499602D2;chunk-signature=0\r\n01234567".getBytes(UTF_8); ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes); - InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession(), 1234567890); + InputStream in = new AwsChunkedInputStream(inputStream, Optional.of(new DummyChunkSigningSession()), 1234567890, ImmutableList.of()); ByteArrayOutputStream out = new ByteArrayOutputStream(); for (int i = 0; i < 8; ++i) { @@ -522,6 +540,238 @@ public void testHugeChunk() assertEquals("01234567", result); } + @Test + public void testChunkedInputStreamWithTrailerHeaderChunk() + throws IOException + { + String s = "5;chunk-signature=0\r\n01234\r\n5;chunk-signature=0\r\n56789\r\n0;chunk-signature=0\r\nx-amz-header1:foo\r\nx-amz-trailer-signature:0\r\n\r\n"; + byte[] rawBytes = s.getBytes(UTF_8); + ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes); + InputStream in = new AwsChunkedInputStream(inputStream, Optional.of(new DummyChunkSigningSession()), rawBytes.length, ImmutableList.of("x-amz-header1", "x-amz-trailer-signature")); + + byte[] buffer = new byte[7]; + ByteArrayOutputStream out = new ByteArrayOutputStream(); + int len; + while ((len = in.read(buffer)) > 0) { + out.write(buffer, 0, len); + } + assertEquals(-1, in.read(buffer)); + assertEquals(-1, in.read(buffer)); + + String result = out.toString(UTF_8); + assertEquals("0123456789", result); + + in.close(); + } + + @Test + public void testChunkedInputStreamWithTrailerHeaderChunkMultipleTrailers() + throws IOException + { + String s = "5;chunk-signature=0\r\n01234\r\n5;chunk-signature=0\r\n56789\r\n0;chunk-signature=0\r\nx-amz-header1:foo\r\nx-amz-header2:foo\r\nx-amz-trailer-signature:0\r\n\r\n"; + byte[] rawBytes = s.getBytes(UTF_8); + ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes); + InputStream in = new AwsChunkedInputStream(inputStream, Optional.of(new DummyChunkSigningSession()), rawBytes.length, ImmutableList.of("x-amz-header1", "x-amz-header2", "x-amz-trailer-signature")); + + byte[] buffer = new byte[7]; + ByteArrayOutputStream out = new ByteArrayOutputStream(); + int len; + while ((len = in.read(buffer)) > 0) { + out.write(buffer, 0, len); + } + assertEquals(-1, in.read(buffer)); + assertEquals(-1, in.read(buffer)); + + String result = out.toString(UTF_8); + assertEquals("0123456789", result); + + in.close(); + } + + @Test + public void testCorruptChunkedInputStreamWithTrailerHeaderChunkNoneExpected() + throws IOException + { + String s = "5;chunk-signature=0\r\n01234\r\n5;chunk-signature=0\r\n56789\r\n0;chunk-signature=0\r\nx-amz-header1:foo\r\nx-amz-trailer-signature:0\r\n\r\n"; + byte[] rawBytes = s.getBytes(UTF_8); + ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes); + InputStream in = new AwsChunkedInputStream(inputStream, Optional.of(new DummyChunkSigningSession()), rawBytes.length, ImmutableList.of()); + + byte[] buffer = new byte[7]; + assertThrows(WebApplicationException.class, () -> { + while (in.read(buffer) > 0) { + // do nothing + } + }); + in.close(); + } + + @Test + public void testCorruptChunkedInputStreamWithTrailerHeaderSignatureBeforeHeader() + throws IOException + { + String s = "5;chunk-signature=0\r\n01234\r\n5;chunk-signature=0\r\n56789\r\n0;chunk-signature=0\r\nx-amz-trailer-signature:0\r\nx-amz-header1:foo\r\n\r\n"; + byte[] rawBytes = s.getBytes(UTF_8); + ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes); + InputStream in = new AwsChunkedInputStream(inputStream, Optional.of(new DummyChunkSigningSession()), rawBytes.length, ImmutableList.of("x-amz-trailer-signature", "x-amz-header1")); + + byte[] buffer = new byte[7]; + assertThrows(WebApplicationException.class, () -> { + while (in.read(buffer) > 0) { + // do nothing + } + }); + in.close(); + } + + @Test + public void testCorruptChunkedInputStreamWithTrailerHeaderChunkNoSignature() + throws IOException + { + String s = "5;chunk-signature=0\r\n01234\r\n5;chunk-signature=0\r\n56789\r\n0;chunk-signature=0\r\nx-amz-header1:foo\r\n\r\n"; + byte[] rawBytes = s.getBytes(UTF_8); + ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes); + InputStream in = new AwsChunkedInputStream(inputStream, Optional.of(new DummyChunkSigningSession()), rawBytes.length, ImmutableList.of("x-amz-header1")); + + byte[] buffer = new byte[7]; + assertThrows(WebApplicationException.class, () -> { + while (in.read(buffer) > 0) { + // do nothing + } + }); + in.close(); + } + + @Test + public void testChunkedInputStreamUnsigned() + throws IOException + { + String s = "5\r\n01234\r\n5\r\n56789\r\n0\r\n\r\n"; + byte[] rawBytes = s.getBytes(UTF_8); + ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes); + InputStream in = new AwsChunkedInputStream(inputStream, Optional.empty(), rawBytes.length, ImmutableList.of()); + + byte[] buffer = new byte[7]; + ByteArrayOutputStream out = new ByteArrayOutputStream(); + int len; + while ((len = in.read(buffer)) > 0) { + out.write(buffer, 0, len); + } + assertEquals(-1, in.read(buffer)); + assertEquals(-1, in.read(buffer)); + + String result = out.toString(UTF_8); + assertEquals("0123456789", result); + + in.close(); + } + + @Test + public void testChunkedInputStreamUnsignedWithTrailerHeaderChunk1Header() + throws IOException + { + String s = "5\r\n01234\r\n5\r\n56789\r\n0\r\nx-amz-header1:val\r\n\r\n"; + byte[] rawBytes = s.getBytes(UTF_8); + ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes); + InputStream in = new AwsChunkedInputStream(inputStream, Optional.empty(), rawBytes.length, ImmutableList.of("x-amz-header1")); + + byte[] buffer = new byte[7]; + ByteArrayOutputStream out = new ByteArrayOutputStream(); + int len; + while ((len = in.read(buffer)) > 0) { + out.write(buffer, 0, len); + } + assertEquals(-1, in.read(buffer)); + assertEquals(-1, in.read(buffer)); + + String result = out.toString(UTF_8); + assertEquals("0123456789", result); + + in.close(); + } + + @Test + public void testChunkedInputStreamUnsignedWithTrailerHeaderChunkMultipleHeader() + throws IOException + { + String s = "5\r\n01234\r\n5\r\n56789\r\n0\r\nx-amz-header1:val\r\nx-amz-header2:val\r\n\r\n"; + byte[] rawBytes = s.getBytes(UTF_8); + ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes); + InputStream in = new AwsChunkedInputStream(inputStream, Optional.empty(), rawBytes.length, ImmutableList.of("x-amz-header1", "x-amz-header2")); + + byte[] buffer = new byte[7]; + ByteArrayOutputStream out = new ByteArrayOutputStream(); + int len; + while ((len = in.read(buffer)) > 0) { + out.write(buffer, 0, len); + } + assertEquals(-1, in.read(buffer)); + assertEquals(-1, in.read(buffer)); + + String result = out.toString(UTF_8); + assertEquals("0123456789", result); + + in.close(); + } + + @Test + public void testChunkedInputStreamUnsignedWithTrailerHeaderChunkMultipleHeaderOutOfOrder() + throws IOException + { + String s = "5\r\n01234\r\n5\r\n56789\r\n0\r\nx-amz-header1:val\r\nx-amz-header2:val\r\n\r\n"; + byte[] rawBytes = s.getBytes(UTF_8); + ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes); + InputStream in = new AwsChunkedInputStream(inputStream, Optional.empty(), rawBytes.length, ImmutableList.of("x-amz-header2", "x-amz-header1")); + + byte[] buffer = new byte[7]; + ByteArrayOutputStream out = new ByteArrayOutputStream(); + int len; + while ((len = in.read(buffer)) > 0) { + out.write(buffer, 0, len); + } + assertEquals(-1, in.read(buffer)); + assertEquals(-1, in.read(buffer)); + + String result = out.toString(UTF_8); + assertEquals("0123456789", result); + + in.close(); + } + + @Test + public void testCorruptUnsignedChunkedInputStreamTrailerHeadersWithNoneExpected() + throws IOException + { + String s = "5\r\n01234\r\n5\r\n56789\r\n0\r\nx-amz-trailer:val\r\n\r\n"; + byte[] rawBytes = s.getBytes(UTF_8); + ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes); + InputStream in = new AwsChunkedInputStream(inputStream, Optional.empty(), rawBytes.length, ImmutableList.of()); + byte[] buffer = new byte[7]; + assertThrows(WebApplicationException.class, () -> { + while (in.read(buffer) > 0) { + // do nothing + } + }); + in.close(); + } + + @Test + public void testCorruptUnsignedChunkedInputStreamMissingTrailerHeaderChunk() + throws IOException + { + String s = "5\r\n01234\r\n5\r\n56789\r\n0\r\n\r\n"; + byte[] rawBytes = s.getBytes(UTF_8); + ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes); + InputStream in = new AwsChunkedInputStream(inputStream, Optional.empty(), rawBytes.length, ImmutableList.of("x-amz-trailer-foo")); + byte[] buffer = new byte[7]; + assertThrows(WebApplicationException.class, () -> { + while (in.read(buffer) > 0) { + // do nothing + } + }); + in.close(); + } + private static class DummyChunkSigningSession implements ChunkSigningSession { diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/rest/TestRequestHeadersBuilder.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/rest/TestRequestHeadersBuilder.java index 51a65eaa..edc2c9a5 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/rest/TestRequestHeadersBuilder.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/rest/TestRequestHeadersBuilder.java @@ -25,6 +25,8 @@ import java.util.Optional; +import static io.trino.aws.proxy.spi.rest.RequestContent.ContentType.AWS_CHUNKED_IN_W3C_CHUNKED_UNSIGNED; +import static io.trino.aws.proxy.spi.rest.RequestContent.ContentType.AWS_CHUNKED_UNSIGNED; import static io.trino.aws.proxy.spi.rest.RequestContent.ContentType.W3C_CHUNKED; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -83,9 +85,9 @@ public void testBuildHeadersHttpAndAwsChunked() private void testBuildHeadersAwsChunkedPayload(MultiMap baseHeaders, ContentType expectedContentType) { assertThatThrownBy(() -> doBuildHeaders(mergeMaps(baseHeaders, ImmutableMultiMap.builder(false) - .add("X-Amz-Content-Sha256", "UNSIGNED-PAYLOAD") - .add("Content-Encoding", "aws-chunked") - .build()))).isInstanceOf(WebApplicationException.class); + .add("X-Amz-Content-Sha256", "UNSIGNED-PAYLOAD") + .add("Content-Encoding", "aws-chunked") + .build()))).isInstanceOf(WebApplicationException.class); testBuildHeaders( mergeMaps( @@ -121,6 +123,52 @@ private void testBuildHeadersAwsChunkedPayload(MultiMap baseHeaders, ContentType Optional.of(expectedContentType)); } + @Test + public void testBuildHeadersHttpAndAwsChunkedUnsignedPayload() + { + // Testing our corner case where we want to handle aws-chunked requests with a STREAMING-UNSIGNED-PAYLOAD hash + // as if they were W3C chunked, and not aws-chunked. + ImmutableMultiMap baseHeaders = ImmutableMultiMap.builder(false) + .add("transfer-encoding", "chunked") + .add("content-encoding", "aws-chunked") + .add("Content-Encoding", "gzip") + .add("x-amz-trailer", "x-amz-checksum-crc32") + .add("x-amz-content-sha256", "STREAMING-UNSIGNED-PAYLOAD-TRAILER") + .add("x-amz-decoded-content-length", "1234") + .build(); + + testBuildHeaders( + mergeMaps( + baseHeaders, + ImmutableMultiMap.builder(false) + .build()), + ImmutableMultiMap.builder(false).add("Content-Encoding", "gzip").build(), + Optional.of(AWS_CHUNKED_IN_W3C_CHUNKED_UNSIGNED)); + } + + @Test + public void testBuildHeadersAwsChunkedUnsignedPayload() + { + // Testing our corner case where we want to handle aws-chunked requests with a STREAMING-UNSIGNED-PAYLOAD hash + // as if they were W3C chunked, and not aws-chunked. + ImmutableMultiMap baseHeaders = ImmutableMultiMap.builder(false) + .add("content-encoding", "aws-chunked") + .add("Content-Encoding", "gzip") + .add("x-amz-trailer", "x-amz-checksum-crc32") + .add("x-amz-content-sha256", "STREAMING-UNSIGNED-PAYLOAD-TRAILER") + .add("Content-Length", "1234") + .add("X-Amz-Decoded-Content-Length", "1234") + .build(); + + testBuildHeaders( + mergeMaps( + baseHeaders, + ImmutableMultiMap.builder(false) + .build()), + ImmutableMultiMap.builder(false).add("Content-Encoding", "gzip").build(), + Optional.of(AWS_CHUNKED_UNSIGNED)); + } + @Test public void testBuildHeadersHttpChunked() { @@ -129,13 +177,6 @@ public void testBuildHeadersHttpChunked() .add("Transfer-Encoding", "chunked") .add("X-Amz-Decoded-Content-Length", "1000") .build(); - testBuildHeaders(baseHttpChunkedHeaders, ImmutableMultiMap.empty(), Optional.of(W3C_CHUNKED)); - - MultiMap metadataHeaders = ImmutableMultiMap.builder(false).add("X-Amz-Some-Metadata", "foo").build(); - testBuildHeaders( - mergeMaps(baseHttpChunkedHeaders, metadataHeaders), - metadataHeaders, - Optional.of(W3C_CHUNKED)); testBuildHeaders( mergeMaps( diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/signing/TestingChunkSigningSession.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/signing/TestingChunkSigningSession.java index 55bdd9ba..19e73a75 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/signing/TestingChunkSigningSession.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/signing/TestingChunkSigningSession.java @@ -26,7 +26,9 @@ import java.time.Instant; import java.time.ZoneId; import java.time.format.DateTimeFormatter; +import java.util.Arrays; import java.util.Locale; +import java.util.Optional; import java.util.UUID; import static com.google.common.base.Preconditions.checkArgument; @@ -90,7 +92,7 @@ public static int getExpectedChunkedStreamSize(String rawContent, int partitions } @SuppressWarnings("UnstableApiUsage") - public String generateChunkedStream(String content, int partitions) + public String generateChunkedStream(String content, int partitions, Optional trailerHeaders) { checkArgument(partitions > 1, "partitions must be greater than 1"); @@ -112,7 +114,16 @@ public String generateChunkedStream(String content, int partitions) } String thisSignature = chunkSigner.signChunk(Hashing.sha256().newHasher().hash(), previousSignature); - chunkedStream.append("0;chunk-signature=").append(thisSignature).append("\r\n\r\n"); + chunkedStream.append("0;chunk-signature=").append(thisSignature).append("\r\n"); + + trailerHeaders.ifPresent(headers -> { + chunkedStream.append(headers); + String trailerSignature = getChunkSignature(Arrays.stream(headers.split("\r\n")).reduce("", (prev, next) -> prev + next), thisSignature); + chunkedStream.append("x-amz-trailer-signature:").append(trailerSignature).append("\r\n"); + }); + + // Mark end of entire Streaming + chunkedStream.append("\r\n"); return chunkedStream.toString(); }