Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,12 @@ private <I extends SerializableStruct, O extends SerializableStruct> O afterIden
var identity = identityResult.unwrap();
call.context.put(CallContext.IDENTITY, identity);

// TODO: what to do with supportedAuthSchemes of an endpoint?
Endpoint endpoint = resolveEndpoint(call);
call.context.put(CallContext.ENDPOINT, endpoint);

// Augment signer properties with endpoint auth scheme overrides if present.
resolvedAuthScheme = applyEndpointAuthSchemeOverrides(endpoint, resolvedAuthScheme);

RequestT req = protocol.setServiceEndpoint(requestHook.request(), endpoint);
var signResult = resolvedAuthScheme.sign(req);
req = signResult.signedRequest();
Expand Down Expand Up @@ -311,6 +313,34 @@ private <I extends SerializableStruct, O extends SerializableStruct> Endpoint re
return call.endpointResolver.resolveEndpoint(request);
}

@SuppressWarnings("unchecked")
private <IdentityT extends Identity> ResolvedScheme<IdentityT, RequestT> applyEndpointAuthSchemeOverrides(
Endpoint endpoint,
ResolvedScheme<IdentityT, RequestT> resolvedScheme
) {
var endpointAuthSchemes = endpoint.authSchemes();
if (!endpointAuthSchemes.isEmpty()) {
var schemeId = resolvedScheme.authScheme().schemeId().toString();
for (var endpointAuthScheme : endpointAuthSchemes) {
if (schemeId.equals(endpointAuthScheme.authSchemeId())) {
var overrides = endpointAuthScheme.properties();
if (overrides.isEmpty()) {
return resolvedScheme;
}
// Apply the found overrides for the auth scheme.
var merged = Context.create();
resolvedScheme.signerProperties().copyTo(merged);
for (var key : overrides) {
merged.put((Context.Key<Object>) key, endpointAuthScheme.property(key));
}
return new ResolvedScheme<>(merged, resolvedScheme.authScheme(), resolvedScheme.identity());
}
}
}

return resolvedScheme;
}

private <I extends SerializableStruct, O extends SerializableStruct> O deserialize(
ClientCall<I, O> call,
RequestT request,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,29 @@

import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import software.amazon.smithy.java.auth.api.SignResult;
import software.amazon.smithy.java.auth.api.identity.Identity;
import software.amazon.smithy.java.auth.api.identity.IdentityResolver;
import software.amazon.smithy.java.auth.api.identity.IdentityResult;
import software.amazon.smithy.java.aws.client.restjson.RestJsonClientProtocol;
import software.amazon.smithy.java.client.core.auth.scheme.AuthScheme;
import software.amazon.smithy.java.client.core.auth.scheme.AuthSchemeOption;
import software.amazon.smithy.java.client.core.auth.scheme.AuthSchemeResolver;
import software.amazon.smithy.java.client.core.endpoint.Endpoint;
import software.amazon.smithy.java.client.core.endpoint.EndpointAuthScheme;
import software.amazon.smithy.java.client.core.endpoint.EndpointResolver;
import software.amazon.smithy.java.client.http.JavaHttpClientTransport;
import software.amazon.smithy.java.client.http.mock.MockPlugin;
import software.amazon.smithy.java.client.http.mock.MockQueue;
import software.amazon.smithy.java.context.Context;
import software.amazon.smithy.java.core.serde.document.Document;
import software.amazon.smithy.java.dynamicclient.DynamicClient;
import software.amazon.smithy.java.http.api.HttpRequest;
import software.amazon.smithy.java.http.api.HttpResponse;
import software.amazon.smithy.java.io.datastream.DataStream;
import software.amazon.smithy.java.retries.api.AcquireInitialTokenRequest;
Expand Down Expand Up @@ -188,6 +200,65 @@ public Builder toBuilder() {
assertThat(calls, contains("Acquire", "Refresh", "Success: 1"));
}

@Test
public void endpointAuthSchemeOverridesAugmentSignerProperties() {
var service = ShapeId.from("smithy.example#Sprockets");
var testSchemeId = ShapeId.from("smithy.test#testAuth");
var TEST_KEY = Context.<String>key("test-signing-override");
var capturedProperties = new AtomicReference<Context>();

// Auth scheme with a signer that captures the properties it receives.
var testScheme = AuthScheme.of(
testSchemeId,
HttpRequest.class,
Identity.class,
(request, identity, properties) -> {
capturedProperties.set(properties);
return new SignResult<>(request);
});

// Endpoint resolver that returns an endpoint with an auth scheme override.
EndpointResolver endpointResolver = params -> Endpoint.builder()
.uri("https://example.com")
.addAuthScheme(
EndpointAuthScheme.builder()
.authSchemeId(testSchemeId.toString())
.putProperty(TEST_KEY, "overridden-value")
.build())
.build();

var mockQueue = new MockQueue()
.enqueue(HttpResponse.builder()
.statusCode(200)
.body(DataStream.ofString("{\"id\":\"1\"}"))
.build());
var mock = MockPlugin.builder().addQueue(mockQueue).build();

var client = DynamicClient.builder()
.serviceId(service)
.model(MODEL)
.addPlugin(mock)
.endpointResolver(endpointResolver)
.authSchemeResolver(params -> List.of(new AuthSchemeOption(testSchemeId)))
.putSupportedAuthSchemes(testScheme)
.addIdentityResolver(new IdentityResolver<>() {
@Override
public IdentityResult<Identity> resolveIdentity(Context requestProperties) {
return IdentityResult.of(new Identity() {});
}

@Override
public Class<Identity> identityType() {
return Identity.class;
}
})
.build();

client.call("GetSprocket", Document.ofObject(Map.of("id", "1")));

assertThat(capturedProperties.get().get(TEST_KEY), equalTo("overridden-value"));
}

private static final class Token implements RetryToken {
int retry;

Expand Down
Loading