Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
*/
package io.trino.aws.proxy.spi.rest;

import com.google.common.collect.ImmutableList;

import java.io.InputStream;
import java.util.List;
import java.util.Optional;

@FunctionalInterface
Expand All @@ -27,7 +30,9 @@ enum ContentType
STANDARD,
W3C_CHUNKED,
AWS_CHUNKED,
AWS_CHUNKED_UNSIGNED,
AWS_CHUNKED_IN_W3C_CHUNKED,
AWS_CHUNKED_IN_W3C_CHUNKED_UNSIGNED,
}

default ContentType contentType()
Expand All @@ -48,5 +53,10 @@ default Optional<Integer> contentLength()
return Optional.empty();
}

default List<String> trailerHeaders()
{
return ImmutableList.of();
}

Optional<InputStream> inputStream();
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
package io.trino.aws.proxy.server.rest;

import com.google.common.base.Splitter;
import com.google.common.collect.ImmutableList;
import io.trino.aws.proxy.spi.signing.ChunkSigningSession;
import jakarta.ws.rs.WebApplicationException;

import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand All @@ -30,7 +32,7 @@ class AwsChunkedInputStream
extends InputStream
{
private final InputStream delegate;
private final ChunkSigningSession chunkSigningSession;
private final Optional<ChunkSigningSession> chunkSigningSession;

private enum State
{
Expand All @@ -44,12 +46,14 @@ private enum State
private int bytesRemainingInChunk;
private int bytesAccountedFor;
private final int decodedContentLength;
private final List<String> trailerHeaders;

AwsChunkedInputStream(InputStream delegate, ChunkSigningSession chunkSigningSession, int decodedContentLength)
AwsChunkedInputStream(InputStream delegate, Optional<ChunkSigningSession> chunkSigningSession, int decodedContentLength, List<String> trailerHeaders)
{
this.delegate = requireNonNull(delegate, "delegate is null");
this.chunkSigningSession = requireNonNull(chunkSigningSession, "chunkSigningSession is null");
this.decodedContentLength = decodedContentLength;
this.trailerHeaders = requireNonNull(ImmutableList.copyOf(trailerHeaders), "trailerHeaders is null");
}

@Override
Expand All @@ -65,7 +69,7 @@ public int read()
throw new WebApplicationException("Unexpected end of stream", BAD_REQUEST);
}

chunkSigningSession.write((byte) (i & 0xff));
chunkSigningSession.ifPresent(chunkSigningSession -> chunkSigningSession.write((byte) (i & 0xff)));
updateBytesRemaining(1);

return i;
Expand All @@ -86,7 +90,7 @@ public int read(byte[] b, int off, int len)
throw new WebApplicationException("Unexpected end of stream", BAD_REQUEST);
}

chunkSigningSession.write(b, off, count);
chunkSigningSession.ifPresent(chunkSigningSession -> chunkSigningSession.write(b, off, count));
updateBytesRemaining(count);

return count;
Expand Down Expand Up @@ -155,9 +159,6 @@ private void nextChunk()
boolean success = false;
do {
List<String> parts = Splitter.on(';').trimResults().limit(2).splitToList(header);
if (parts.size() != 2) {
break;
}

int chunkSize;
try {
Expand All @@ -170,23 +171,41 @@ private void nextChunk()
break;
}

Optional<String> chunkSignature = Splitter.on(';').trimResults().withKeyValueSeparator('=').split(parts.get(1))
.entrySet()
.stream()
.filter(entry -> entry.getKey().equalsIgnoreCase("chunk-signature"))
.map(Map.Entry::getValue)
.findFirst();
if (chunkSigningSession.isPresent()) {
if (parts.size() != 2) {
break;
}

Optional<String> chunkSignature = Splitter.on(';').trimResults().withKeyValueSeparator('=').split(parts.get(1))
.entrySet()
.stream()
.filter(entry -> entry.getKey().equalsIgnoreCase("chunk-signature"))
.map(Map.Entry::getValue)
.findFirst();

if (chunkSignature.isEmpty()) {
break;
if (chunkSignature.isEmpty()) {
break;
}

chunkSigningSession.get().startChunk(chunkSignature.get());
}
else {
if (parts.size() != 1) {
break;
}
}

chunkSigningSession.startChunk(chunkSignature.get());
bytesRemainingInChunk = chunkSize;

if (chunkSize == 0) {
readEmptyLine();
chunkSigningSession.complete();
if (trailerHeaders.isEmpty()) {
readEmptyLine();
chunkSigningSession.ifPresent(ChunkSigningSession::complete);
}
else {
readTrailingHeaders();
readEmptyLine();
}
state = State.LAST_CHUNK;
}
bytesAccountedFor += chunkSize;
Expand Down Expand Up @@ -236,4 +255,44 @@ private String readLine()

return line.toString();
}

private TrailerHeaderChunk readTrailingHeadersChunk()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): The logic for parsing trailer headers may break early if x-amz-trailer-signature is found before all expected headers.

The loop may exit before processing all expected trailer headers if the signature is encountered early. Please verify whether all headers should be read before breaking, or if the signature is always last.

throws IOException
{
Optional<String> signature = Optional.empty();
StringBuilder trailerHeadersChunkBuilder = new StringBuilder();
for (int i = 0; i < this.trailerHeaders.size(); i++) {
String trailerHeaders = readLine();
List<String> trailerHeadersValues = Splitter.on(":").trimResults().limit(2).splitToList(trailerHeaders);
String trailerHeaderName = trailerHeadersValues.getFirst();
if ((trailerHeadersValues.size() != 2) || !this.trailerHeaders.contains(trailerHeaderName)) {
throw new WebApplicationException("Trailer header is invalid: " + trailerHeaders, BAD_REQUEST);
Comment on lines +268 to +269
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (bug_risk): The check for trailer header validity may be too strict if header names are case-insensitive.

Normalize both the input and expected header names to lower-case before comparison to ensure case-insensitive matching.

Suggested implementation:

            String trailerHeaders = readLine();
            List<String> trailerHeadersValues = Splitter.on(":").trimResults().limit(2).splitToList(trailerHeaders);
            String trailerHeaderName = trailerHeadersValues.getFirst().toLowerCase();
            boolean validTrailerHeader = this.trailerHeaders.stream()
                .map(String::toLowerCase)
                .anyMatch(h -> h.equals(trailerHeaderName));
            if ((trailerHeadersValues.size() != 2) || !validTrailerHeader) {
                throw new WebApplicationException("Trailer header is invalid: " + trailerHeaders, BAD_REQUEST);
            }
            if (trailerHeaderName.equals("x-amz-trailer-signature")) {
            if (trailerHeaderName.equals("x-amz-trailer-signature")) {
                signature = Optional.of(trailerHeadersValues.getLast());
                break;
            }
            else {
                trailerHeadersChunkBuilder.append(trailerHeaders);
            }

}
if (trailerHeaderName.equals("x-amz-trailer-signature")) {
signature = Optional.of(trailerHeadersValues.getLast());
break;
}
else {
trailerHeadersChunkBuilder.append(trailerHeaders);
}
}
return new TrailerHeaderChunk(trailerHeadersChunkBuilder.toString(), signature);
}

private void readTrailingHeaders()
throws IOException
{
TrailerHeaderChunk trailerHeaderChunk = readTrailingHeadersChunk();
chunkSigningSession.ifPresent(chunkSigningSession -> {
if (trailerHeaderChunk.signature.isEmpty()) {
throw new WebApplicationException("Expected x-amz-trailer-signature, none found", BAD_REQUEST);
}
chunkSigningSession.startChunk(trailerHeaderChunk.signature.get());
byte[] trailerHeaderContent = trailerHeaderChunk.trailerHeaders.getBytes(StandardCharsets.UTF_8);
chunkSigningSession.write(trailerHeaderContent, 0, trailerHeaderContent.length);
chunkSigningSession.complete();
});
}

private record TrailerHeaderChunk(String trailerHeaders, Optional<String> signature) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.google.common.base.Splitter;
import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import io.airlift.log.Logger;
import io.trino.aws.proxy.server.rest.RequestHeadersBuilder.InternalRequestHeaders;
import io.trino.aws.proxy.server.signing.SigningQueryParameters;
Expand Down Expand Up @@ -152,7 +153,7 @@ private static RequestContent buildRequestContent(InputStream requestEntityStrea

// AWS does not mandate x-amz-decoded-content length is required for chunked transfer encoding
// But we require it for simplicity (Content-Length is needed since we don't do chunking on outbound requests)
case AWS_CHUNKED, W3C_CHUNKED, AWS_CHUNKED_IN_W3C_CHUNKED -> () -> {
case AWS_CHUNKED, AWS_CHUNKED_UNSIGNED, W3C_CHUNKED, AWS_CHUNKED_IN_W3C_CHUNKED, AWS_CHUNKED_IN_W3C_CHUNKED_UNSIGNED -> () -> {
int contentLength = requestHeaders.decodedContentLength()
.orElseThrow(() -> new WebApplicationException(BAD_REQUEST));
return Optional.of(contentLength);
Expand Down Expand Up @@ -188,6 +189,12 @@ public Optional<InputStream> inputStream()
.map(bytes -> (InputStream) new ByteArrayInputStream(bytes))
.or(() -> Optional.of(requestEntityStream));
}

@Override
public List<String> trailerHeaders()
Comment on lines +193 to +194
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): The trailerHeaders() implementation assumes the header is present and non-null.

If the header is missing, ImmutableList.copyOf will throw a NullPointerException. Use Optional.ofNullable or default to an empty list to prevent this.

{
return ImmutableList.copyOf(requestHeaders.requestHeaders().unmodifiedHeaders().get("x-amz-trailer"));
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ private RequestHeadersBuilder() {}
"connection",
"amz-sdk-invocation-id",
"amx-sdk-request",
"host");
"host",
"x-amz-trailer");

record InternalRequestHeaders(
RequestHeaders requestHeaders,
Expand All @@ -75,7 +76,9 @@ static InternalRequestHeaders parseHeaders(MultiMap allRequestHeaders)
Builder builder = new Builder();
allRequestHeaders.forEach((headerName, headerValues) -> {
switch (headerName) {
case "authorization", "x-amz-security-token" -> {} // these get handled separately
case "authorization", "x-amz-security-token" -> {
// these get handled separately
}
case "content-length" -> builder.contentLength(headerValues);
case "x-amz-decoded-content-length" -> builder.decodedContentLength(headerValues);
case "content-encoding" -> builder.contentEncoding(headerValues);
Expand Down Expand Up @@ -104,6 +107,13 @@ private static class Builder
private Optional<String> contentSha256 = Optional.empty();
private Set<ContentType> seenRequestPayloadContentTypes = new HashSet<>();

private static final Set<ContentType> CHUNKED_CONTENT_TYPES = Set.of(
ContentType.AWS_CHUNKED,
ContentType.AWS_CHUNKED_IN_W3C_CHUNKED,
ContentType.W3C_CHUNKED,
ContentType.AWS_CHUNKED_UNSIGNED,
ContentType.AWS_CHUNKED_IN_W3C_CHUNKED_UNSIGNED);

private Builder() {}

private static Optional<String> parseHeaderValuesAsSingle(List<String> allValues)
Expand Down Expand Up @@ -193,14 +203,18 @@ private void addPassthroughHeader(String headerName, List<String> headerValues)
passthroughHeadersBuilder.addAll(headerName, headerValues);
}

private String requiredContentSha256()
{
return contentSha256.orElseThrow(() -> new WebApplicationException(BAD_REQUEST));
}
Comment on lines +206 to +209
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: The requiredContentSha256() method throws a generic BAD_REQUEST error without a message.

Include a descriptive error message in the exception to clarify when the x-amz-content-sha256 header is missing or invalid.

Suggested change
private String requiredContentSha256()
{
return contentSha256.orElseThrow(() -> new WebApplicationException(BAD_REQUEST));
}
private String requiredContentSha256()
{
return contentSha256.orElseThrow(() -> new WebApplicationException(
"Missing or invalid x-amz-content-sha256 header", BAD_REQUEST));
}


private void assertContentTypeValid(ContentType actualContentType)
{
if (actualContentType == ContentType.AWS_CHUNKED || actualContentType == ContentType.AWS_CHUNKED_IN_W3C_CHUNKED || actualContentType == ContentType.W3C_CHUNKED) {
if (CHUNKED_CONTENT_TYPES.contains(actualContentType)) {
if (decodedContentLength.isEmpty()) {
throw new WebApplicationException(LENGTH_REQUIRED);
}
String sha256 = contentSha256.orElseThrow(() -> new WebApplicationException(BAD_REQUEST));
if (actualContentType != ContentType.W3C_CHUNKED && !sha256.startsWith("STREAMING-")) {
if (actualContentType != ContentType.W3C_CHUNKED && !requiredContentSha256().startsWith("STREAMING-")) {
throw new WebApplicationException(BAD_REQUEST);
}
}
Expand All @@ -209,11 +223,21 @@ private void assertContentTypeValid(ContentType actualContentType)
private InternalRequestHeaders build(MultiMap allHeaders)
{
Optional<ContentType> applicableContentType = switch (seenRequestPayloadContentTypes.size()) {
case 0, 1 -> seenRequestPayloadContentTypes.stream().findFirst();
case 0 -> Optional.empty();
case 1 -> {
Optional<ContentType> contentType = seenRequestPayloadContentTypes.stream().findFirst();
if (contentType.get().equals(ContentType.AWS_CHUNKED) && requiredContentSha256().startsWith("STREAMING-UNSIGNED-PAYLOAD")) {
yield Optional.of(ContentType.AWS_CHUNKED_UNSIGNED);
}
yield contentType;
}
case 2 -> {
if (!seenRequestPayloadContentTypes.containsAll(ImmutableSet.of(ContentType.AWS_CHUNKED, ContentType.W3C_CHUNKED))) {
throw new WebApplicationException(BAD_REQUEST);
}
if (requiredContentSha256().startsWith("STREAMING-UNSIGNED-PAYLOAD")) {
yield Optional.of(ContentType.AWS_CHUNKED_IN_W3C_CHUNKED_UNSIGNED);
}
yield Optional.of(ContentType.AWS_CHUNKED_IN_W3C_CHUNKED);
}
default -> throw new WebApplicationException(BAD_REQUEST);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,16 @@ private Optional<InputStream> contentInputStream(RequestContent requestContent,
{
return switch (requestContent.contentType()) {
case AWS_CHUNKED, AWS_CHUNKED_IN_W3C_CHUNKED -> requestContent.inputStream()
.map(inputStream -> new AwsChunkedInputStream(limitStreamController.wrap(inputStream), signingMetadata.requiredSigningContext().chunkSigningSession(), requestContent.contentLength().orElseThrow()));

.map(inputStream -> new AwsChunkedInputStream(limitStreamController.wrap(inputStream),
Optional.of(signingMetadata.requiredSigningContext().chunkSigningSession()),
requestContent.contentLength().orElseThrow(),
requestContent.trailerHeaders()));
case AWS_CHUNKED_UNSIGNED, AWS_CHUNKED_IN_W3C_CHUNKED_UNSIGNED -> requestContent.inputStream()
.map(inputStream -> new AwsChunkedInputStream(
limitStreamController.wrap(inputStream),
Optional.empty(),
requestContent.contentLength().orElseThrow(),
requestContent.trailerHeaders()));
case STANDARD, W3C_CHUNKED -> requestContent.inputStream().map(inputStream -> {
SigningContext signingContext = signingMetadata.requiredSigningContext();
return signingContext.contentHash()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ private StatusResponse doAwsChunkedUpload(

URI requestUri = UriBuilder.fromUri(baseUri).path(bucket).path(key).build();
RequestAuthorization requestAuthorization = signRequest(requestSigningCredential, requestUri, requestDate, "PUT", requestHeaderBuilder.build());
String chunkedContent = chunkedPayloadMutator.apply(TestingChunkSigningSession.build(chunkSigningCredential, requestAuthorization.signature(), requestDate).generateChunkedStream(contentToUpload, partitionCount));
String chunkedContent = chunkedPayloadMutator.apply(TestingChunkSigningSession.build(chunkSigningCredential, requestAuthorization.signature(), requestDate).generateChunkedStream(contentToUpload, partitionCount, Optional.empty()));
Request.Builder requestBuilder = preparePut().setUri(requestUri);

requestHeaderBuilder.add("Authorization", requestAuthorization.authorization());
Expand Down
Loading