Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.web.RedirectStrategy;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
import org.springframework.security.web.savedrequest.RequestCache;
import org.springframework.util.Assert;

Expand Down Expand Up @@ -177,6 +178,8 @@ public final class AuthorizationCodeGrantConfigurer {

private OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient;

private AuthenticationSuccessHandler authenticationSuccessHandler;

private AuthorizationCodeGrantConfigurer() {
}

Expand Down Expand Up @@ -231,6 +234,20 @@ public AuthorizationCodeGrantConfigurer accessTokenResponseClient(
return this;
}

/**
* Sets the {@link AuthenticationSuccessHandler} used for handling a successful
* authorization response.
* @param authenticationSuccessHandler the handler used for handling a successful
* authorization response
* @return the {@link AuthorizationCodeGrantConfigurer} for further configuration
*/
public AuthorizationCodeGrantConfigurer authenticationSuccessHandler(
AuthenticationSuccessHandler authenticationSuccessHandler) {
Assert.notNull(authenticationSuccessHandler, "authenticationSuccessHandler cannot be null");
this.authenticationSuccessHandler = authenticationSuccessHandler;
return this;
}

private void init(B builder) {
OAuth2AuthorizationCodeAuthenticationProvider authorizationCodeAuthenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider(
getAccessTokenResponseClient());
Expand Down Expand Up @@ -288,6 +305,9 @@ private OAuth2AuthorizationCodeGrantFilter createAuthorizationCodeGrantFilter(B
if (requestCache != null) {
authorizationCodeGrantFilter.setRequestCache(requestCache);
}
if (this.authenticationSuccessHandler != null) {
authorizationCodeGrantFilter.setAuthenticationSuccessHandler(this.authenticationSuccessHandler);
}
return authorizationCodeGrantFilter;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import org.springframework.security.web.DefaultRedirectStrategy;
import org.springframework.security.web.RedirectStrategy;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
import org.springframework.security.web.savedrequest.RequestCache;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult;
Expand Down Expand Up @@ -106,6 +107,8 @@ public class OAuth2ClientConfigurerTests {

private static RequestCache requestCache;

private static AuthenticationSuccessHandler authenticationSuccessHandler;

public final SpringTestContext spring = new SpringTestContext(this);

@Autowired
Expand Down Expand Up @@ -146,6 +149,7 @@ public void setup() {
given(accessTokenResponseClient.getTokenResponse(any(OAuth2AuthorizationCodeGrantRequest.class)))
.willReturn(accessTokenResponse);
requestCache = mock(RequestCache.class);
authenticationSuccessHandler = null;
}

@Test
Expand Down Expand Up @@ -345,6 +349,45 @@ public void configureWhenOAuth2LoginBeansConfiguredThenNotShared() throws Except
verifyNoInteractions(clientRegistrationRepository, authorizedClientRepository);
}

@Test
public void configureWhenCustomAuthenticationSuccessHandlerSetThenAuthenticationSuccessHandlerUsed()
throws Exception {
authenticationSuccessHandler = mock(AuthenticationSuccessHandler.class);
this.spring.register(OAuth2ClientConfig.class).autowire();
Map<String, Object> attributes = new HashMap<>();
attributes.put(OAuth2ParameterNames.REGISTRATION_ID, this.registration1.getRegistrationId());
// @formatter:off
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri(this.registration1.getProviderDetails().getAuthorizationUri())
.clientId(this.registration1.getClientId())
.redirectUri("http://localhost/client-1")
.state("state")
.attributes(attributes)
.build();
// @formatter:on
AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository();
MockHttpServletRequest request = new MockHttpServletRequest("GET", "");
MockHttpServletResponse response = new MockHttpServletResponse();
authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response);
MockHttpSession session = (MockHttpSession) request.getSession();
String principalName = "user1";
TestingAuthenticationToken authentication = new TestingAuthenticationToken(principalName, "password");
// @formatter:off
MockHttpServletRequestBuilder clientRequest = get("/client-1")
.param(OAuth2ParameterNames.CODE, "code")
.param(OAuth2ParameterNames.STATE, "state")
.with(authentication(authentication))
.session(session);
this.mockMvc.perform(clientRequest)
.andExpect(status().isOk());
// @formatter:on
verify(authenticationSuccessHandler).onAuthenticationSuccess(any(HttpServletRequest.class),
any(HttpServletResponse.class), any());
OAuth2AuthorizedClient authorizedClient = authorizedClientRepository
.loadAuthorizedClient(this.registration1.getRegistrationId(), authentication, request);
assertThat(authorizedClient).isNotNull();
}

@EnableWebSecurity
@Configuration
@EnableWebMvc
Expand All @@ -359,10 +402,14 @@ SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
.requestCache((cache) -> cache
.requestCache(requestCache))
.oauth2Client((client) -> client
.authorizationCodeGrant((code) -> code
.authorizationRequestResolver(authorizationRequestResolver)
.authorizationRedirectStrategy(authorizationRedirectStrategy)
.accessTokenResponseClient(accessTokenResponseClient)));
.authorizationCodeGrant((code) -> {
code.authorizationRequestResolver(authorizationRequestResolver)
.authorizationRedirectStrategy(authorizationRedirectStrategy)
.accessTokenResponseClient(accessTokenResponseClient);
if (authenticationSuccessHandler != null) {
code.authenticationSuccessHandler(authenticationSuccessHandler);
}
}));
return http.build();
// @formatter:on
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.web.DefaultRedirectStrategy;
import org.springframework.security.web.RedirectStrategy;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
import org.springframework.security.web.savedrequest.RequestCache;
Expand Down Expand Up @@ -121,6 +122,8 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {

private RequestCache requestCache = new HttpSessionRequestCache();

private AuthenticationSuccessHandler authenticationSuccessHandler;

/**
* Constructs an {@code OAuth2AuthorizationCodeGrantFilter} using the provided
* parameters.
Expand Down Expand Up @@ -162,6 +165,18 @@ public final void setRequestCache(RequestCache requestCache) {
this.requestCache = requestCache;
}

/**
* Sets the {@link AuthenticationSuccessHandler} used for handling a successful
* authorization response.
* @param authenticationSuccessHandler the handler used for handling a successful
* authorization response
* @since 7.1
*/
public final void setAuthenticationSuccessHandler(AuthenticationSuccessHandler authenticationSuccessHandler) {
Assert.notNull(authenticationSuccessHandler, "authenticationSuccessHandler cannot be null");
this.authenticationSuccessHandler = authenticationSuccessHandler;
}

/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
Expand Down Expand Up @@ -217,7 +232,7 @@ private boolean matchesAuthorizationResponse(HttpServletRequest request) {
}

private void processAuthorizationResponse(HttpServletRequest request, HttpServletResponse response)
throws IOException {
throws IOException, ServletException {
OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestRepository
.removeAuthorizationRequest(request, response);
String registrationId = authorizationRequest.getAttribute(OAuth2ParameterNames.REGISTRATION_ID);
Expand Down Expand Up @@ -254,6 +269,10 @@ private void processAuthorizationResponse(HttpServletRequest request, HttpServle
authenticationResult.getRefreshToken());
this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, currentAuthentication, request,
response);
if (this.authenticationSuccessHandler != null) {
this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, authenticationResult);
return;
}
String redirectUrl = authorizationRequest.getRedirectUri();
SavedRequest savedRequest = this.requestCache.getRequest(request, response);
if (savedRequest != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationExchanges;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
import org.springframework.security.web.savedrequest.RequestCache;
import org.springframework.security.web.savedrequest.SavedRequest;
Expand Down Expand Up @@ -152,6 +153,11 @@ public void setRequestCacheWhenRequestCacheIsNullThenThrowIllegalArgumentExcepti
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setRequestCache(null));
}

@Test
public void setAuthenticationSuccessHandlerWhenAuthenticationSuccessHandlerIsNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthenticationSuccessHandler(null));
}

@Test
public void doFilterWhenNotAuthorizationResponseThenNotProcessed() throws Exception {
String requestUri = "/path";
Expand Down Expand Up @@ -308,6 +314,27 @@ public void doFilterWhenAuthorizationSucceedsThenRedirected() throws Exception {
assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/callback/client-1");
}

@Test
public void doFilterWhenAuthorizationSucceedsAndAuthenticationSuccessHandlerConfiguredThenAuthenticationSuccessHandlerUsed()
throws Exception {
MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");
MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
this.setUpAuthenticationResult(this.registration1);
AuthenticationSuccessHandler authenticationSuccessHandler = mock(AuthenticationSuccessHandler.class);
this.filter.setAuthenticationSuccessHandler(authenticationSuccessHandler);
this.filter.doFilter(authorizationResponse, response, filterChain);
verify(authenticationSuccessHandler).onAuthenticationSuccess(any(HttpServletRequest.class),
any(HttpServletResponse.class), any(Authentication.class));
verifyNoInteractions(filterChain);
OAuth2AuthorizedClient authorizedClient = this.authorizedClientService
.loadAuthorizedClient(this.registration1.getRegistrationId(), this.principalName1);
assertThat(authorizedClient).isNotNull();
assertThat(response.getRedirectedUrl()).isNull();
}

@Test
public void doFilterWhenCustomSecurityContextHolderStrategyThenUses() throws Exception {
MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");
Expand Down
Loading