Add tests to oauth2-client

Fixes gh-4299
This commit is contained in:
Joe Grandja 2017-11-07 12:14:03 -05:00
parent f2ccc53549
commit 473ac0e37c
32 changed files with 3361 additions and 297 deletions

View File

@ -9,5 +9,9 @@ dependencies {
optional project(':spring-security-oauth2-jose')
testCompile powerMock2Dependencies
testCompile 'com.squareup.okhttp3:mockwebserver'
testCompile 'com.fasterxml.jackson.core:jackson-databind'
provided 'javax.servlet:javax.servlet-api'
}

View File

@ -100,7 +100,9 @@ public class NimbusAuthorizationCodeTokenResponseClient implements OAuth2AccessT
httpRequest.setReadTimeout(30000);
tokenResponse = com.nimbusds.oauth2.sdk.TokenResponse.parse(httpRequest.send());
} catch (ParseException pe) {
throw new OAuth2AuthenticationException(new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE), pe);
OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE,
"An error occurred parsing the Access Token response: " + pe.getMessage(), null);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), pe);
} catch (IOException ioe) {
throw new AuthenticationServiceException("An error occurred while sending the Access Token Request: " +
ioe.getMessage(), ioe);

View File

@ -262,7 +262,7 @@ public class OidcAuthorizationCodeAuthenticationProvider implements Authenticati
// 10. The iat Claim can be used to reject tokens that were issued too far away from the current time,
// limiting the amount of time that nonces need to be stored to prevent attacks.
// The acceptable range is Client specific.
Instant maxIssuedAt = now.plusSeconds(30);
Instant maxIssuedAt = Instant.now().plusSeconds(30);
if (issuedAt.isAfter(maxIssuedAt)) {
this.throwInvalidIdTokenException();
}

View File

@ -27,6 +27,7 @@ import org.springframework.http.HttpHeaders;
import org.springframework.http.client.AbstractClientHttpResponse;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.http.converter.GenericHttpMessageConverter;
import org.springframework.http.converter.HttpMessageNotReadableException;
import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter;
import org.springframework.security.authentication.AuthenticationServiceException;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
@ -57,7 +58,7 @@ final class NimbusUserInfoResponseClient {
userInfoRequest.getClientRegistration(), userInfoRequest.getAccessToken());
try {
return (T) this.genericHttpMessageConverter.read(returnType, userInfoResponse);
} catch (IOException ex) {
} catch (IOException | HttpMessageNotReadableException ex) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE,
"An error occurred reading the UserInfo Success response: " + ex.getMessage(), null);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex);
@ -69,7 +70,7 @@ final class NimbusUserInfoResponseClient {
userInfoRequest.getClientRegistration(), userInfoRequest.getAccessToken());
try {
return (T) this.genericHttpMessageConverter.read(typeReference.getType(), null, userInfoResponse);
} catch (IOException ex) {
} catch (IOException | HttpMessageNotReadableException ex) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE,
"An error occurred reading the UserInfo Success response: " + ex.getMessage(), null);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex);

View File

@ -26,6 +26,7 @@ import org.springframework.security.oauth2.core.oidc.OidcUserInfo;
import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import java.util.Arrays;
@ -52,6 +53,7 @@ public class OidcUserService implements OAuth2UserService<OidcUserRequest, OidcU
@Override
public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2AuthenticationException {
Assert.notNull(userRequest, "userRequest cannot be null");
OidcUserInfo userInfo = null;
if (this.shouldRetrieveUserInfo(userRequest)) {
ParameterizedTypeReference<Map<String, Object>> typeReference =

View File

@ -49,6 +49,7 @@ public class CustomUserTypesOAuth2UserService implements OAuth2UserService<OAuth
@Override
public OAuth2User loadUser(OAuth2UserRequest userRequest) throws OAuth2AuthenticationException {
Assert.notNull(userRequest, "userRequest cannot be null");
String registrationId = userRequest.getClientRegistration().getRegistrationId();
Class<? extends OAuth2User> customUserType;
if ((customUserType = this.customUserTypes.get(registrationId)) == null) {

View File

@ -23,6 +23,7 @@ import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.security.oauth2.core.user.OAuth2UserAuthority;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import java.util.HashSet;
@ -52,6 +53,7 @@ public class DefaultOAuth2UserService implements OAuth2UserService<OAuth2UserReq
@Override
public OAuth2User loadUser(OAuth2UserRequest userRequest) throws OAuth2AuthenticationException {
Assert.notNull(userRequest, "userRequest cannot be null");
String userNameAttributeName = userRequest.getClientRegistration().getProviderDetails().getUserInfoEndpoint().getUserNameAttributeName();
if (!StringUtils.hasText(userNameAttributeName)) {
OAuth2Error oauth2Error = new OAuth2Error(

View File

@ -51,6 +51,7 @@ public class DelegatingOAuth2UserService<R extends OAuth2UserRequest, U extends
@Override
public U loadUser(R userRequest) throws OAuth2AuthenticationException {
Assert.notNull(userRequest, "userRequest cannot be null");
return this.userServices.stream()
.map(userService -> userService.loadUser(userRequest))
.filter(Objects::nonNull)

View File

@ -27,6 +27,7 @@ import org.springframework.http.HttpHeaders;
import org.springframework.http.client.AbstractClientHttpResponse;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.http.converter.GenericHttpMessageConverter;
import org.springframework.http.converter.HttpMessageNotReadableException;
import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter;
import org.springframework.security.authentication.AuthenticationServiceException;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
@ -54,7 +55,7 @@ final class NimbusUserInfoResponseClient {
userInfoRequest.getClientRegistration(), userInfoRequest.getAccessToken());
try {
return (T) this.genericHttpMessageConverter.read(returnType, userInfoResponse);
} catch (IOException ex) {
} catch (IOException | HttpMessageNotReadableException ex) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE,
"An error occurred reading the UserInfo Success response: " + ex.getMessage(), null);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex);
@ -66,7 +67,7 @@ final class NimbusUserInfoResponseClient {
userInfoRequest.getClientRegistration(), userInfoRequest.getAccessToken());
try {
return (T) this.genericHttpMessageConverter.read(typeReference.getType(), null, userInfoResponse);
} catch (IOException ex) {
} catch (IOException | HttpMessageNotReadableException ex) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE,
"An error occurred reading the UserInfo Success response: " + ex.getMessage(), null);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex);

View File

@ -16,6 +16,7 @@
package org.springframework.security.oauth2.client.web;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.util.Assert;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
@ -36,6 +37,7 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository implements Au
@Override
public OAuth2AuthorizationRequest loadAuthorizationRequest(HttpServletRequest request) {
Assert.notNull(request, "request cannot be null");
HttpSession session = request.getSession(false);
if (session != null) {
return (OAuth2AuthorizationRequest) session.getAttribute(this.sessionAttributeName);
@ -46,6 +48,8 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository implements Au
@Override
public void saveAuthorizationRequest(OAuth2AuthorizationRequest authorizationRequest, HttpServletRequest request,
HttpServletResponse response) {
Assert.notNull(request, "request cannot be null");
Assert.notNull(response, "response cannot be null");
if (authorizationRequest == null) {
this.removeAuthorizationRequest(request);
return;
@ -55,6 +59,7 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository implements Au
@Override
public OAuth2AuthorizationRequest removeAuthorizationRequest(HttpServletRequest request) {
Assert.notNull(request, "request cannot be null");
OAuth2AuthorizationRequest authorizationRequest = this.loadAuthorizationRequest(request);
if (authorizationRequest != null) {
request.getSession().removeAttribute(this.sessionAttributeName);

View File

@ -17,6 +17,7 @@ package org.springframework.security.oauth2.client.web;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponentsBuilder;
@ -35,6 +36,7 @@ import java.util.Set;
class OAuth2AuthorizationRequestUriBuilder {
URI build(OAuth2AuthorizationRequest authorizationRequest) {
Assert.notNull(authorizationRequest, "authorizationRequest cannot be null");
Set<String> scopes = authorizationRequest.getScopes();
UriComponentsBuilder uriBuilder = UriComponentsBuilder
.fromUriString(authorizationRequest.getAuthorizationUri())

View File

@ -0,0 +1,188 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client;
import org.junit.Test;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
* Tests for {@link InMemoryOAuth2AuthorizedClientService}.
*
* @author Joe Grandja
*/
public class InMemoryOAuth2AuthorizedClientServiceTests {
private String registrationId1 = "registration-1";
private String registrationId2 = "registration-2";
private String registrationId3 = "registration-3";
private String principalName1 = "principal-1";
private String principalName2 = "principal-2";
private ClientRegistration registration1 = ClientRegistration.withRegistrationId(this.registrationId1)
.clientId("client-1")
.clientSecret("secret")
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.redirectUri("{scheme}://{serverName}:{serverPort}{contextPath}/login/oauth2/code/{registrationId}")
.scope("user")
.authorizationUri("https://provider.com/oauth2/authorize")
.tokenUri("https://provider.com/oauth2/token")
.userInfoUri("https://provider.com/oauth2/user")
.userNameAttributeName("id")
.clientName("client-1")
.build();
private ClientRegistration registration2 = ClientRegistration.withRegistrationId(this.registrationId2)
.clientId("client-2")
.clientSecret("secret")
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.redirectUri("{scheme}://{serverName}:{serverPort}{contextPath}/login/oauth2/code/{registrationId}")
.scope("openid", "profile", "email")
.authorizationUri("https://provider.com/oauth2/authorize")
.tokenUri("https://provider.com/oauth2/token")
.userInfoUri("https://provider.com/oauth2/userinfo")
.jwkSetUri("https://provider.com/oauth2/keys")
.clientName("client-2")
.build();
private ClientRegistration registration3 = ClientRegistration.withRegistrationId(this.registrationId3)
.clientId("client-3")
.clientSecret("secret")
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.redirectUri("{scheme}://{serverName}:{serverPort}{contextPath}/login/oauth2/code/{registrationId}")
.scope("openid", "profile")
.authorizationUri("https://provider.com/oauth2/authorize")
.tokenUri("https://provider.com/oauth2/token")
.userInfoUri("https://provider.com/oauth2/userinfo")
.jwkSetUri("https://provider.com/oauth2/keys")
.clientName("client-3")
.build();
private ClientRegistrationRepository clientRegistrationRepository =
new InMemoryClientRegistrationRepository(this.registration1, this.registration2, this.registration3);
private InMemoryOAuth2AuthorizedClientService authorizedClientService =
new InMemoryOAuth2AuthorizedClientService(this.clientRegistrationRepository);
@Test(expected = IllegalArgumentException.class)
public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() {
new InMemoryOAuth2AuthorizedClientService(null);
}
@Test(expected = IllegalArgumentException.class)
public void loadAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() {
this.authorizedClientService.loadAuthorizedClient(null, this.principalName1);
}
@Test(expected = IllegalArgumentException.class)
public void loadAuthorizedClientWhenPrincipalNameIsNullThenThrowIllegalArgumentException() {
this.authorizedClientService.loadAuthorizedClient(this.registrationId1, null);
}
@Test
public void loadAuthorizedClientWhenClientRegistrationNotFoundThenReturnNull() {
OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient(
"registration-not-found", this.principalName1);
assertThat(authorizedClient).isNull();
}
@Test
public void loadAuthorizedClientWhenClientRegistrationFoundButNotAssociatedToPrincipalThenReturnNull() {
OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient(
this.registrationId1, "principal-not-found");
assertThat(authorizedClient).isNull();
}
@Test
public void loadAuthorizedClientWhenClientRegistrationFoundAndAssociatedToPrincipalThenReturnAuthorizedClient() {
Authentication authentication = mock(Authentication.class);
when(authentication.getName()).thenReturn(this.principalName1);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
this.registration1, this.principalName1, mock(OAuth2AccessToken.class));
this.authorizedClientService.saveAuthorizedClient(authorizedClient, authentication);
OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientService.loadAuthorizedClient(
this.registrationId1, this.principalName1);
assertThat(loadedAuthorizedClient).isEqualTo(authorizedClient);
}
@Test(expected = IllegalArgumentException.class)
public void saveAuthorizedClientWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() {
this.authorizedClientService.saveAuthorizedClient(null, mock(Authentication.class));
}
@Test(expected = IllegalArgumentException.class)
public void saveAuthorizedClientWhenPrincipalIsNullThenThrowIllegalArgumentException() {
this.authorizedClientService.saveAuthorizedClient(mock(OAuth2AuthorizedClient.class), null);
}
@Test
public void saveAuthorizedClientWhenSavedThenCanLoad() {
Authentication authentication = mock(Authentication.class);
when(authentication.getName()).thenReturn(this.principalName2);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
this.registration3, this.principalName2, mock(OAuth2AccessToken.class));
this.authorizedClientService.saveAuthorizedClient(authorizedClient, authentication);
OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientService.loadAuthorizedClient(
this.registrationId3, this.principalName2);
assertThat(loadedAuthorizedClient).isEqualTo(authorizedClient);
}
@Test(expected = IllegalArgumentException.class)
public void removeAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() {
this.authorizedClientService.removeAuthorizedClient(null, this.principalName2);
}
@Test(expected = IllegalArgumentException.class)
public void removeAuthorizedClientWhenPrincipalNameIsNullThenThrowIllegalArgumentException() {
this.authorizedClientService.removeAuthorizedClient(this.registrationId2, null);
}
@Test
public void removeAuthorizedClientWhenSavedThenRemoved() {
Authentication authentication = mock(Authentication.class);
when(authentication.getName()).thenReturn(this.principalName2);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
this.registration2, this.principalName2, mock(OAuth2AccessToken.class));
this.authorizedClientService.saveAuthorizedClient(authorizedClient, authentication);
OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientService.loadAuthorizedClient(
this.registrationId2, this.principalName2);
assertThat(loadedAuthorizedClient).isNotNull();
this.authorizedClientService.removeAuthorizedClient(this.registrationId2, this.principalName2);
loadedAuthorizedClient = this.authorizedClientService.loadAuthorizedClient(
this.registrationId2, this.principalName2);
assertThat(loadedAuthorizedClient).isNull();
}
}

View File

@ -0,0 +1,72 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
/**
* Tests for {@link OAuth2AuthorizedClient}.
*
* @author Joe Grandja
*/
@RunWith(PowerMockRunner.class)
@PrepareForTest(ClientRegistration.class)
public class OAuth2AuthorizedClientTests {
private ClientRegistration clientRegistration;
private String principalName;
private OAuth2AccessToken accessToken;
@Before
public void setUp() {
this.clientRegistration = mock(ClientRegistration.class);
this.principalName = "principal";
this.accessToken = mock(OAuth2AccessToken.class);
}
@Test(expected = IllegalArgumentException.class)
public void constructorWhenClientRegistrationIsNullThenThrowIllegalArgumentException() {
new OAuth2AuthorizedClient(null, this.principalName, this.accessToken);
}
@Test(expected = IllegalArgumentException.class)
public void constructorWhenPrincipalNameIsNullThenThrowIllegalArgumentException() {
new OAuth2AuthorizedClient(this.clientRegistration, null, this.accessToken);
}
@Test(expected = IllegalArgumentException.class)
public void constructorWhenAccessTokenIsNullThenThrowIllegalArgumentException() {
new OAuth2AuthorizedClient(this.clientRegistration, this.principalName, null);
}
@Test
public void constructorWhenAllParametersProvidedAndValidThenCreated() {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
this.clientRegistration, this.principalName, this.accessToken);
assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.clientRegistration);
assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principalName);
assertThat(authorizedClient.getAccessToken()).isEqualTo(this.accessToken);
}
}

View File

@ -0,0 +1,77 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.authentication;
import org.junit.Before;
import org.junit.Test;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.oauth2.core.user.OAuth2User;
import java.util.Collection;
import java.util.Collections;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
/**
* Tests for {@link OAuth2AuthenticationToken}.
*
* @author Joe Grandja
*/
public class OAuth2AuthenticationTokenTests {
private OAuth2User principal;
private Collection<? extends GrantedAuthority> authorities;
private String authorizedClientRegistrationId;
@Before
public void setUp() {
this.principal = mock(OAuth2User.class);
this.authorities = Collections.emptyList();
this.authorizedClientRegistrationId = "client-registration-1";
}
@Test(expected = IllegalArgumentException.class)
public void constructorWhenPrincipalIsNullThenThrowIllegalArgumentException() {
new OAuth2AuthenticationToken(null, this.authorities, this.authorizedClientRegistrationId);
}
@Test
public void constructorWhenAuthoritiesIsNullThenCreated() {
new OAuth2AuthenticationToken(this.principal, null, this.authorizedClientRegistrationId);
}
@Test
public void constructorWhenAuthoritiesIsEmptyThenCreated() {
new OAuth2AuthenticationToken(this.principal, Collections.emptyList(), this.authorizedClientRegistrationId);
}
@Test(expected = IllegalArgumentException.class)
public void constructorWhenAuthorizedClientRegistrationIdIsNullThenThrowIllegalArgumentException() {
new OAuth2AuthenticationToken(this.principal, this.authorities, null);
}
@Test
public void constructorWhenAllParametersProvidedAndValidThenCreated() {
OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(
this.principal, this.authorities, this.authorizedClientRegistrationId);
assertThat(authentication.getPrincipal()).isEqualTo(this.principal);
assertThat(authentication.getCredentials()).isEqualTo("");
assertThat(authentication.getAuthorities()).isEqualTo(this.authorities);
assertThat(authentication.getAuthorizedClientRegistrationId()).isEqualTo(this.authorizedClientRegistrationId);
assertThat(authentication.isAuthenticated()).isEqualTo(true);
}
}

View File

@ -0,0 +1,215 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.authentication;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.mockito.stubbing.Answer;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
import org.springframework.security.oauth2.core.user.OAuth2User;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import static org.assertj.core.api.Assertions.assertThat;
import static org.hamcrest.CoreMatchers.containsString;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyCollection;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
* Tests for {@link OAuth2LoginAuthenticationProvider}.
*
* @author Joe Grandja
*/
@PrepareForTest({ClientRegistration.class, OAuth2AuthorizationRequest.class,
OAuth2AuthorizationResponse.class, OAuth2AccessTokenResponse.class})
@RunWith(PowerMockRunner.class)
public class OAuth2LoginAuthenticationProviderTests {
private ClientRegistration clientRegistration;
private OAuth2AuthorizationRequest authorizationRequest;
private OAuth2AuthorizationResponse authorizationResponse;
private OAuth2AuthorizationExchange authorizationExchange;
private OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient;
private OAuth2UserService<OAuth2UserRequest, OAuth2User> userService;
private OAuth2LoginAuthenticationProvider authenticationProvider;
@Rule
public ExpectedException exception = ExpectedException.none();
@Before
@SuppressWarnings("unchecked")
public void setUp() throws Exception {
this.clientRegistration = mock(ClientRegistration.class);
this.authorizationRequest = mock(OAuth2AuthorizationRequest.class);
this.authorizationResponse = mock(OAuth2AuthorizationResponse.class);
this.authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, this.authorizationResponse);
this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class);
this.userService = mock(OAuth2UserService.class);
this.authenticationProvider = new OAuth2LoginAuthenticationProvider(this.accessTokenResponseClient, this.userService);
when(this.authorizationRequest.getScopes()).thenReturn(new LinkedHashSet<>(Arrays.asList("scope1", "scope2")));
when(this.authorizationRequest.getState()).thenReturn("12345");
when(this.authorizationResponse.getState()).thenReturn("12345");
when(this.authorizationRequest.getRedirectUri()).thenReturn("http://example.com");
when(this.authorizationResponse.getRedirectUri()).thenReturn("http://example.com");
}
@Test
public void constructorWhenAccessTokenResponseClientIsNullThenThrowIllegalArgumentException() {
this.exception.expect(IllegalArgumentException.class);
new OAuth2LoginAuthenticationProvider(null, this.userService);
}
@Test
public void constructorWhenUserServiceIsNullThenThrowIllegalArgumentException() {
this.exception.expect(IllegalArgumentException.class);
new OAuth2LoginAuthenticationProvider(this.accessTokenResponseClient, null);
}
@Test
public void setAuthoritiesMapperWhenAuthoritiesMapperIsNullThenThrowIllegalArgumentException() {
this.exception.expect(IllegalArgumentException.class);
this.authenticationProvider.setAuthoritiesMapper(null);
}
@Test
public void supportsWhenTypeOAuth2LoginAuthenticationTokenThenReturnTrue() {
assertThat(this.authenticationProvider.supports(OAuth2LoginAuthenticationToken.class)).isTrue();
}
@Test
public void authenticateWhenAuthorizationRequestContainsOpenidScopeThenReturnNull() {
when(this.authorizationRequest.getScopes()).thenReturn(new LinkedHashSet<>(Collections.singleton("openid")));
OAuth2LoginAuthenticationToken authentication =
(OAuth2LoginAuthenticationToken)this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
assertThat(authentication).isNull();
}
@Test
public void authenticateWhenAuthorizationErrorResponseThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString(OAuth2ErrorCodes.INVALID_REQUEST));
when(this.authorizationResponse.statusError()).thenReturn(true);
when(this.authorizationResponse.getError()).thenReturn(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST));
this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
}
@Test
public void authenticateWhenAuthorizationResponseStateNotEqualAuthorizationRequestStateThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_state_parameter"));
when(this.authorizationRequest.getState()).thenReturn("12345");
when(this.authorizationResponse.getState()).thenReturn("67890");
this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
}
@Test
public void authenticateWhenAuthorizationResponseRedirectUriNotEqualAuthorizationRequestRedirectUriThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_redirect_uri_parameter"));
when(this.authorizationRequest.getRedirectUri()).thenReturn("http://example.com");
when(this.authorizationResponse.getRedirectUri()).thenReturn("http://example2.com");
this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
}
@Test
public void authenticateWhenLoginSuccessThenReturnAuthentication() {
OAuth2AccessToken accessToken = mock(OAuth2AccessToken.class);
OAuth2AccessTokenResponse accessTokenResponse = mock(OAuth2AccessTokenResponse.class);
when(accessTokenResponse.getAccessToken()).thenReturn(accessToken);
when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);
OAuth2User principal = mock(OAuth2User.class);
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
when(principal.getAuthorities()).thenAnswer(
(Answer<List<GrantedAuthority>>) invocation -> authorities);
when(this.userService.loadUser(any())).thenReturn(principal);
OAuth2LoginAuthenticationToken authentication =
(OAuth2LoginAuthenticationToken)this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
assertThat(authentication.isAuthenticated()).isTrue();
assertThat(authentication.getPrincipal()).isEqualTo(principal);
assertThat(authentication.getCredentials()).isEqualTo("");
assertThat(authentication.getAuthorities()).isEqualTo(authorities);
assertThat(authentication.getClientRegistration()).isEqualTo(this.clientRegistration);
assertThat(authentication.getAuthorizationExchange()).isEqualTo(this.authorizationExchange);
assertThat(authentication.getAccessToken()).isEqualTo(accessToken);
}
@Test
public void authenticateWhenAuthoritiesMapperSetThenReturnMappedAuthorities() {
OAuth2AccessToken accessToken = mock(OAuth2AccessToken.class);
OAuth2AccessTokenResponse accessTokenResponse = mock(OAuth2AccessTokenResponse.class);
when(accessTokenResponse.getAccessToken()).thenReturn(accessToken);
when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);
OAuth2User principal = mock(OAuth2User.class);
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
when(principal.getAuthorities()).thenAnswer(
(Answer<List<GrantedAuthority>>) invocation -> authorities);
when(this.userService.loadUser(any())).thenReturn(principal);
List<GrantedAuthority> mappedAuthorities = AuthorityUtils.createAuthorityList("ROLE_OAUTH2_USER");
GrantedAuthoritiesMapper authoritiesMapper = mock(GrantedAuthoritiesMapper.class);
when(authoritiesMapper.mapAuthorities(anyCollection())).thenAnswer(
(Answer<List<GrantedAuthority>>) invocation -> mappedAuthorities);
this.authenticationProvider.setAuthoritiesMapper(authoritiesMapper);
OAuth2LoginAuthenticationToken authentication =
(OAuth2LoginAuthenticationToken)this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
assertThat(authentication.getAuthorities()).isEqualTo(mappedAuthorities);
}
}

View File

@ -0,0 +1,131 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.authentication;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
import org.springframework.security.oauth2.core.user.OAuth2User;
import java.util.Collection;
import java.util.Collections;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
/**
* Tests for {@link OAuth2LoginAuthenticationToken}.
*
* @author Joe Grandja
*/
@RunWith(PowerMockRunner.class)
@PrepareForTest({ClientRegistration.class, OAuth2AuthorizationExchange.class})
public class OAuth2LoginAuthenticationTokenTests {
private OAuth2User principal;
private Collection<? extends GrantedAuthority> authorities;
private ClientRegistration clientRegistration;
private OAuth2AuthorizationExchange authorizationExchange;
private OAuth2AccessToken accessToken;
@Before
public void setUp() {
this.principal = mock(OAuth2User.class);
this.authorities = Collections.emptyList();
this.clientRegistration = mock(ClientRegistration.class);
this.authorizationExchange = mock(OAuth2AuthorizationExchange.class);
this.accessToken = mock(OAuth2AccessToken.class);
}
@Test(expected = IllegalArgumentException.class)
public void constructorAuthorizationRequestResponseWhenClientRegistrationIsNullThenThrowIllegalArgumentException() {
new OAuth2LoginAuthenticationToken(null, this.authorizationExchange);
}
@Test(expected = IllegalArgumentException.class)
public void constructorAuthorizationRequestResponseWhenAuthorizationExchangeIsNullThenThrowIllegalArgumentException() {
new OAuth2LoginAuthenticationToken(this.clientRegistration, null);
}
@Test
public void constructorAuthorizationRequestResponseWhenAllParametersProvidedAndValidThenCreated() {
OAuth2LoginAuthenticationToken authentication = new OAuth2LoginAuthenticationToken(
this.clientRegistration, this.authorizationExchange);
assertThat(authentication.getPrincipal()).isNull();
assertThat(authentication.getCredentials()).isEqualTo("");
assertThat(authentication.getAuthorities()).isEqualTo(Collections.emptyList());
assertThat(authentication.getClientRegistration()).isEqualTo(this.clientRegistration);
assertThat(authentication.getAuthorizationExchange()).isEqualTo(this.authorizationExchange);
assertThat(authentication.getAccessToken()).isNull();
assertThat(authentication.isAuthenticated()).isEqualTo(false);
}
@Test(expected = IllegalArgumentException.class)
public void constructorTokenRequestResponseWhenClientRegistrationIsNullThenThrowIllegalArgumentException() {
new OAuth2LoginAuthenticationToken(null, this.authorizationExchange, this.principal,
this.authorities, this.accessToken);
}
@Test(expected = IllegalArgumentException.class)
public void constructorTokenRequestResponseWhenAuthorizationExchangeIsNullThenThrowIllegalArgumentException() {
new OAuth2LoginAuthenticationToken(this.clientRegistration, null, this.principal,
this.authorities, this.accessToken);
}
@Test(expected = IllegalArgumentException.class)
public void constructorTokenRequestResponseWhenPrincipalIsNullThenThrowIllegalArgumentException() {
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange, null,
this.authorities, this.accessToken);
}
@Test
public void constructorTokenRequestResponseWhenAuthoritiesIsNullThenCreated() {
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange,
this.principal, null, this.accessToken);
}
@Test
public void constructorTokenRequestResponseWhenAuthoritiesIsEmptyThenCreated() {
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange,
this.principal, Collections.emptyList(), this.accessToken);
}
@Test(expected = IllegalArgumentException.class)
public void constructorTokenRequestResponseWhenAccessTokenIsNullThenThrowIllegalArgumentException() {
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange, this.principal,
this.authorities, null);
}
@Test
public void constructorTokenRequestResponseWhenAllParametersProvidedAndValidThenCreated() {
OAuth2LoginAuthenticationToken authentication = new OAuth2LoginAuthenticationToken(
this.clientRegistration, this.authorizationExchange, this.principal, this.authorities, this.accessToken);
assertThat(authentication.getPrincipal()).isEqualTo(this.principal);
assertThat(authentication.getCredentials()).isEqualTo("");
assertThat(authentication.getAuthorities()).isEqualTo(this.authorities);
assertThat(authentication.getClientRegistration()).isEqualTo(this.clientRegistration);
assertThat(authentication.getAuthorizationExchange()).isEqualTo(this.authorizationExchange);
assertThat(authentication.getAccessToken()).isEqualTo(this.accessToken);
assertThat(authentication.isAuthenticated()).isEqualTo(true);
}
}

View File

@ -0,0 +1,300 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.endpoint;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.powermock.core.classloader.annotations.PowerMockIgnore;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.security.authentication.AuthenticationServiceException;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
import java.time.Instant;
import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.Set;
import static org.assertj.core.api.Assertions.assertThat;
import static org.hamcrest.CoreMatchers.containsString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
* Tests for {@link NimbusAuthorizationCodeTokenResponseClient}.
*
* @author Joe Grandja
*/
@PowerMockIgnore("okhttp3.*")
@PrepareForTest({ClientRegistration.class, OAuth2AuthorizationRequest.class, OAuth2AuthorizationResponse.class, OAuth2AuthorizationExchange.class})
@RunWith(PowerMockRunner.class)
public class NimbusAuthorizationCodeTokenResponseClientTests {
private ClientRegistration clientRegistration;
private ClientRegistration.ProviderDetails providerDetails;
private OAuth2AuthorizationRequest authorizationRequest;
private OAuth2AuthorizationResponse authorizationResponse;
private OAuth2AuthorizationExchange authorizationExchange;
private NimbusAuthorizationCodeTokenResponseClient tokenResponseClient = new NimbusAuthorizationCodeTokenResponseClient();
@Rule
public ExpectedException exception = ExpectedException.none();
@Before
public void setUp() throws Exception {
this.clientRegistration = mock(ClientRegistration.class);
this.providerDetails = mock(ClientRegistration.ProviderDetails.class);
this.authorizationRequest = mock(OAuth2AuthorizationRequest.class);
this.authorizationResponse = mock(OAuth2AuthorizationResponse.class);
this.authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, this.authorizationResponse);
when(this.clientRegistration.getProviderDetails()).thenReturn(this.providerDetails);
when(this.clientRegistration.getClientId()).thenReturn("client-id");
when(this.clientRegistration.getClientSecret()).thenReturn("secret");
when(this.clientRegistration.getClientAuthenticationMethod()).thenReturn(ClientAuthenticationMethod.BASIC);
when(this.authorizationRequest.getRedirectUri()).thenReturn("http://example.com");
when(this.authorizationResponse.getCode()).thenReturn("code");
}
@Test
public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception {
MockWebServer server = new MockWebServer();
String accessTokenSuccessResponse = "{\n" +
" \"access_token\": \"access-token-1234\",\n" +
" \"token_type\": \"bearer\",\n" +
" \"expires_in\": \"3600\",\n" +
" \"scope\": \"openid profile\",\n" +
" \"custom_parameter_1\": \"custom-value-1\",\n" +
" \"custom_parameter_2\": \"custom-value-2\"\n" +
"}\n";
server.enqueue(new MockResponse()
.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.setBody(accessTokenSuccessResponse));
server.start();
String tokenUri = server.url("/oauth2/token").toString();
when(this.providerDetails.getTokenUri()).thenReturn(tokenUri);
Instant expiresAtBefore = Instant.now().plusSeconds(3600);
OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(
new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange));
Instant expiresAtAfter = Instant.now().plusSeconds(3600);
server.shutdown();
assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234");
assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER);
assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter);
assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("openid", "profile");
assertThat(accessTokenResponse.getAdditionalParameters().size()).isEqualTo(2);
assertThat(accessTokenResponse.getAdditionalParameters()).containsEntry("custom_parameter_1", "custom-value-1");
assertThat(accessTokenResponse.getAdditionalParameters()).containsEntry("custom_parameter_2", "custom-value-2");
}
@Test
public void getTokenResponseWhenRedirectUriMalformedThenThrowIllegalArgumentException() throws Exception {
this.exception.expect(IllegalArgumentException.class);
String redirectUri = "http:\\example.com";
when(this.clientRegistration.getRedirectUri()).thenReturn(redirectUri);
this.tokenResponseClient.getTokenResponse(
new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange));
}
@Test
public void getTokenResponseWhenTokenUriMalformedThenThrowIllegalArgumentException() throws Exception {
this.exception.expect(IllegalArgumentException.class);
String tokenUri = "http:\\provider.com\\oauth2\\token";
when(this.providerDetails.getTokenUri()).thenReturn(tokenUri);
this.tokenResponseClient.getTokenResponse(
new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange));
}
@Test
public void getTokenResponseWhenSuccessResponseInvalidThenThrowOAuth2AuthenticationException() throws Exception {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_token_response"));
MockWebServer server = new MockWebServer();
String accessTokenSuccessResponse = "{\n" +
" \"access_token\": \"access-token-1234\",\n" +
" \"token_type\": \"bearer\",\n" +
" \"expires_in\": \"3600\",\n" +
" \"scope\": \"openid profile\",\n" +
" \"custom_parameter_1\": \"custom-value-1\",\n" +
" \"custom_parameter_2\": \"custom-value-2\"\n";
// "}\n"; // Make the JSON invalid/malformed
server.enqueue(new MockResponse()
.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.setBody(accessTokenSuccessResponse));
server.start();
String tokenUri = server.url("/oauth2/token").toString();
when(this.providerDetails.getTokenUri()).thenReturn(tokenUri);
try {
this.tokenResponseClient.getTokenResponse(
new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange));
} finally {
server.shutdown();
}
}
@Test
public void getTokenResponseWhenTokenUriInvalidThenThrowAuthenticationServiceException() throws Exception {
this.exception.expect(AuthenticationServiceException.class);
String tokenUri = "http://invalid-provider.com/oauth2/token";
when(this.providerDetails.getTokenUri()).thenReturn(tokenUri);
this.tokenResponseClient.getTokenResponse(
new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange));
}
@Test
public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthenticationException() throws Exception {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("unauthorized_client"));
MockWebServer server = new MockWebServer();
String accessTokenErrorResponse = "{\n" +
" \"error\": \"unauthorized_client\"\n" +
"}\n";
server.enqueue(new MockResponse()
.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.setResponseCode(500)
.setBody(accessTokenErrorResponse));
server.start();
String tokenUri = server.url("/oauth2/token").toString();
when(this.providerDetails.getTokenUri()).thenReturn(tokenUri);
try {
this.tokenResponseClient.getTokenResponse(
new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange));
} finally {
server.shutdown();
}
}
@Test
public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthenticationException() throws Exception {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_token_response"));
MockWebServer server = new MockWebServer();
String accessTokenSuccessResponse = "{\n" +
" \"access_token\": \"access-token-1234\",\n" +
" \"token_type\": \"not-bearer\",\n" +
" \"expires_in\": \"3600\"\n" +
"}\n";
server.enqueue(new MockResponse()
.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.setBody(accessTokenSuccessResponse));
server.start();
String tokenUri = server.url("/oauth2/token").toString();
when(this.providerDetails.getTokenUri()).thenReturn(tokenUri);
try {
this.tokenResponseClient.getTokenResponse(
new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange));
} finally {
server.shutdown();
}
}
@Test
public void getTokenResponseWhenSuccessResponseIncludesScopeThenReturnAccessTokenResponseUsingResponseScope() throws Exception {
MockWebServer server = new MockWebServer();
String accessTokenSuccessResponse = "{\n" +
" \"access_token\": \"access-token-1234\",\n" +
" \"token_type\": \"bearer\",\n" +
" \"expires_in\": \"3600\",\n" +
" \"scope\": \"openid profile\"\n" +
"}\n";
server.enqueue(new MockResponse()
.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.setBody(accessTokenSuccessResponse));
server.start();
String tokenUri = server.url("/oauth2/token").toString();
when(this.providerDetails.getTokenUri()).thenReturn(tokenUri);
Set<String> requestedScopes = new LinkedHashSet<>(Arrays.asList("openid", "profile", "email", "address"));
when(this.authorizationRequest.getScopes()).thenReturn(requestedScopes);
OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(
new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange));
server.shutdown();
assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("openid", "profile");
}
@Test
public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenReturnAccessTokenResponseUsingRequestedScope() throws Exception {
MockWebServer server = new MockWebServer();
String accessTokenSuccessResponse = "{\n" +
" \"access_token\": \"access-token-1234\",\n" +
" \"token_type\": \"bearer\",\n" +
" \"expires_in\": \"3600\"\n" +
"}\n";
server.enqueue(new MockResponse()
.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.setBody(accessTokenSuccessResponse));
server.start();
String tokenUri = server.url("/oauth2/token").toString();
when(this.providerDetails.getTokenUri()).thenReturn(tokenUri);
Set<String> requestedScopes = new LinkedHashSet<>(Arrays.asList("openid", "profile", "email", "address"));
when(this.authorizationRequest.getScopes()).thenReturn(requestedScopes);
OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(
new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange));
server.shutdown();
assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("openid", "profile", "email", "address");
}
}

View File

@ -0,0 +1,66 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.endpoint;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
/**
* Tests for {@link OAuth2AuthorizationCodeGrantRequest}.
*
* @author Joe Grandja
*/
@RunWith(PowerMockRunner.class)
@PrepareForTest({ClientRegistration.class, OAuth2AuthorizationExchange.class})
public class OAuth2AuthorizationCodeGrantRequestTests {
private ClientRegistration clientRegistration;
private OAuth2AuthorizationExchange authorizationExchange;
@Before
public void setUp() {
this.clientRegistration = mock(ClientRegistration.class);
this.authorizationExchange = mock(OAuth2AuthorizationExchange.class);
}
@Test(expected = IllegalArgumentException.class)
public void constructorWhenClientRegistrationIsNullThenThrowIllegalArgumentException() {
new OAuth2AuthorizationCodeGrantRequest(null, this.authorizationExchange);
}
@Test(expected = IllegalArgumentException.class)
public void constructorWhenAuthorizationExchangeIsNullThenThrowIllegalArgumentException() {
new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, null);
}
@Test
public void constructorWhenAllParametersProvidedAndValidThenCreated() {
OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest =
new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange);
assertThat(authorizationCodeGrantRequest.getClientRegistration()).isEqualTo(this.clientRegistration);
assertThat(authorizationCodeGrantRequest.getAuthorizationExchange()).isEqualTo(this.authorizationExchange);
assertThat(authorizationCodeGrantRequest.getGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
}
}

View File

@ -0,0 +1,414 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.oidc.authentication;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.mockito.stubbing.Answer;
import org.powermock.api.mockito.PowerMockito;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken;
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat;
import static org.hamcrest.CoreMatchers.containsString;
import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
* Tests for {@link OidcAuthorizationCodeAuthenticationProvider}.
*
* @author Joe Grandja
*/
@PrepareForTest({ClientRegistration.class, OAuth2AuthorizationRequest.class, OAuth2AuthorizationResponse.class,
OAuth2AccessTokenResponse.class, OidcAuthorizationCodeAuthenticationProvider.class})
@RunWith(PowerMockRunner.class)
public class OidcAuthorizationCodeAuthenticationProviderTests {
private ClientRegistration clientRegistration;
private ClientRegistration.ProviderDetails providerDetails;
private OAuth2AuthorizationRequest authorizationRequest;
private OAuth2AuthorizationResponse authorizationResponse;
private OAuth2AuthorizationExchange authorizationExchange;
private OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient;
private OAuth2AccessTokenResponse accessTokenResponse;
private OAuth2AccessToken accessToken;
private OAuth2UserService<OidcUserRequest, OidcUser> userService;
private OidcAuthorizationCodeAuthenticationProvider authenticationProvider;
@Rule
public ExpectedException exception = ExpectedException.none();
@Before
@SuppressWarnings("unchecked")
public void setUp() throws Exception {
this.clientRegistration = mock(ClientRegistration.class);
this.providerDetails = mock(ClientRegistration.ProviderDetails.class);
this.authorizationRequest = mock(OAuth2AuthorizationRequest.class);
this.authorizationResponse = mock(OAuth2AuthorizationResponse.class);
this.authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, this.authorizationResponse);
this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class);
this.accessTokenResponse = mock(OAuth2AccessTokenResponse.class);
this.accessToken = mock(OAuth2AccessToken.class);
this.userService = mock(OAuth2UserService.class);
this.authenticationProvider = PowerMockito.spy(
new OidcAuthorizationCodeAuthenticationProvider(this.accessTokenResponseClient, this.userService));
when(this.clientRegistration.getRegistrationId()).thenReturn("client-registration-id-1");
when(this.clientRegistration.getClientId()).thenReturn("client1");
when(this.clientRegistration.getProviderDetails()).thenReturn(this.providerDetails);
when(this.providerDetails.getJwkSetUri()).thenReturn("https://provider.com/oauth2/keys");
when(this.authorizationRequest.getScopes()).thenReturn(new LinkedHashSet<>(Arrays.asList("openid", "profile", "email")));
when(this.authorizationRequest.getState()).thenReturn("12345");
when(this.authorizationResponse.getState()).thenReturn("12345");
when(this.authorizationRequest.getRedirectUri()).thenReturn("http://example.com");
when(this.authorizationResponse.getRedirectUri()).thenReturn("http://example.com");
when(this.accessTokenResponse.getAccessToken()).thenReturn(this.accessToken);
Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put(OidcParameterNames.ID_TOKEN, "id-token");
when(this.accessTokenResponse.getAdditionalParameters()).thenReturn(additionalParameters);
when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(this.accessTokenResponse);
}
@Test
public void constructorWhenAccessTokenResponseClientIsNullThenThrowIllegalArgumentException() {
this.exception.expect(IllegalArgumentException.class);
new OidcAuthorizationCodeAuthenticationProvider(null, this.userService);
}
@Test
public void constructorWhenUserServiceIsNullThenThrowIllegalArgumentException() {
this.exception.expect(IllegalArgumentException.class);
new OidcAuthorizationCodeAuthenticationProvider(this.accessTokenResponseClient, null);
}
@Test
public void setAuthoritiesMapperWhenAuthoritiesMapperIsNullThenThrowIllegalArgumentException() {
this.exception.expect(IllegalArgumentException.class);
this.authenticationProvider.setAuthoritiesMapper(null);
}
@Test
public void supportsWhenTypeOAuth2LoginAuthenticationTokenThenReturnTrue() {
assertThat(this.authenticationProvider.supports(OAuth2LoginAuthenticationToken.class)).isTrue();
}
@Test
public void authenticateWhenAuthorizationRequestDoesNotContainOpenidScopeThenReturnNull() {
when(this.authorizationRequest.getScopes()).thenReturn(new LinkedHashSet<>(Collections.singleton("scope1")));
OAuth2LoginAuthenticationToken authentication =
(OAuth2LoginAuthenticationToken)this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
assertThat(authentication).isNull();
}
@Test
public void authenticateWhenAuthorizationErrorResponseThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString(OAuth2ErrorCodes.INVALID_SCOPE));
when(this.authorizationResponse.statusError()).thenReturn(true);
when(this.authorizationResponse.getError()).thenReturn(new OAuth2Error(OAuth2ErrorCodes.INVALID_SCOPE));
this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
}
@Test
public void authenticateWhenAuthorizationResponseStateNotEqualAuthorizationRequestStateThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_state_parameter"));
when(this.authorizationRequest.getState()).thenReturn("34567");
when(this.authorizationResponse.getState()).thenReturn("89012");
this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
}
@Test
public void authenticateWhenAuthorizationResponseRedirectUriNotEqualAuthorizationRequestRedirectUriThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_redirect_uri_parameter"));
when(this.authorizationRequest.getRedirectUri()).thenReturn("http://example1.com");
when(this.authorizationResponse.getRedirectUri()).thenReturn("http://example2.com");
this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
}
@Test
public void authenticateWhenTokenResponseDoesNotContainIdTokenThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_id_token"));
when(this.accessTokenResponse.getAdditionalParameters()).thenReturn(Collections.emptyMap());
this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
}
@Test
public void authenticateWhenJwkSetUriNotSetThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("missing_signature_verifier"));
when(this.providerDetails.getJwkSetUri()).thenReturn(null);
this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
}
@Test
public void authenticateWhenIdTokenIssuerClaimIsNullThenThrowOAuth2AuthenticationException() throws Exception {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_id_token"));
Map<String, Object> claims = new HashMap<>();
claims.put(IdTokenClaimNames.SUB, "subject1");
this.setUpIdToken(claims);
this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
}
@Test
public void authenticateWhenIdTokenSubjectClaimIsNullThenThrowOAuth2AuthenticationException() throws Exception {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_id_token"));
Map<String, Object> claims = new HashMap<>();
claims.put(IdTokenClaimNames.ISS, "https://provider.com");
this.setUpIdToken(claims);
this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
}
@Test
public void authenticateWhenIdTokenAudienceClaimIsNullThenThrowOAuth2AuthenticationException() throws Exception {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_id_token"));
Map<String, Object> claims = new HashMap<>();
claims.put(IdTokenClaimNames.ISS, "https://provider.com");
claims.put(IdTokenClaimNames.SUB, "subject1");
this.setUpIdToken(claims);
this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
}
@Test
public void authenticateWhenIdTokenAudienceClaimDoesNotContainClientIdThenThrowOAuth2AuthenticationException() throws Exception {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_id_token"));
Map<String, Object> claims = new HashMap<>();
claims.put(IdTokenClaimNames.ISS, "https://provider.com");
claims.put(IdTokenClaimNames.SUB, "subject1");
claims.put(IdTokenClaimNames.AUD, Collections.singletonList("other-client"));
this.setUpIdToken(claims);
this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
}
@Test
public void authenticateWhenIdTokenMultipleAudienceClaimAndAuthorizedPartyClaimIsNullThenThrowOAuth2AuthenticationException() throws Exception {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_id_token"));
Map<String, Object> claims = new HashMap<>();
claims.put(IdTokenClaimNames.ISS, "https://provider.com");
claims.put(IdTokenClaimNames.SUB, "subject1");
claims.put(IdTokenClaimNames.AUD, Arrays.asList("client1", "client2"));
this.setUpIdToken(claims);
this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
}
@Test
public void authenticateWhenIdTokenAuthorizedPartyClaimNotEqualToClientIdThenThrowOAuth2AuthenticationException() throws Exception {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_id_token"));
Map<String, Object> claims = new HashMap<>();
claims.put(IdTokenClaimNames.ISS, "https://provider.com");
claims.put(IdTokenClaimNames.SUB, "subject1");
claims.put(IdTokenClaimNames.AUD, Arrays.asList("client1", "client2"));
claims.put(IdTokenClaimNames.AZP, "other-client");
this.setUpIdToken(claims);
this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
}
@Test
public void authenticateWhenIdTokenExpiresAtIsBeforeNowThenThrowOAuth2AuthenticationException() throws Exception {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_id_token"));
Map<String, Object> claims = new HashMap<>();
claims.put(IdTokenClaimNames.ISS, "https://provider.com");
claims.put(IdTokenClaimNames.SUB, "subject1");
claims.put(IdTokenClaimNames.AUD, Arrays.asList("client1", "client2"));
claims.put(IdTokenClaimNames.AZP, "client1");
Instant issuedAt = Instant.now().minusSeconds(10);
Instant expiresAt = Instant.from(issuedAt).plusSeconds(5);
this.setUpIdToken(claims, issuedAt, expiresAt);
this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
}
@Test
public void authenticateWhenIdTokenIssuedAtIsAfterMaxIssuedAtThenThrowOAuth2AuthenticationException() throws Exception {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_id_token"));
Map<String, Object> claims = new HashMap<>();
claims.put(IdTokenClaimNames.ISS, "https://provider.com");
claims.put(IdTokenClaimNames.SUB, "subject1");
claims.put(IdTokenClaimNames.AUD, Arrays.asList("client1", "client2"));
claims.put(IdTokenClaimNames.AZP, "client1");
Instant issuedAt = Instant.now().plusSeconds(35);
Instant expiresAt = Instant.from(issuedAt).plusSeconds(60);
this.setUpIdToken(claims, issuedAt, expiresAt);
this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
}
@Test
public void authenticateWhenLoginSuccessThenReturnAuthentication() throws Exception {
Map<String, Object> claims = new HashMap<>();
claims.put(IdTokenClaimNames.ISS, "https://provider.com");
claims.put(IdTokenClaimNames.SUB, "subject1");
claims.put(IdTokenClaimNames.AUD, Arrays.asList("client1", "client2"));
claims.put(IdTokenClaimNames.AZP, "client1");
this.setUpIdToken(claims);
OidcUser principal = mock(OidcUser.class);
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
when(principal.getAuthorities()).thenAnswer(
(Answer<List<GrantedAuthority>>) invocation -> authorities);
when(this.userService.loadUser(any())).thenReturn(principal);
OAuth2LoginAuthenticationToken authentication =
(OAuth2LoginAuthenticationToken)this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
assertThat(authentication.isAuthenticated()).isTrue();
assertThat(authentication.getPrincipal()).isEqualTo(principal);
assertThat(authentication.getCredentials()).isEqualTo("");
assertThat(authentication.getAuthorities()).isEqualTo(authorities);
assertThat(authentication.getClientRegistration()).isEqualTo(this.clientRegistration);
assertThat(authentication.getAuthorizationExchange()).isEqualTo(this.authorizationExchange);
assertThat(authentication.getAccessToken()).isEqualTo(this.accessToken);
}
@Test
public void authenticateWhenAuthoritiesMapperSetThenReturnMappedAuthorities() throws Exception {
Map<String, Object> claims = new HashMap<>();
claims.put(IdTokenClaimNames.ISS, "https://provider.com");
claims.put(IdTokenClaimNames.SUB, "subject1");
claims.put(IdTokenClaimNames.AUD, Arrays.asList("client1", "client2"));
claims.put(IdTokenClaimNames.AZP, "client1");
this.setUpIdToken(claims);
OidcUser principal = mock(OidcUser.class);
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
when(principal.getAuthorities()).thenAnswer(
(Answer<List<GrantedAuthority>>) invocation -> authorities);
when(this.userService.loadUser(any())).thenReturn(principal);
List<GrantedAuthority> mappedAuthorities = AuthorityUtils.createAuthorityList("ROLE_OIDC_USER");
GrantedAuthoritiesMapper authoritiesMapper = mock(GrantedAuthoritiesMapper.class);
when(authoritiesMapper.mapAuthorities(anyCollection())).thenAnswer(
(Answer<List<GrantedAuthority>>) invocation -> mappedAuthorities);
this.authenticationProvider.setAuthoritiesMapper(authoritiesMapper);
OAuth2LoginAuthenticationToken authentication =
(OAuth2LoginAuthenticationToken)this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
assertThat(authentication.getAuthorities()).isEqualTo(mappedAuthorities);
}
private void setUpIdToken(Map<String, Object> claims) throws Exception {
Instant issuedAt = Instant.now();
Instant expiresAt = Instant.from(issuedAt).plusSeconds(3600);
this.setUpIdToken(claims, issuedAt, expiresAt);
}
private void setUpIdToken(Map<String, Object> claims, Instant issuedAt, Instant expiresAt) throws Exception {
Map<String, Object> headers = new HashMap<>();
headers.put("alg", "RS256");
Jwt idToken = new Jwt("id-token", issuedAt, expiresAt, headers, claims);
JwtDecoder jwtDecoder = mock(JwtDecoder.class);
when(jwtDecoder.decode(anyString())).thenReturn(idToken);
PowerMockito.doReturn(jwtDecoder).when(this.authenticationProvider, "getJwtDecoder", any(ClientRegistration.class));
}
}

View File

@ -0,0 +1,73 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.oidc.userinfo;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
/**
* Tests for {@link OidcUserRequest}.
*
* @author Joe Grandja
*/
@RunWith(PowerMockRunner.class)
@PrepareForTest(ClientRegistration.class)
public class OidcUserRequestTests {
private ClientRegistration clientRegistration;
private OAuth2AccessToken accessToken;
private OidcIdToken idToken;
@Before
public void setUp() {
this.clientRegistration = mock(ClientRegistration.class);
this.accessToken = mock(OAuth2AccessToken.class);
this.idToken = mock(OidcIdToken.class);
}
@Test(expected = IllegalArgumentException.class)
public void constructorWhenClientRegistrationIsNullThenThrowIllegalArgumentException() {
new OidcUserRequest(null, this.accessToken, this.idToken);
}
@Test(expected = IllegalArgumentException.class)
public void constructorWhenAccessTokenIsNullThenThrowIllegalArgumentException() {
new OidcUserRequest(this.clientRegistration, null, this.idToken);
}
@Test(expected = IllegalArgumentException.class)
public void constructorWhenIdTokenIsNullThenThrowIllegalArgumentException() {
new OidcUserRequest(this.clientRegistration, this.accessToken, null);
}
@Test
public void constructorWhenAllParametersProvidedAndValidThenCreated() {
OidcUserRequest userRequest = new OidcUserRequest(
this.clientRegistration, this.accessToken, this.idToken);
assertThat(userRequest.getClientRegistration()).isEqualTo(this.clientRegistration);
assertThat(userRequest.getAccessToken()).isEqualTo(this.accessToken);
assertThat(userRequest.getIdToken()).isEqualTo(this.idToken);
}
}

View File

@ -0,0 +1,260 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.oidc.userinfo;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.powermock.core.classloader.annotations.PowerMockIgnore;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.security.authentication.AuthenticationServiceException;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.core.oidc.OidcScopes;
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import static org.assertj.core.api.Assertions.assertThat;
import static org.hamcrest.CoreMatchers.containsString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
* Tests for {@link OidcUserService}.
*
* @author Joe Grandja
*/
@PowerMockIgnore("okhttp3.*")
@PrepareForTest(ClientRegistration.class)
@RunWith(PowerMockRunner.class)
public class OidcUserServiceTests {
private ClientRegistration clientRegistration;
private ClientRegistration.ProviderDetails providerDetails;
private ClientRegistration.ProviderDetails.UserInfoEndpoint userInfoEndpoint;
private OAuth2AccessToken accessToken;
private OidcIdToken idToken;
private OidcUserService userService = new OidcUserService();
@Rule
public ExpectedException exception = ExpectedException.none();
@Before
public void setUp() throws Exception {
this.clientRegistration = mock(ClientRegistration.class);
this.providerDetails = mock(ClientRegistration.ProviderDetails.class);
this.userInfoEndpoint = mock(ClientRegistration.ProviderDetails.UserInfoEndpoint.class);
when(this.clientRegistration.getProviderDetails()).thenReturn(this.providerDetails);
when(this.providerDetails.getUserInfoEndpoint()).thenReturn(this.userInfoEndpoint);
when(this.clientRegistration.getAuthorizationGrantType()).thenReturn(AuthorizationGrantType.AUTHORIZATION_CODE);
this.accessToken = mock(OAuth2AccessToken.class);
Set<String> authorizedScopes = new LinkedHashSet<>(Arrays.asList(OidcScopes.OPENID, OidcScopes.PROFILE));
when(this.accessToken.getScopes()).thenReturn(authorizedScopes);
this.idToken = mock(OidcIdToken.class);
Map<String, Object> idTokenClaims = new HashMap<>();
idTokenClaims.put(IdTokenClaimNames.ISS, "https://provider.com");
idTokenClaims.put(IdTokenClaimNames.SUB, "subject1");
when(this.idToken.getClaims()).thenReturn(idTokenClaims);
when(this.idToken.getSubject()).thenReturn("subject1");
}
@Test
public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() {
this.exception.expect(IllegalArgumentException.class);
this.userService.loadUser(null);
}
@Test
public void loadUserWhenUserInfoUriIsNullThenUserInfoEndpointNotRequested() {
when(this.userInfoEndpoint.getUri()).thenReturn(null);
OidcUser user = this.userService.loadUser(
new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
assertThat(user.getUserInfo()).isNull();
}
@Test
public void loadUserWhenAuthorizedScopesDoesNotContainUserInfoScopesThenUserInfoEndpointNotRequested() {
Set<String> authorizedScopes = new LinkedHashSet<>(Arrays.asList("scope1", "scope2"));
when(this.accessToken.getScopes()).thenReturn(authorizedScopes);
when(this.userInfoEndpoint.getUri()).thenReturn("http://provider.com/user");
OidcUser user = this.userService.loadUser(
new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
assertThat(user.getUserInfo()).isNull();
}
@Test
public void loadUserWhenUserInfoSuccessResponseThenReturnUser() throws Exception {
MockWebServer server = new MockWebServer();
String userInfoResponse = "{\n" +
" \"sub\": \"subject1\",\n" +
" \"name\": \"first last\",\n" +
" \"given_name\": \"first\",\n" +
" \"family_name\": \"last\",\n" +
" \"preferred_username\": \"user1\",\n" +
" \"email\": \"user1@example.com\"\n" +
"}\n";
server.enqueue(new MockResponse()
.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
.setBody(userInfoResponse));
server.start();
String userInfoUri = server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.accessToken.getTokenValue()).thenReturn("access-token");
OidcUser user = this.userService.loadUser(
new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
server.shutdown();
assertThat(user.getIdToken()).isNotNull();
assertThat(user.getUserInfo()).isNotNull();
assertThat(user.getUserInfo().getClaims().size()).isEqualTo(6);
assertThat(user.getIdToken()).isEqualTo(this.idToken);
assertThat(user.getName()).isEqualTo("subject1");
assertThat(user.getUserInfo().getSubject()).isEqualTo("subject1");
assertThat(user.getUserInfo().getFullName()).isEqualTo("first last");
assertThat(user.getUserInfo().getGivenName()).isEqualTo("first");
assertThat(user.getUserInfo().getFamilyName()).isEqualTo("last");
assertThat(user.getUserInfo().getPreferredUsername()).isEqualTo("user1");
assertThat(user.getUserInfo().getEmail()).isEqualTo("user1@example.com");
assertThat(user.getAuthorities().size()).isEqualTo(1);
assertThat(user.getAuthorities().iterator().next()).isInstanceOf(OidcUserAuthority.class);
OidcUserAuthority userAuthority = (OidcUserAuthority)user.getAuthorities().iterator().next();
assertThat(userAuthority.getAuthority()).isEqualTo("ROLE_USER");
assertThat(userAuthority.getIdToken()).isEqualTo(user.getIdToken());
assertThat(userAuthority.getUserInfo()).isEqualTo(user.getUserInfo());
}
@Test
public void loadUserWhenUserInfoSuccessResponseAndUserInfoSubjectNotSameAsIdTokenSubjectThenThrowOAuth2AuthenticationException() throws Exception {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_user_info_response"));
MockWebServer server = new MockWebServer();
String userInfoResponse = "{\n" +
" \"sub\": \"other-subject\"\n" +
"}\n";
server.enqueue(new MockResponse()
.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
.setBody(userInfoResponse));
server.start();
String userInfoUri = server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.accessToken.getTokenValue()).thenReturn("access-token");
try {
this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
} finally {
server.shutdown();
}
}
@Test
public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() throws Exception {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_user_info_response"));
MockWebServer server = new MockWebServer();
String userInfoResponse = "{\n" +
" \"sub\": \"subject1\",\n" +
" \"name\": \"first last\",\n" +
" \"given_name\": \"first\",\n" +
" \"family_name\": \"last\",\n" +
" \"preferred_username\": \"user1\",\n" +
" \"email\": \"user1@example.com\"\n";
// "}\n"; // Make the JSON invalid/malformed
server.enqueue(new MockResponse()
.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
.setBody(userInfoResponse));
server.start();
String userInfoUri = server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.accessToken.getTokenValue()).thenReturn("access-token");
try {
this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
} finally {
server.shutdown();
}
}
@Test
public void loadUserWhenUserInfoErrorResponseThenThrowOAuth2AuthenticationException() throws Exception {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_user_info_response"));
MockWebServer server = new MockWebServer();
server.enqueue(new MockResponse().setResponseCode(500));
server.start();
String userInfoUri = server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.accessToken.getTokenValue()).thenReturn("access-token");
try {
this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
} finally {
server.shutdown();
}
}
@Test
public void loadUserWhenUserInfoUriInvalidThenThrowAuthenticationServiceException() throws Exception {
this.exception.expect(AuthenticationServiceException.class);
String userInfoUri = "http://invalid-provider.com/user";
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.accessToken.getTokenValue()).thenReturn("access-token");
this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
}
}

View File

@ -0,0 +1,354 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.registration;
import org.junit.Test;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.Set;
import static org.assertj.core.api.Assertions.assertThat;
/**
* Tests for {@link ClientRegistration}.
*
* @author Joe Grandja
*/
public class ClientRegistrationTests {
private static final String REGISTRATION_ID = "registration-1";
private static final String CLIENT_ID = "client-1";
private static final String CLIENT_SECRET = "secret";
private static final String REDIRECT_URI = "https://example.com";
private static final Set<String> SCOPES = new LinkedHashSet<>(Arrays.asList("openid", "profile", "email"));
private static final String AUTHORIZATION_URI = "https://provider.com/oauth2/authorization";
private static final String TOKEN_URI = "https://provider.com/oauth2/token";
private static final String JWK_SET_URI = "https://provider.com/oauth2/keys";
private static final String CLIENT_NAME = "Client 1";
@Test(expected = IllegalArgumentException.class)
public void buildWhenAuthorizationGrantTypeIsNullThenThrowIllegalArgumentException() {
ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID)
.clientSecret(CLIENT_SECRET)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(null)
.redirectUri(REDIRECT_URI)
.scope(SCOPES.toArray(new String[0]))
.authorizationUri(AUTHORIZATION_URI)
.tokenUri(TOKEN_URI)
.jwkSetUri(JWK_SET_URI)
.clientName(CLIENT_NAME)
.build();
}
@Test
public void buildWhenAuthorizationCodeGrantAllAttributesProvidedThenAllAttributesAreSet() {
ClientRegistration registration = ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID)
.clientSecret(CLIENT_SECRET)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.redirectUri(REDIRECT_URI)
.scope(SCOPES.toArray(new String[0]))
.authorizationUri(AUTHORIZATION_URI)
.tokenUri(TOKEN_URI)
.jwkSetUri(JWK_SET_URI)
.clientName(CLIENT_NAME)
.build();
assertThat(registration.getRegistrationId()).isEqualTo(REGISTRATION_ID);
assertThat(registration.getClientId()).isEqualTo(CLIENT_ID);
assertThat(registration.getClientSecret()).isEqualTo(CLIENT_SECRET);
assertThat(registration.getClientAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.BASIC);
assertThat(registration.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
assertThat(registration.getRedirectUri()).isEqualTo(REDIRECT_URI);
assertThat(registration.getScopes()).isEqualTo(SCOPES);
assertThat(registration.getProviderDetails().getAuthorizationUri()).isEqualTo(AUTHORIZATION_URI);
assertThat(registration.getProviderDetails().getTokenUri()).isEqualTo(TOKEN_URI);
assertThat(registration.getProviderDetails().getJwkSetUri()).isEqualTo(JWK_SET_URI);
assertThat(registration.getClientName()).isEqualTo(CLIENT_NAME);
}
@Test(expected = IllegalArgumentException.class)
public void buildWhenAuthorizationCodeGrantRegistrationIdIsNullThenThrowIllegalArgumentException() {
ClientRegistration.withRegistrationId(null)
.clientId(CLIENT_ID)
.clientSecret(CLIENT_SECRET)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.redirectUri(REDIRECT_URI)
.scope(SCOPES.toArray(new String[0]))
.authorizationUri(AUTHORIZATION_URI)
.tokenUri(TOKEN_URI)
.jwkSetUri(JWK_SET_URI)
.clientName(CLIENT_NAME)
.build();
}
@Test(expected = IllegalArgumentException.class)
public void buildWhenAuthorizationCodeGrantClientIdIsNullThenThrowIllegalArgumentException() {
ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(null)
.clientSecret(CLIENT_SECRET)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.redirectUri(REDIRECT_URI)
.scope(SCOPES.toArray(new String[0]))
.authorizationUri(AUTHORIZATION_URI)
.tokenUri(TOKEN_URI)
.jwkSetUri(JWK_SET_URI)
.clientName(CLIENT_NAME)
.build();
}
@Test(expected = IllegalArgumentException.class)
public void buildWhenAuthorizationCodeGrantClientSecretIsNullThenThrowIllegalArgumentException() {
ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID)
.clientSecret(null)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.redirectUri(REDIRECT_URI)
.scope(SCOPES.toArray(new String[0]))
.authorizationUri(AUTHORIZATION_URI)
.tokenUri(TOKEN_URI)
.jwkSetUri(JWK_SET_URI)
.clientName(CLIENT_NAME)
.build();
}
@Test(expected = IllegalArgumentException.class)
public void buildWhenAuthorizationCodeGrantClientAuthenticationMethodIsNullThenThrowIllegalArgumentException() {
ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID)
.clientSecret(CLIENT_SECRET)
.clientAuthenticationMethod(null)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.redirectUri(REDIRECT_URI)
.scope(SCOPES.toArray(new String[0]))
.authorizationUri(AUTHORIZATION_URI)
.tokenUri(TOKEN_URI)
.jwkSetUri(JWK_SET_URI)
.clientName(CLIENT_NAME)
.build();
}
@Test(expected = IllegalArgumentException.class)
public void buildWhenAuthorizationCodeGrantRedirectUriIsNullThenThrowIllegalArgumentException() {
ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID)
.clientSecret(CLIENT_SECRET)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.redirectUri(null)
.scope(SCOPES.toArray(new String[0]))
.authorizationUri(AUTHORIZATION_URI)
.tokenUri(TOKEN_URI)
.jwkSetUri(JWK_SET_URI)
.clientName(CLIENT_NAME)
.build();
}
@Test(expected = IllegalArgumentException.class)
public void buildWhenAuthorizationCodeGrantScopeIsNullThenThrowIllegalArgumentException() {
ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID)
.clientSecret(CLIENT_SECRET)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.redirectUri(REDIRECT_URI)
.scope(null)
.authorizationUri(AUTHORIZATION_URI)
.tokenUri(TOKEN_URI)
.jwkSetUri(JWK_SET_URI)
.clientName(CLIENT_NAME)
.build();
}
@Test(expected = IllegalArgumentException.class)
public void buildWhenAuthorizationCodeGrantAuthorizationUriIsNullThenThrowIllegalArgumentException() {
ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID)
.clientSecret(CLIENT_SECRET)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.redirectUri(REDIRECT_URI)
.scope(SCOPES.toArray(new String[0]))
.authorizationUri(null)
.tokenUri(TOKEN_URI)
.jwkSetUri(JWK_SET_URI)
.clientName(CLIENT_NAME)
.build();
}
@Test(expected = IllegalArgumentException.class)
public void buildWhenAuthorizationCodeGrantTokenUriIsNullThenThrowIllegalArgumentException() {
ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID)
.clientSecret(CLIENT_SECRET)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.redirectUri(REDIRECT_URI)
.scope(SCOPES.toArray(new String[0]))
.authorizationUri(AUTHORIZATION_URI)
.tokenUri(null)
.jwkSetUri(JWK_SET_URI)
.clientName(CLIENT_NAME)
.build();
}
@Test(expected = IllegalArgumentException.class)
public void buildWhenAuthorizationCodeGrantJwkSetUriIsNullThenThrowIllegalArgumentException() {
ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID)
.clientSecret(CLIENT_SECRET)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.redirectUri(REDIRECT_URI)
.scope(SCOPES.toArray(new String[0]))
.authorizationUri(AUTHORIZATION_URI)
.tokenUri(TOKEN_URI)
.jwkSetUri(null)
.clientName(CLIENT_NAME)
.build();
}
@Test(expected = IllegalArgumentException.class)
public void buildWhenAuthorizationCodeGrantClientNameIsNullThenThrowIllegalArgumentException() {
ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID)
.clientSecret(CLIENT_SECRET)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.redirectUri(REDIRECT_URI)
.scope(SCOPES.toArray(new String[0]))
.authorizationUri(AUTHORIZATION_URI)
.tokenUri(TOKEN_URI)
.jwkSetUri(JWK_SET_URI)
.clientName(null)
.build();
}
@Test
public void buildWhenAuthorizationCodeGrantScopeDoesNotContainOpenidThenJwkSetUriNotRequired() {
ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID)
.clientSecret(CLIENT_SECRET)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.redirectUri(REDIRECT_URI)
.scope("scope1")
.authorizationUri(AUTHORIZATION_URI)
.tokenUri(TOKEN_URI)
.clientName(CLIENT_NAME)
.build();
}
@Test
public void buildWhenImplicitGrantAllAttributesProvidedThenAllAttributesAreSet() {
ClientRegistration registration = ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID)
.authorizationGrantType(AuthorizationGrantType.IMPLICIT)
.redirectUri(REDIRECT_URI)
.scope(SCOPES.toArray(new String[0]))
.authorizationUri(AUTHORIZATION_URI)
.clientName(CLIENT_NAME)
.build();
assertThat(registration.getRegistrationId()).isEqualTo(REGISTRATION_ID);
assertThat(registration.getClientId()).isEqualTo(CLIENT_ID);
assertThat(registration.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.IMPLICIT);
assertThat(registration.getRedirectUri()).isEqualTo(REDIRECT_URI);
assertThat(registration.getScopes()).isEqualTo(SCOPES);
assertThat(registration.getProviderDetails().getAuthorizationUri()).isEqualTo(AUTHORIZATION_URI);
assertThat(registration.getClientName()).isEqualTo(CLIENT_NAME);
}
@Test(expected = IllegalArgumentException.class)
public void buildWhenImplicitGrantRegistrationIdIsNullThenThrowIllegalArgumentException() {
ClientRegistration.withRegistrationId(null)
.clientId(CLIENT_ID)
.authorizationGrantType(AuthorizationGrantType.IMPLICIT)
.redirectUri(REDIRECT_URI)
.scope(SCOPES.toArray(new String[0]))
.authorizationUri(AUTHORIZATION_URI)
.clientName(CLIENT_NAME)
.build();
}
@Test(expected = IllegalArgumentException.class)
public void buildWhenImplicitGrantClientIdIsNullThenThrowIllegalArgumentException() {
ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(null)
.authorizationGrantType(AuthorizationGrantType.IMPLICIT)
.redirectUri(REDIRECT_URI)
.scope(SCOPES.toArray(new String[0]))
.authorizationUri(AUTHORIZATION_URI)
.clientName(CLIENT_NAME)
.build();
}
@Test(expected = IllegalArgumentException.class)
public void buildWhenImplicitGrantRedirectUriIsNullThenThrowIllegalArgumentException() {
ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID)
.authorizationGrantType(AuthorizationGrantType.IMPLICIT)
.redirectUri(null)
.scope(SCOPES.toArray(new String[0]))
.authorizationUri(AUTHORIZATION_URI)
.clientName(CLIENT_NAME)
.build();
}
@Test(expected = IllegalArgumentException.class)
public void buildWhenImplicitGrantScopeIsNullThenThrowIllegalArgumentException() {
ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID)
.authorizationGrantType(AuthorizationGrantType.IMPLICIT)
.redirectUri(REDIRECT_URI)
.scope(null)
.authorizationUri(AUTHORIZATION_URI)
.clientName(CLIENT_NAME)
.build();
}
@Test(expected = IllegalArgumentException.class)
public void buildWhenImplicitGrantAuthorizationUriIsNullThenThrowIllegalArgumentException() {
ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID)
.authorizationGrantType(AuthorizationGrantType.IMPLICIT)
.redirectUri(REDIRECT_URI)
.scope(SCOPES.toArray(new String[0]))
.authorizationUri(null)
.clientName(CLIENT_NAME)
.build();
}
@Test(expected = IllegalArgumentException.class)
public void buildWhenImplicitGrantClientNameIsNullThenThrowIllegalArgumentException() {
ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID)
.authorizationGrantType(AuthorizationGrantType.IMPLICIT)
.redirectUri(REDIRECT_URI)
.scope(SCOPES.toArray(new String[0]))
.authorizationUri(AUTHORIZATION_URI)
.clientName(null)
.build();
}
}

View File

@ -24,9 +24,11 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.assertj.core.api.Assertions.*;
import static org.assertj.core.api.Assertions.assertThat;
/**
* Tests for {@link InMemoryClientRegistrationRepository}.
*
* @author Rob Winch
* @since 5.0
*/

View File

@ -0,0 +1,267 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.userinfo;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.powermock.core.classloader.annotations.PowerMockIgnore;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.security.authentication.AuthenticationServiceException;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.user.OAuth2User;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat;
import static org.hamcrest.CoreMatchers.containsString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
* Tests for {@link CustomUserTypesOAuth2UserService}.
*
* @author Joe Grandja
*/
@PowerMockIgnore("okhttp3.*")
@PrepareForTest(ClientRegistration.class)
@RunWith(PowerMockRunner.class)
public class CustomUserTypesOAuth2UserServiceTests {
private ClientRegistration clientRegistration;
private ClientRegistration.ProviderDetails providerDetails;
private ClientRegistration.ProviderDetails.UserInfoEndpoint userInfoEndpoint;
private OAuth2AccessToken accessToken;
private CustomUserTypesOAuth2UserService userService;
@Rule
public ExpectedException exception = ExpectedException.none();
@Before
public void setUp() throws Exception {
this.clientRegistration = mock(ClientRegistration.class);
this.providerDetails = mock(ClientRegistration.ProviderDetails.class);
this.userInfoEndpoint = mock(ClientRegistration.ProviderDetails.UserInfoEndpoint.class);
when(this.clientRegistration.getProviderDetails()).thenReturn(this.providerDetails);
when(this.providerDetails.getUserInfoEndpoint()).thenReturn(this.userInfoEndpoint);
String registrationId = "client-registration-id-1";
when(this.clientRegistration.getRegistrationId()).thenReturn(registrationId);
this.accessToken = mock(OAuth2AccessToken.class);
Map<String, Class<? extends OAuth2User>> customUserTypes = new HashMap<>();
customUserTypes.put(registrationId, CustomOAuth2User.class);
this.userService = new CustomUserTypesOAuth2UserService(customUserTypes);
}
@Test
public void constructorWhenCustomUserTypesIsNullThenThrowIllegalArgumentException() {
this.exception.expect(IllegalArgumentException.class);
new CustomUserTypesOAuth2UserService(null);
}
@Test
public void constructorWhenCustomUserTypesIsEmptyThenThrowIllegalArgumentException() {
this.exception.expect(IllegalArgumentException.class);
new CustomUserTypesOAuth2UserService(Collections.emptyMap());
}
@Test
public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() {
this.exception.expect(IllegalArgumentException.class);
this.userService.loadUser(null);
}
@Test
public void loadUserWhenCustomUserTypeNotFoundThenReturnNull() {
when(this.clientRegistration.getRegistrationId()).thenReturn("other-client-registration-id-1");
OAuth2User user = this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
assertThat(user).isNull();
}
@Test
public void loadUserWhenUserInfoSuccessResponseThenReturnUser() throws Exception {
MockWebServer server = new MockWebServer();
String userInfoResponse = "{\n" +
" \"id\": \"12345\",\n" +
" \"name\": \"first last\",\n" +
" \"login\": \"user1\",\n" +
" \"email\": \"user1@example.com\"\n" +
"}\n";
server.enqueue(new MockResponse()
.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
.setBody(userInfoResponse));
server.start();
String userInfoUri = server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.accessToken.getTokenValue()).thenReturn("access-token");
OAuth2User user = this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
server.shutdown();
assertThat(user.getName()).isEqualTo("first last");
assertThat(user.getAttributes().size()).isEqualTo(4);
assertThat(user.getAttributes().get("id")).isEqualTo("12345");
assertThat(user.getAttributes().get("name")).isEqualTo("first last");
assertThat(user.getAttributes().get("login")).isEqualTo("user1");
assertThat(user.getAttributes().get("email")).isEqualTo("user1@example.com");
assertThat(user.getAuthorities().size()).isEqualTo(1);
assertThat(user.getAuthorities().iterator().next().getAuthority()).isEqualTo("ROLE_USER");
}
@Test
public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() throws Exception {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_user_info_response"));
MockWebServer server = new MockWebServer();
String userInfoResponse = "{\n" +
" \"id\": \"12345\",\n" +
" \"name\": \"first last\",\n" +
" \"login\": \"user1\",\n" +
" \"email\": \"user1@example.com\"\n";
// "}\n"; // Make the JSON invalid/malformed
server.enqueue(new MockResponse()
.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
.setBody(userInfoResponse));
server.start();
String userInfoUri = server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.accessToken.getTokenValue()).thenReturn("access-token");
try {
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
} finally {
server.shutdown();
}
}
@Test
public void loadUserWhenUserInfoErrorResponseThenThrowOAuth2AuthenticationException() throws Exception {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_user_info_response"));
MockWebServer server = new MockWebServer();
server.enqueue(new MockResponse().setResponseCode(500));
server.start();
String userInfoUri = server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.accessToken.getTokenValue()).thenReturn("access-token");
try {
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
} finally {
server.shutdown();
}
}
@Test
public void loadUserWhenUserInfoUriInvalidThenThrowAuthenticationServiceException() throws Exception {
this.exception.expect(AuthenticationServiceException.class);
String userInfoUri = "http://invalid-provider.com/user";
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.accessToken.getTokenValue()).thenReturn("access-token");
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
}
public static class CustomOAuth2User implements OAuth2User {
private List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
private String id;
private String name;
private String login;
private String email;
public CustomOAuth2User() {
}
@Override
public Collection<? extends GrantedAuthority> getAuthorities() {
return this.authorities;
}
@Override
public Map<String, Object> getAttributes() {
Map<String, Object> attributes = new HashMap<>();
attributes.put("id", this.getId());
attributes.put("name", this.getName());
attributes.put("login", this.getLogin());
attributes.put("email", this.getEmail());
return attributes;
}
public String getId() {
return this.id;
}
public void setId(String id) {
this.id = id;
}
@Override
public String getName() {
return this.name;
}
public void setName(String name) {
this.name = name;
}
public String getLogin() {
return this.login;
}
public void setLogin(String login) {
this.login = login;
}
public String getEmail() {
return this.email;
}
public void setEmail(String email) {
this.email = email;
}
}
}

View File

@ -0,0 +1,197 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.userinfo;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.powermock.core.classloader.annotations.PowerMockIgnore;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.security.authentication.AuthenticationServiceException;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.security.oauth2.core.user.OAuth2UserAuthority;
import static org.assertj.core.api.Assertions.assertThat;
import static org.hamcrest.CoreMatchers.containsString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
* Tests for {@link DefaultOAuth2UserService}.
*
* @author Joe Grandja
*/
@PowerMockIgnore("okhttp3.*")
@PrepareForTest(ClientRegistration.class)
@RunWith(PowerMockRunner.class)
public class DefaultOAuth2UserServiceTests {
private ClientRegistration clientRegistration;
private ClientRegistration.ProviderDetails providerDetails;
private ClientRegistration.ProviderDetails.UserInfoEndpoint userInfoEndpoint;
private OAuth2AccessToken accessToken;
private DefaultOAuth2UserService userService = new DefaultOAuth2UserService();
@Rule
public ExpectedException exception = ExpectedException.none();
@Before
public void setUp() throws Exception {
this.clientRegistration = mock(ClientRegistration.class);
this.providerDetails = mock(ClientRegistration.ProviderDetails.class);
this.userInfoEndpoint = mock(ClientRegistration.ProviderDetails.UserInfoEndpoint.class);
when(this.clientRegistration.getProviderDetails()).thenReturn(this.providerDetails);
when(this.providerDetails.getUserInfoEndpoint()).thenReturn(this.userInfoEndpoint);
this.accessToken = mock(OAuth2AccessToken.class);
}
@Test
public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() {
this.exception.expect(IllegalArgumentException.class);
this.userService.loadUser(null);
}
@Test
public void loadUserWhenUserNameAttributeNameIsNullThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("missing_user_name_attribute"));
when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn(null);
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
}
@Test
public void loadUserWhenUserInfoSuccessResponseThenReturnUser() throws Exception {
MockWebServer server = new MockWebServer();
String userInfoResponse = "{\n" +
" \"user-name\": \"user1\",\n" +
" \"first-name\": \"first\",\n" +
" \"last-name\": \"last\",\n" +
" \"middle-name\": \"middle\",\n" +
" \"address\": \"address\",\n" +
" \"email\": \"user1@example.com\"\n" +
"}\n";
server.enqueue(new MockResponse()
.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
.setBody(userInfoResponse));
server.start();
String userInfoUri = server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name");
when(this.accessToken.getTokenValue()).thenReturn("access-token");
OAuth2User user = this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
server.shutdown();
assertThat(user.getName()).isEqualTo("user1");
assertThat(user.getAttributes().size()).isEqualTo(6);
assertThat(user.getAttributes().get("user-name")).isEqualTo("user1");
assertThat(user.getAttributes().get("first-name")).isEqualTo("first");
assertThat(user.getAttributes().get("last-name")).isEqualTo("last");
assertThat(user.getAttributes().get("middle-name")).isEqualTo("middle");
assertThat(user.getAttributes().get("address")).isEqualTo("address");
assertThat(user.getAttributes().get("email")).isEqualTo("user1@example.com");
assertThat(user.getAuthorities().size()).isEqualTo(1);
assertThat(user.getAuthorities().iterator().next()).isInstanceOf(OAuth2UserAuthority.class);
OAuth2UserAuthority userAuthority = (OAuth2UserAuthority)user.getAuthorities().iterator().next();
assertThat(userAuthority.getAuthority()).isEqualTo("ROLE_USER");
assertThat(userAuthority.getAttributes()).isEqualTo(user.getAttributes());
}
@Test
public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() throws Exception {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_user_info_response"));
MockWebServer server = new MockWebServer();
String userInfoResponse = "{\n" +
" \"user-name\": \"user1\",\n" +
" \"first-name\": \"first\",\n" +
" \"last-name\": \"last\",\n" +
" \"middle-name\": \"middle\",\n" +
" \"address\": \"address\",\n" +
" \"email\": \"user1@example.com\"\n";
// "}\n"; // Make the JSON invalid/malformed
server.enqueue(new MockResponse()
.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
.setBody(userInfoResponse));
server.start();
String userInfoUri = server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name");
when(this.accessToken.getTokenValue()).thenReturn("access-token");
try {
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
} finally {
server.shutdown();
}
}
@Test
public void loadUserWhenUserInfoErrorResponseThenThrowOAuth2AuthenticationException() throws Exception {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_user_info_response"));
MockWebServer server = new MockWebServer();
server.enqueue(new MockResponse().setResponseCode(500));
server.start();
String userInfoUri = server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name");
when(this.accessToken.getTokenValue()).thenReturn("access-token");
try {
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
} finally {
server.shutdown();
}
}
@Test
public void loadUserWhenUserInfoUriInvalidThenThrowAuthenticationServiceException() throws Exception {
this.exception.expect(AuthenticationServiceException.class);
String userInfoUri = "http://invalid-provider.com/user";
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name");
when(this.accessToken.getTokenValue()).thenReturn("access-token");
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
}
}

View File

@ -0,0 +1,84 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.userinfo;
import org.junit.Test;
import org.springframework.security.oauth2.core.user.OAuth2User;
import java.util.Arrays;
import java.util.Collections;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
* Tests for {@link DelegatingOAuth2UserService}.
*
* @author Joe Grandja
*/
public class DelegatingOAuth2UserServiceTests {
@Test(expected = IllegalArgumentException.class)
public void constructorWhenUserServicesIsNullThenThrowIllegalArgumentException() {
new DelegatingOAuth2UserService<>(null);
}
@Test(expected = IllegalArgumentException.class)
public void constructorWhenUserServicesIsEmptyThenThrowIllegalArgumentException() {
new DelegatingOAuth2UserService<>(Collections.emptyList());
}
@Test(expected = IllegalArgumentException.class)
@SuppressWarnings("unchecked")
public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() {
DelegatingOAuth2UserService<OAuth2UserRequest, OAuth2User> delegatingUserService =
new DelegatingOAuth2UserService<>(
Arrays.asList(mock(OAuth2UserService.class), mock(OAuth2UserService.class)));
delegatingUserService.loadUser(null);
}
@Test
@SuppressWarnings("unchecked")
public void loadUserWhenUserServiceCanLoadThenReturnUser() {
OAuth2UserService<OAuth2UserRequest, OAuth2User> userService1 = mock(OAuth2UserService.class);
OAuth2UserService<OAuth2UserRequest, OAuth2User> userService2 = mock(OAuth2UserService.class);
OAuth2UserService<OAuth2UserRequest, OAuth2User> userService3 = mock(OAuth2UserService.class);
OAuth2User mockUser = mock(OAuth2User.class);
when(userService3.loadUser(any(OAuth2UserRequest.class))).thenReturn(mockUser);
DelegatingOAuth2UserService<OAuth2UserRequest, OAuth2User> delegatingUserService =
new DelegatingOAuth2UserService<>(Arrays.asList(userService1, userService2, userService3));
OAuth2User loadedUser = delegatingUserService.loadUser(mock(OAuth2UserRequest.class));
assertThat(loadedUser).isEqualTo(mockUser);
}
@Test
@SuppressWarnings("unchecked")
public void loadUserWhenUserServiceCannotLoadThenReturnNull() {
OAuth2UserService<OAuth2UserRequest, OAuth2User> userService1 = mock(OAuth2UserService.class);
OAuth2UserService<OAuth2UserRequest, OAuth2User> userService2 = mock(OAuth2UserService.class);
OAuth2UserService<OAuth2UserRequest, OAuth2User> userService3 = mock(OAuth2UserService.class);
DelegatingOAuth2UserService<OAuth2UserRequest, OAuth2User> delegatingUserService =
new DelegatingOAuth2UserService<>(Arrays.asList(userService1, userService2, userService3));
OAuth2User loadedUser = delegatingUserService.loadUser(mock(OAuth2UserRequest.class));
assertThat(loadedUser).isNull();
}
}

View File

@ -0,0 +1,63 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.userinfo;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
/**
* Tests for {@link OAuth2UserRequest}.
*
* @author Joe Grandja
*/
@RunWith(PowerMockRunner.class)
@PrepareForTest(ClientRegistration.class)
public class OAuth2UserRequestTests {
private ClientRegistration clientRegistration;
private OAuth2AccessToken accessToken;
@Before
public void setUp() {
this.clientRegistration = mock(ClientRegistration.class);
this.accessToken = mock(OAuth2AccessToken.class);
}
@Test(expected = IllegalArgumentException.class)
public void constructorWhenClientRegistrationIsNullThenThrowIllegalArgumentException() {
new OAuth2UserRequest(null, this.accessToken);
}
@Test(expected = IllegalArgumentException.class)
public void constructorWhenAccessTokenIsNullThenThrowIllegalArgumentException() {
new OAuth2UserRequest(this.clientRegistration, null);
}
@Test
public void constructorWhenAllParametersProvidedAndValidThenCreated() {
OAuth2UserRequest userRequest = new OAuth2UserRequest(this.clientRegistration, this.accessToken);
assertThat(userRequest.getClientRegistration()).isEqualTo(this.clientRegistration);
assertThat(userRequest.getAccessToken()).isEqualTo(this.accessToken);
}
}

View File

@ -0,0 +1,138 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.web;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
/**
* Tests for {@link HttpSessionOAuth2AuthorizationRequestRepository}.
*
* @author Joe Grandja
*/
@PrepareForTest(OAuth2AuthorizationRequest.class)
@RunWith(PowerMockRunner.class)
public class HttpSessionOAuth2AuthorizationRequestRepositoryTests {
private HttpSessionOAuth2AuthorizationRequestRepository authorizationRequestRepository =
new HttpSessionOAuth2AuthorizationRequestRepository();
@Test(expected = IllegalArgumentException.class)
public void loadAuthorizationRequestWhenHttpServletRequestIsNullThenThrowIllegalArgumentException() {
this.authorizationRequestRepository.loadAuthorizationRequest(null);
}
@Test
public void loadAuthorizationRequestWhenNotSavedThenReturnNull() {
OAuth2AuthorizationRequest authorizationRequest =
this.authorizationRequestRepository.loadAuthorizationRequest(new MockHttpServletRequest());
assertThat(authorizationRequest).isNull();
}
@Test
public void loadAuthorizationRequestWhenSavedThenReturnAuthorizationRequest() {
MockHttpServletRequest request = new MockHttpServletRequest();
MockHttpServletResponse response = new MockHttpServletResponse();
OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response);
OAuth2AuthorizationRequest loadedAuthorizationRequest =
this.authorizationRequestRepository.loadAuthorizationRequest(request);
assertThat(loadedAuthorizationRequest).isEqualTo(authorizationRequest);
}
@Test(expected = IllegalArgumentException.class)
public void saveAuthorizationRequestWhenHttpServletRequestIsNullThenThrowIllegalArgumentException() {
this.authorizationRequestRepository.saveAuthorizationRequest(
mock(OAuth2AuthorizationRequest.class), null, new MockHttpServletResponse());
}
@Test(expected = IllegalArgumentException.class)
public void saveAuthorizationRequestWhenHttpServletResponseIsNullThenThrowIllegalArgumentException() {
this.authorizationRequestRepository.saveAuthorizationRequest(
mock(OAuth2AuthorizationRequest.class), new MockHttpServletRequest(), null);
}
@Test
public void saveAuthorizationRequestWhenNotNullThenSaved() {
MockHttpServletRequest request = new MockHttpServletRequest();
OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
this.authorizationRequestRepository.saveAuthorizationRequest(
authorizationRequest, request, new MockHttpServletResponse());
OAuth2AuthorizationRequest loadedAuthorizationRequest =
this.authorizationRequestRepository.loadAuthorizationRequest(request);
assertThat(loadedAuthorizationRequest).isEqualTo(authorizationRequest);
}
@Test
public void saveAuthorizationRequestWhenNullThenRemoved() {
MockHttpServletRequest request = new MockHttpServletRequest();
MockHttpServletResponse response = new MockHttpServletResponse();
OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
this.authorizationRequestRepository.saveAuthorizationRequest( // Save
authorizationRequest, request, response);
this.authorizationRequestRepository.saveAuthorizationRequest( // Null value removes
null, request, response);
OAuth2AuthorizationRequest loadedAuthorizationRequest =
this.authorizationRequestRepository.loadAuthorizationRequest(request);
assertThat(loadedAuthorizationRequest).isNull();
}
@Test(expected = IllegalArgumentException.class)
public void removeAuthorizationRequestWhenHttpServletRequestIsNullThenThrowIllegalArgumentException() {
this.authorizationRequestRepository.removeAuthorizationRequest(null);
}
@Test
public void removeAuthorizationRequestWhenSavedThenRemoved() {
MockHttpServletRequest request = new MockHttpServletRequest();
MockHttpServletResponse response = new MockHttpServletResponse();
OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
this.authorizationRequestRepository.saveAuthorizationRequest(
authorizationRequest, request, response);
OAuth2AuthorizationRequest removedAuthorizationRequest =
this.authorizationRequestRepository.removeAuthorizationRequest(request);
OAuth2AuthorizationRequest loadedAuthorizationRequest =
this.authorizationRequestRepository.loadAuthorizationRequest(request);
assertThat(removedAuthorizationRequest).isNotNull();
assertThat(loadedAuthorizationRequest).isNull();
}
@Test
public void removeAuthorizationRequestWhenNotSavedThenNotRemoved() {
MockHttpServletRequest request = new MockHttpServletRequest();
OAuth2AuthorizationRequest removedAuthorizationRequest =
this.authorizationRequestRepository.removeAuthorizationRequest(request);
assertThat(removedAuthorizationRequest).isNull();
}
}

View File

@ -15,109 +15,259 @@
*/
package org.springframework.security.oauth2.client.web;
import org.assertj.core.api.Assertions;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Matchers;
import org.mockito.Mockito;
import org.springframework.http.HttpStatus;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import javax.servlet.FilterChain;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;
/**
* Tests {@link OAuth2AuthorizationRequestRedirectFilter}.
* Tests for {@link OAuth2AuthorizationRequestRedirectFilter}.
*
* @author Joe Grandja
*/
public class OAuth2AuthorizationRequestRedirectFilterTests {
private ClientRegistration registration1;
private ClientRegistration registration2;
private ClientRegistration registration3;
private ClientRegistrationRepository clientRegistrationRepository;
private OAuth2AuthorizationRequestRedirectFilter filter;
@Before
public void setUp() {
this.registration1 = ClientRegistration.withRegistrationId("registration-1")
.clientId("client-1")
.clientSecret("secret")
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.redirectUri("{scheme}://{serverName}:{serverPort}{contextPath}/login/oauth2/code/{registrationId}")
.scope("user")
.authorizationUri("https://provider.com/oauth2/authorize")
.tokenUri("https://provider.com/oauth2/token")
.userInfoUri("https://provider.com/oauth2/user")
.userNameAttributeName("id")
.clientName("client-1")
.build();
this.registration2 = ClientRegistration.withRegistrationId("registration-2")
.clientId("client-2")
.clientSecret("secret")
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.redirectUri("{scheme}://{serverName}:{serverPort}{contextPath}/login/oauth2/code/{registrationId}")
.scope("openid", "profile", "email")
.authorizationUri("https://provider.com/oauth2/authorize")
.tokenUri("https://provider.com/oauth2/token")
.userInfoUri("https://provider.com/oauth2/userinfo")
.jwkSetUri("https://provider.com/oauth2/keys")
.clientName("client-2")
.build();
this.registration3 = ClientRegistration.withRegistrationId("registration-3")
.clientId("client-3")
.authorizationGrantType(AuthorizationGrantType.IMPLICIT)
.redirectUri("{scheme}://{serverName}:{serverPort}{contextPath}/login/oauth2/implicit/{registrationId}")
.scope("openid", "profile", "email")
.authorizationUri("https://provider.com/oauth2/authorize")
.tokenUri("https://provider.com/oauth2/token")
.userInfoUri("https://provider.com/oauth2/userinfo")
.clientName("client-3")
.build();
this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(
this.registration1, this.registration2, this.registration3);
this.filter = new OAuth2AuthorizationRequestRedirectFilter(this.clientRegistrationRepository);
}
@Test(expected = IllegalArgumentException.class)
public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() {
new OAuth2AuthorizationRequestRedirectFilter(null);
}
@Test
public void doFilterWhenRequestDoesNotMatchClientThenContinueChain() throws Exception {
ClientRegistration clientRegistration = TestUtil.googleClientRegistration();
String authorizationUri = clientRegistration.getProviderDetails().getAuthorizationUri().toString();
OAuth2AuthorizationRequestRedirectFilter filter =
setupFilter(authorizationUri, clientRegistration);
@Test(expected = IllegalArgumentException.class)
public void constructorWhenAuthorizationRequestBaseUriIsNullThenThrowIllegalArgumentException() {
new OAuth2AuthorizationRequestRedirectFilter(null, this.clientRegistrationRepository);
}
String requestURI = "/path";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestURI);
request.setServletPath(requestURI);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = Mockito.mock(FilterChain.class);
filter.doFilter(request, response, filterChain);
Mockito.verify(filterChain).doFilter(Matchers.any(HttpServletRequest.class), Matchers.any(HttpServletResponse.class));
@Test(expected = IllegalArgumentException.class)
public void setAuthorizationRequestRepositoryWhenAuthorizationRequestRepositoryIsNullThenThrowIllegalArgumentException() {
this.filter.setAuthorizationRequestRepository(null);
}
@Test
public void doFilterWhenRequestMatchesClientThenRedirectForAuthorization() throws Exception {
ClientRegistration clientRegistration = TestUtil.googleClientRegistration();
String authorizationUri = clientRegistration.getProviderDetails().getAuthorizationUri().toString();
OAuth2AuthorizationRequestRedirectFilter filter =
setupFilter(authorizationUri, clientRegistration);
String requestUri = TestUtil.AUTHORIZATION_BASE_URI + "/" + clientRegistration.getRegistrationId();
public void doFilterWhenNotAuthorizationRequestThenNextFilter() throws Exception {
String requestUri = "/path";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = Mockito.mock(FilterChain.class);
FilterChain filterChain = mock(FilterChain.class);
filter.doFilter(request, response, filterChain);
this.filter.doFilter(request, response, filterChain);
Mockito.verifyZeroInteractions(filterChain); // Request should not proceed up the chain
Assertions.assertThat(response.getRedirectedUrl()).matches("https://accounts.google.com/o/oauth2/auth\\?response_type=code&client_id=google-client-id&scope=openid%20email%20profile&state=.{15,}&redirect_uri=https://localhost:8080/login/oauth2/code/google");
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
}
@Test
public void doFilterWhenRequestMatchesClientThenAuthorizationRequestSavedInSession() throws Exception {
ClientRegistration clientRegistration = TestUtil.githubClientRegistration();
String authorizationUri = clientRegistration.getProviderDetails().getAuthorizationUri().toString();
OAuth2AuthorizationRequestRedirectFilter filter =
setupFilter(authorizationUri, clientRegistration);
public void doFilterWhenAuthorizationRequestWithInvalidClientThenStatusBadRequest() throws Exception {
String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI +
"/" + this.registration1.getRegistrationId() + "-invalid";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyZeroInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
assertThat(response.getErrorMessage()).isEqualTo(HttpStatus.BAD_REQUEST.getReasonPhrase());
}
@Test
public void doFilterWhenAuthorizationRequestAuthorizationCodeGrantThenRedirectForAuthorization() throws Exception {
String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI +
"/" + this.registration1.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyZeroInteractions(filterChain);
assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http://localhost:80/login/oauth2/code/registration-1");
}
@Test
public void doFilterWhenAuthorizationRequestAuthorizationCodeGrantThenAuthorizationRequestSavedInSession() throws Exception {
String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI +
"/" + this.registration2.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
new HttpSessionOAuth2AuthorizationRequestRepository();
filter.setAuthorizationRequestRepository(authorizationRequestRepository);
this.filter.setAuthorizationRequestRepository(authorizationRequestRepository);
String requestUri = TestUtil.AUTHORIZATION_BASE_URI + "/" + clientRegistration.getRegistrationId();
this.filter.doFilter(request, response, filterChain);
verifyZeroInteractions(filterChain);
OAuth2AuthorizationRequest authorizationRequest = authorizationRequestRepository.loadAuthorizationRequest(request);
assertThat(authorizationRequest).isNotNull();
assertThat(authorizationRequest.getAuthorizationUri()).isEqualTo(
this.registration2.getProviderDetails().getAuthorizationUri());
assertThat(authorizationRequest.getGrantType()).isEqualTo(
this.registration2.getAuthorizationGrantType());
assertThat(authorizationRequest.getResponseType()).isEqualTo(
OAuth2AuthorizationResponseType.CODE);
assertThat(authorizationRequest.getClientId()).isEqualTo(
this.registration2.getClientId());
assertThat(authorizationRequest.getRedirectUri()).isEqualTo(
"http://localhost:80/login/oauth2/code/registration-2");
assertThat(authorizationRequest.getScopes()).isEqualTo(
this.registration2.getScopes());
assertThat(authorizationRequest.getState()).isNotNull();
assertThat(authorizationRequest.getAdditionalParameters()
.get(OAuth2ParameterNames.REGISTRATION_ID)).isEqualTo(this.registration2.getRegistrationId());
}
@Test
public void doFilterWhenAuthorizationRequestImplicitGrantThenRedirectForAuthorization() throws Exception {
String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI +
"/" + this.registration3.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = Mockito.mock(FilterChain.class);
FilterChain filterChain = mock(FilterChain.class);
filter.doFilter(request, response, filterChain);
this.filter.doFilter(request, response, filterChain);
Mockito.verifyZeroInteractions(filterChain); // Request should not proceed up the chain
verifyZeroInteractions(filterChain);
// The authorization request attributes are saved in the session before the redirect happens
OAuth2AuthorizationRequest authorizationRequest =
authorizationRequestRepository.loadAuthorizationRequest(request);
Assertions.assertThat(authorizationRequest).isNotNull();
Assertions.assertThat(authorizationRequest.getAuthorizationUri()).isNotNull();
Assertions.assertThat(authorizationRequest.getGrantType()).isNotNull();
Assertions.assertThat(authorizationRequest.getResponseType()).isNotNull();
Assertions.assertThat(authorizationRequest.getClientId()).isNotNull();
Assertions.assertThat(authorizationRequest.getRedirectUri()).isNotNull();
Assertions.assertThat(authorizationRequest.getScopes()).isNotNull();
Assertions.assertThat(authorizationRequest.getState()).isNotNull();
assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=token&client_id=client-3&scope=openid%20profile%20email&state=.{15,}&redirect_uri=http://localhost:80/login/oauth2/implicit/registration-3");
}
private OAuth2AuthorizationRequestRedirectFilter setupFilter(String authorizationUri,
ClientRegistration... clientRegistrations) throws Exception {
ClientRegistrationRepository clientRegistrationRepository = new InMemoryClientRegistrationRepository(clientRegistrations);
OAuth2AuthorizationRequestRedirectFilter filter = new OAuth2AuthorizationRequestRedirectFilter(clientRegistrationRepository);
return filter;
@Test
public void doFilterWhenAuthorizationRequestImplicitGrantThenAuthorizationRequestNotSavedInSession() throws Exception {
String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI +
"/" + this.registration3.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
new HttpSessionOAuth2AuthorizationRequestRepository();
this.filter.setAuthorizationRequestRepository(authorizationRequestRepository);
this.filter.doFilter(request, response, filterChain);
verifyZeroInteractions(filterChain);
OAuth2AuthorizationRequest authorizationRequest = authorizationRequestRepository.loadAuthorizationRequest(request);
assertThat(authorizationRequest).isNull();
}
@Test
public void doFilterWhenCustomAuthorizationRequestBaseUriThenRedirectForAuthorization() throws Exception {
String authorizationRequestBaseUri = "/custom/authorization";
this.filter = new OAuth2AuthorizationRequestRedirectFilter(authorizationRequestBaseUri, this.clientRegistrationRepository);
String requestUri = authorizationRequestBaseUri + "/" + this.registration1.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyZeroInteractions(filterChain);
assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http://localhost:80/login/oauth2/code/registration-1");
}
@Test
public void doFilterWhenAuthorizationRequestRedirectUriTemplatedThenRedirectUriExpanded() throws Exception {
String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI +
"/" + this.registration2.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
new HttpSessionOAuth2AuthorizationRequestRepository();
this.filter.setAuthorizationRequestRepository(authorizationRequestRepository);
this.filter.doFilter(request, response, filterChain);
verifyZeroInteractions(filterChain);
OAuth2AuthorizationRequest authorizationRequest = authorizationRequestRepository.loadAuthorizationRequest(request);
assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(
this.registration2.getRedirectUri());
assertThat(authorizationRequest.getRedirectUri()).isEqualTo(
"http://localhost:80/login/oauth2/code/registration-2");
}
}

View File

@ -27,12 +27,19 @@ import java.util.HashSet;
import static org.assertj.core.api.Assertions.assertThat;
/**
* Tests for {@link OAuth2AuthorizationRequestUriBuilder}.
*
* @author Rob Winch
* @since 5.0
*/
public class OAuth2AuthorizationRequestUriBuilderTests {
private OAuth2AuthorizationRequestUriBuilder builder = new OAuth2AuthorizationRequestUriBuilder();
@Test(expected = IllegalArgumentException.class)
public void buildWhenAuthorizationRequestIsNullThenThrowIllegalArgumentException() {
this.builder.build(null);
}
@Test
public void buildWhenScopeMultiThenSeparatedByEncodedSpace() {
OAuth2AuthorizationRequest request = OAuth2AuthorizationRequest.implicit()

View File

@ -15,32 +15,36 @@
*/
package org.springframework.security.oauth2.client.web;
import org.assertj.core.api.Assertions;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Matchers;
import org.mockito.Mockito;
import org.powermock.core.classloader.annotations.PowerMockIgnore;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
import javax.servlet.FilterChain;
import javax.servlet.http.HttpServletRequest;
@ -48,183 +52,235 @@ import javax.servlet.http.HttpServletResponse;
import java.util.HashMap;
import java.util.Map;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;
/**
* Tests {@link OAuth2LoginAuthenticationFilter}.
* Tests for {@link OAuth2LoginAuthenticationFilter}.
*
* @author Joe Grandja
*/
@PowerMockIgnore("javax.security.*")
@PrepareForTest({OAuth2AuthorizationRequest.class, OAuth2AuthorizationExchange.class})
@RunWith(PowerMockRunner.class)
public class OAuth2LoginAuthenticationFilterTests {
private ClientRegistration registration1;
private ClientRegistration registration2;
private String principalName1 = "principal-1";
private ClientRegistrationRepository clientRegistrationRepository;
private OAuth2AuthorizedClientService authorizedClientService;
private AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository;
private AuthenticationFailureHandler failureHandler;
private AuthenticationManager authenticationManager;
private OAuth2LoginAuthenticationFilter filter;
@Test
public void doFilterWhenNotAuthorizationCodeResponseThenContinueChain() throws Exception {
ClientRegistration clientRegistration = TestUtil.googleClientRegistration();
@Before
public void setUp() {
this.registration1 = ClientRegistration.withRegistrationId("registration-1")
.clientId("client-1")
.clientSecret("secret")
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.redirectUri("{scheme}://{serverName}:{serverPort}{contextPath}/login/oauth2/code/{registrationId}")
.scope("user")
.authorizationUri("https://provider.com/oauth2/authorize")
.tokenUri("https://provider.com/oauth2/token")
.userInfoUri("https://provider.com/oauth2/user")
.userNameAttributeName("id")
.clientName("client-1")
.build();
this.registration2 = ClientRegistration.withRegistrationId("registration-2")
.clientId("client-2")
.clientSecret("secret")
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.redirectUri("{scheme}://{serverName}:{serverPort}{contextPath}/login/oauth2/code/{registrationId}")
.scope("openid", "profile", "email")
.authorizationUri("https://provider.com/oauth2/authorize")
.tokenUri("https://provider.com/oauth2/token")
.userInfoUri("https://provider.com/oauth2/userinfo")
.jwkSetUri("https://provider.com/oauth2/keys")
.clientName("client-2")
.build();
this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(
this.registration1, this.registration2);
this.authorizedClientService = new InMemoryOAuth2AuthorizedClientService(this.clientRegistrationRepository);
this.authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository();
this.failureHandler = mock(AuthenticationFailureHandler.class);
this.authenticationManager = mock(AuthenticationManager.class);
this.filter = spy(new OAuth2LoginAuthenticationFilter(
this.clientRegistrationRepository, this.authorizedClientService));
this.filter.setAuthorizationRequestRepository(this.authorizationRequestRepository);
this.filter.setAuthenticationFailureHandler(this.failureHandler);
this.filter.setAuthenticationManager(this.authenticationManager);
}
OAuth2LoginAuthenticationFilter filter = Mockito.spy(setupFilter(clientRegistration));
@Test(expected = IllegalArgumentException.class)
public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() {
new OAuth2LoginAuthenticationFilter(null, this.authorizedClientService);
}
String requestURI = "/path";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestURI);
request.setServletPath(requestURI);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
@Test(expected = IllegalArgumentException.class)
public void constructorWhenAuthorizedClientServiceIsNullThenThrowIllegalArgumentException() {
new OAuth2LoginAuthenticationFilter(this.clientRegistrationRepository, null);
}
filter.doFilter(request, response, filterChain);
@Test(expected = IllegalArgumentException.class)
public void constructorWhenFilterProcessesUrlIsNullThenThrowIllegalArgumentException() {
new OAuth2LoginAuthenticationFilter(null, this.clientRegistrationRepository, this.authorizedClientService);
}
Mockito.verify(filterChain).doFilter(Matchers.any(HttpServletRequest.class), Matchers.any(HttpServletResponse.class));
Mockito.verify(filter, Mockito.never()).attemptAuthentication(Matchers.any(HttpServletRequest.class), Matchers.any(HttpServletResponse.class));
@Test(expected = IllegalArgumentException.class)
public void setAuthorizationRequestRepositoryWhenAuthorizationRequestRepositoryIsNullThenThrowIllegalArgumentException() {
this.filter.setAuthorizationRequestRepository(null);
}
@Test
public void doFilterWhenAuthorizationCodeErrorResponseThenAuthenticationFailureHandlerIsCalled() throws Exception {
ClientRegistration clientRegistration = TestUtil.githubClientRegistration();
OAuth2LoginAuthenticationFilter filter = Mockito.spy(setupFilter(clientRegistration));
AuthenticationFailureHandler failureHandler = mock(AuthenticationFailureHandler.class);
filter.setAuthenticationFailureHandler(failureHandler);
MockHttpServletRequest request = this.setupRequest(clientRegistration);
String errorCode = OAuth2ErrorCodes.INVALID_GRANT;
request.addParameter(OAuth2ParameterNames.ERROR, errorCode);
request.addParameter(OAuth2ParameterNames.STATE, "some state");
public void doFilterWhenNotAuthorizationResponseThenNextFilter() throws Exception {
String requestUri = "/path";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
filter.doFilter(request, response, filterChain);
this.filter.doFilter(request, response, filterChain);
Mockito.verify(filter).attemptAuthentication(Matchers.any(HttpServletRequest.class), Matchers.any(HttpServletResponse.class));
Mockito.verify(failureHandler).onAuthenticationFailure(Matchers.any(HttpServletRequest.class), Matchers.any(HttpServletResponse.class),
Matchers.any(AuthenticationException.class));
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
verify(this.filter, never()).attemptAuthentication(any(HttpServletRequest.class), any(HttpServletResponse.class));
}
@Test
public void doFilterWhenAuthorizationCodeSuccessResponseThenAuthenticationSuccessHandlerIsCalled() throws Exception {
ClientRegistration clientRegistration = TestUtil.githubClientRegistration();
OAuth2User oauth2User = mock(OAuth2User.class);
when(oauth2User.getName()).thenReturn("principal name");
public void doFilterWhenAuthorizationResponseInvalidThenInvalidRequestError() throws Exception {
String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
// NOTE:
// A valid Authorization Response contains either a 'code' or 'error' parameter.
// Don't set it to force an invalid Authorization Response.
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
ArgumentCaptor<AuthenticationException> authenticationExceptionArgCaptor = ArgumentCaptor.forClass(AuthenticationException.class);
verify(this.failureHandler).onAuthenticationFailure(any(HttpServletRequest.class), any(HttpServletResponse.class),
authenticationExceptionArgCaptor.capture());
assertThat(authenticationExceptionArgCaptor.getValue()).isInstanceOf(OAuth2AuthenticationException.class);
OAuth2AuthenticationException authenticationException = (OAuth2AuthenticationException) authenticationExceptionArgCaptor.getValue();
assertThat(authenticationException.getError().getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
}
@Test
public void doFilterWhenAuthorizationResponseAuthorizationRequestNotFoundThenAuthorizationRequestNotFoundError() throws Exception {
String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
request.addParameter(OAuth2ParameterNames.CODE, "code");
request.addParameter(OAuth2ParameterNames.STATE, "state");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
ArgumentCaptor<AuthenticationException> authenticationExceptionArgCaptor = ArgumentCaptor.forClass(AuthenticationException.class);
verify(this.failureHandler).onAuthenticationFailure(any(HttpServletRequest.class), any(HttpServletResponse.class),
authenticationExceptionArgCaptor.capture());
assertThat(authenticationExceptionArgCaptor.getValue()).isInstanceOf(OAuth2AuthenticationException.class);
OAuth2AuthenticationException authenticationException = (OAuth2AuthenticationException) authenticationExceptionArgCaptor.getValue();
assertThat(authenticationException.getError().getErrorCode()).isEqualTo("authorization_request_not_found");
}
@Test
public void doFilterWhenAuthorizationResponseValidThenAuthorizationRequestRemoved() throws Exception {
String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
request.addParameter(OAuth2ParameterNames.CODE, "code");
request.addParameter(OAuth2ParameterNames.STATE, "state");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.setUpAuthorizationRequest(request, response, this.registration2);
this.setUpAuthenticationResult(this.registration2);
this.filter.doFilter(request, response, filterChain);
assertThat(this.authorizationRequestRepository.loadAuthorizationRequest(request)).isNull();
}
@Test
public void doFilterWhenAuthorizationResponseValidThenAuthorizedClientSaved() throws Exception {
String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
request.addParameter(OAuth2ParameterNames.CODE, "code");
request.addParameter(OAuth2ParameterNames.STATE, "state");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.setUpAuthorizationRequest(request, response, this.registration1);
this.setUpAuthenticationResult(this.registration1);
this.filter.doFilter(request, response, filterChain);
OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient(
this.registration1.getRegistrationId(), this.principalName1);
assertThat(authorizedClient).isNotNull();
assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration1);
assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principalName1);
assertThat(authorizedClient.getAccessToken()).isNotNull();
}
@Test
public void doFilterWhenCustomFilterProcessesUrlThenFilterProcesses() throws Exception {
String filterProcessesUrl = "/login/oauth2/custom/*";
this.filter = spy(new OAuth2LoginAuthenticationFilter(filterProcessesUrl,
this.clientRegistrationRepository, this.authorizedClientService));
this.filter.setAuthenticationManager(this.authenticationManager);
String requestUri = "/login/oauth2/custom/" + this.registration2.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
request.addParameter(OAuth2ParameterNames.CODE, "code");
request.addParameter(OAuth2ParameterNames.STATE, "state");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.setUpAuthorizationRequest(request, response, this.registration2);
this.setUpAuthenticationResult(this.registration2);
this.filter.doFilter(request, response, filterChain);
verifyZeroInteractions(filterChain);
verify(this.filter).attemptAuthentication(any(HttpServletRequest.class), any(HttpServletResponse.class));
}
private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response,
ClientRegistration registration) {
OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId());
when(authorizationRequest.getAdditionalParameters()).thenReturn(additionalParameters);
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response);
}
private void setUpAuthenticationResult(ClientRegistration registration) {
OAuth2User user = mock(OAuth2User.class);
when(user.getName()).thenReturn(this.principalName1);
OAuth2LoginAuthenticationToken loginAuthentication = mock(OAuth2LoginAuthenticationToken.class);
when(loginAuthentication.getPrincipal()).thenReturn(oauth2User);
when(loginAuthentication.getClientRegistration()).thenReturn(clientRegistration);
when(loginAuthentication.getPrincipal()).thenReturn(user);
when(loginAuthentication.getAuthorities()).thenReturn(AuthorityUtils.createAuthorityList("ROLE_USER"));
when(loginAuthentication.getClientRegistration()).thenReturn(registration);
when(loginAuthentication.getAuthorizationExchange()).thenReturn(mock(OAuth2AuthorizationExchange.class));
when(loginAuthentication.getAccessToken()).thenReturn(mock(OAuth2AccessToken.class));
OAuth2AuthenticationToken userAuthentication = new OAuth2AuthenticationToken(
oauth2User, AuthorityUtils.NO_AUTHORITIES, clientRegistration.getRegistrationId());
SecurityContextHolder.getContext().setAuthentication(userAuthentication);
AuthenticationManager authenticationManager = mock(AuthenticationManager.class);
when(authenticationManager.authenticate(Matchers.any(Authentication.class))).thenReturn(loginAuthentication);
OAuth2LoginAuthenticationFilter filter = Mockito.spy(setupFilter(authenticationManager, clientRegistration));
AuthenticationSuccessHandler successHandler = mock(AuthenticationSuccessHandler.class);
filter.setAuthenticationSuccessHandler(successHandler);
AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
new HttpSessionOAuth2AuthorizationRequestRepository();
filter.setAuthorizationRequestRepository(authorizationRequestRepository);
MockHttpServletRequest request = this.setupRequest(clientRegistration);
String authCode = "some code";
String state = "some state";
request.addParameter(OAuth2ParameterNames.CODE, authCode);
request.addParameter(OAuth2ParameterNames.STATE, state);
MockHttpServletResponse response = new MockHttpServletResponse();
setupAuthorizationRequest(authorizationRequestRepository, request, response, clientRegistration, state);
FilterChain filterChain = mock(FilterChain.class);
filter.doFilter(request, response, filterChain);
Mockito.verify(filter).attemptAuthentication(Matchers.any(HttpServletRequest.class), Matchers.any(HttpServletResponse.class));
ArgumentCaptor<Authentication> authenticationArgCaptor = ArgumentCaptor.forClass(Authentication.class);
Mockito.verify(successHandler).onAuthenticationSuccess(Matchers.any(HttpServletRequest.class), Matchers.any(HttpServletResponse.class),
authenticationArgCaptor.capture());
Assertions.assertThat(authenticationArgCaptor.getValue()).isEqualTo(userAuthentication);
}
@Test
public void doFilterWhenAuthorizationCodeSuccessResponseAndNoMatchingAuthorizationRequestThenThrowOAuth2AuthenticationExceptionAuthorizationRequestNotFound() throws Exception {
ClientRegistration clientRegistration = TestUtil.githubClientRegistration();
OAuth2LoginAuthenticationFilter filter = Mockito.spy(setupFilter(clientRegistration));
AuthenticationFailureHandler failureHandler = mock(AuthenticationFailureHandler.class);
filter.setAuthenticationFailureHandler(failureHandler);
MockHttpServletRequest request = this.setupRequest(clientRegistration);
String authCode = "some code";
String state = "some state";
request.addParameter(OAuth2ParameterNames.CODE, authCode);
request.addParameter(OAuth2ParameterNames.STATE, state);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
filter.doFilter(request, response, filterChain);
verifyThrowsOAuth2AuthenticationExceptionWithErrorCode(filter, failureHandler, "authorization_request_not_found");
}
private void verifyThrowsOAuth2AuthenticationExceptionWithErrorCode(OAuth2LoginAuthenticationFilter filter,
AuthenticationFailureHandler failureHandler,
String errorCode) throws Exception {
Mockito.verify(filter).attemptAuthentication(Matchers.any(HttpServletRequest.class), Matchers.any(HttpServletResponse.class));
ArgumentCaptor<AuthenticationException> authenticationExceptionArgCaptor =
ArgumentCaptor.forClass(AuthenticationException.class);
Mockito.verify(failureHandler).onAuthenticationFailure(Matchers.any(HttpServletRequest.class), Matchers.any(HttpServletResponse.class),
authenticationExceptionArgCaptor.capture());
Assertions.assertThat(authenticationExceptionArgCaptor.getValue()).isInstanceOf(OAuth2AuthenticationException.class);
OAuth2AuthenticationException oauth2AuthenticationException =
(OAuth2AuthenticationException)authenticationExceptionArgCaptor.getValue();
Assertions.assertThat(oauth2AuthenticationException.getError()).isNotNull();
Assertions.assertThat(oauth2AuthenticationException.getError().getErrorCode()).isEqualTo(errorCode);
}
private OAuth2LoginAuthenticationFilter setupFilter(ClientRegistration... clientRegistrations) throws Exception {
AuthenticationManager authenticationManager = mock(AuthenticationManager.class);
return setupFilter(authenticationManager, clientRegistrations);
}
private OAuth2LoginAuthenticationFilter setupFilter(
AuthenticationManager authenticationManager, ClientRegistration... clientRegistrations) throws Exception {
ClientRegistrationRepository clientRegistrationRepository = new InMemoryClientRegistrationRepository(clientRegistrations);
OAuth2LoginAuthenticationFilter filter = new OAuth2LoginAuthenticationFilter(
clientRegistrationRepository, mock(OAuth2AuthorizedClientService.class));
filter.setAuthenticationManager(authenticationManager);
return filter;
}
private void setupAuthorizationRequest(AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository,
HttpServletRequest request,
HttpServletResponse response,
ClientRegistration clientRegistration,
String state) {
Map<String,Object> additionalParameters = new HashMap<>();
additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId());
OAuth2AuthorizationRequest authorizationRequest =
OAuth2AuthorizationRequest.authorizationCode()
.clientId(clientRegistration.getClientId())
.authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri())
.redirectUri(clientRegistration.getRedirectUri())
.scopes(clientRegistration.getScopes())
.state(state)
.additionalParameters(additionalParameters)
.build();
authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response);
}
private MockHttpServletRequest setupRequest(ClientRegistration clientRegistration) {
String requestURI = TestUtil.AUTHORIZE_BASE_URI + "/" + clientRegistration.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestURI);
request.setScheme(TestUtil.DEFAULT_SCHEME);
request.setServerName(TestUtil.DEFAULT_SERVER_NAME);
request.setServerPort(TestUtil.DEFAULT_SERVER_PORT);
request.setServletPath(requestURI);
return request;
when(this.authenticationManager.authenticate(any(Authentication.class))).thenReturn(loginAuthentication);
}
}

View File

@ -1,71 +0,0 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.web;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
/**
* @author Joe Grandja
*/
class TestUtil {
static final String DEFAULT_SCHEME = "https";
static final String DEFAULT_SERVER_NAME = "localhost";
static final int DEFAULT_SERVER_PORT = 8080;
static final String DEFAULT_SERVER_URL = DEFAULT_SCHEME + "://" + DEFAULT_SERVER_NAME + ":" + DEFAULT_SERVER_PORT;
static final String AUTHORIZATION_BASE_URI = "/oauth2/authorization";
static final String AUTHORIZE_BASE_URI = "/login/oauth2/code";
static final String GOOGLE_REGISTRATION_ID = "google";
static final String GITHUB_REGISTRATION_ID = "github";
static ClientRegistration googleClientRegistration() {
return googleClientRegistration(DEFAULT_SERVER_URL + AUTHORIZE_BASE_URI + "/" + GOOGLE_REGISTRATION_ID);
}
static ClientRegistration googleClientRegistration(String redirectUri) {
return ClientRegistration.withRegistrationId(GOOGLE_REGISTRATION_ID)
.clientId("google-client-id")
.clientSecret("secret")
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.clientName("Google Client")
.authorizationUri("https://accounts.google.com/o/oauth2/auth")
.tokenUri("https://accounts.google.com/o/oauth2/token")
.userInfoUri("https://www.googleapis.com/oauth2/v3/userinfo")
.jwkSetUri("https://www.googleapis.com/oauth2/v3/certs")
.redirectUri(redirectUri)
.scope("openid", "email", "profile")
.build();
}
static ClientRegistration githubClientRegistration() {
return githubClientRegistration(DEFAULT_SERVER_URL + AUTHORIZE_BASE_URI + "/" + GITHUB_REGISTRATION_ID);
}
static ClientRegistration githubClientRegistration(String redirectUri) {
return ClientRegistration.withRegistrationId(GITHUB_REGISTRATION_ID)
.clientId("github-client-id")
.clientSecret("secret")
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.clientName("GitHub Client")
.authorizationUri("https://github.com/login/oauth/authorize")
.tokenUri("https://github.com/login/oauth/access_token")
.userInfoUri("https://api.github.com/user")
.redirectUri(redirectUri)
.scope("user")
.build();
}
}