Remove PowerMock from oauth2-client

Issue: gh-6025
This commit is contained in:
Josh Cummings 2018-11-19 18:05:14 -07:00
parent 39933b10ff
commit 80e13bad41
No known key found for this signature in database
GPG Key ID: 49EF60DD7FF83443
13 changed files with 387 additions and 449 deletions

View File

@ -17,22 +17,19 @@ package org.springframework.security.oauth2.client;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration;
import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes;
/**
* Tests for {@link OAuth2AuthorizedClient}.
*
* @author Joe Grandja
*/
@RunWith(PowerMockRunner.class)
@PrepareForTest(ClientRegistration.class)
public class OAuth2AuthorizedClientTests {
private ClientRegistration clientRegistration;
private String principalName;
@ -40,9 +37,9 @@ public class OAuth2AuthorizedClientTests {
@Before
public void setUp() {
this.clientRegistration = mock(ClientRegistration.class);
this.clientRegistration = clientRegistration().build();
this.principalName = "principal";
this.accessToken = mock(OAuth2AccessToken.class);
this.accessToken = noScopes();
}
@Test(expected = IllegalArgumentException.class)

View File

@ -15,62 +15,50 @@
*/
package org.springframework.security.oauth2.client.authentication;
import java.util.Collections;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
import java.util.Collections;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration;
import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses.accessTokenResponse;
import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests.request;
import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses.error;
import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses.success;
/**
* Tests for {@link OAuth2AuthorizationCodeAuthenticationProvider}.
*
* @author Joe Grandja
*/
@PrepareForTest({ClientRegistration.class, OAuth2AuthorizationRequest.class,
OAuth2AuthorizationResponse.class, OAuth2AccessTokenResponse.class})
@RunWith(PowerMockRunner.class)
public class OAuth2AuthorizationCodeAuthenticationProviderTests {
private ClientRegistration clientRegistration;
private OAuth2AuthorizationRequest authorizationRequest;
private OAuth2AuthorizationResponse authorizationResponse;
private OAuth2AuthorizationExchange authorizationExchange;
private OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient;
private OAuth2AuthorizationCodeAuthenticationProvider authenticationProvider;
@Before
@SuppressWarnings("unchecked")
public void setUp() throws Exception {
this.clientRegistration = mock(ClientRegistration.class);
this.authorizationRequest = mock(OAuth2AuthorizationRequest.class);
this.authorizationResponse = mock(OAuth2AuthorizationResponse.class);
this.authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, this.authorizationResponse);
public void setUp() {
this.clientRegistration = clientRegistration().build();
this.authorizationRequest = request().build();
this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class);
this.authenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider(this.accessTokenResponseClient);
when(this.authorizationRequest.getState()).thenReturn("12345");
when(this.authorizationResponse.getState()).thenReturn("12345");
when(this.authorizationRequest.getRedirectUri()).thenReturn("http://example.com");
when(this.authorizationResponse.getRedirectUri()).thenReturn("http://example.com");
}
@Test
@ -86,60 +74,62 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
@Test
public void authenticateWhenAuthorizationErrorResponseThenThrowOAuth2AuthorizationException() {
when(this.authorizationResponse.statusError()).thenReturn(true);
when(this.authorizationResponse.getError()).thenReturn(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST));
OAuth2AuthorizationResponse authorizationResponse = error().errorCode(OAuth2ErrorCodes.INVALID_REQUEST).build();
OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(
this.authorizationRequest, authorizationResponse);
assertThatThrownBy(() -> {
this.authenticationProvider.authenticate(
new OAuth2AuthorizationCodeAuthenticationToken(
this.clientRegistration, this.authorizationExchange));
this.clientRegistration, authorizationExchange));
}).isInstanceOf(OAuth2AuthorizationException.class).hasMessageContaining(OAuth2ErrorCodes.INVALID_REQUEST);
}
@Test
public void authenticateWhenAuthorizationResponseStateNotEqualAuthorizationRequestStateThenThrowOAuth2AuthorizationException() {
when(this.authorizationRequest.getState()).thenReturn("12345");
when(this.authorizationResponse.getState()).thenReturn("67890");
OAuth2AuthorizationResponse authorizationResponse = success().state("67890").build();
OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(
this.authorizationRequest, authorizationResponse);
assertThatThrownBy(() -> {
this.authenticationProvider.authenticate(
new OAuth2AuthorizationCodeAuthenticationToken(
this.clientRegistration, this.authorizationExchange));
this.clientRegistration, authorizationExchange));
}).isInstanceOf(OAuth2AuthorizationException.class).hasMessageContaining("invalid_state_parameter");
}
@Test
public void authenticateWhenAuthorizationResponseRedirectUriNotEqualAuthorizationRequestRedirectUriThenThrowOAuth2AuthorizationException() {
when(this.authorizationRequest.getRedirectUri()).thenReturn("http://example.com");
when(this.authorizationResponse.getRedirectUri()).thenReturn("http://example2.com");
OAuth2AuthorizationResponse authorizationResponse = success().redirectUri("http://example2.com").build();
OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(
this.authorizationRequest, authorizationResponse);
assertThatThrownBy(() -> {
this.authenticationProvider.authenticate(
new OAuth2AuthorizationCodeAuthenticationToken(
this.clientRegistration, this.authorizationExchange));
this.clientRegistration, authorizationExchange));
}).isInstanceOf(OAuth2AuthorizationException.class).hasMessageContaining("invalid_redirect_uri_parameter");
}
@Test
public void authenticateWhenAuthorizationSuccessResponseThenExchangedForAccessToken() {
OAuth2AccessToken accessToken = mock(OAuth2AccessToken.class);
OAuth2RefreshToken refreshToken = mock(OAuth2RefreshToken.class);
OAuth2AccessTokenResponse accessTokenResponse = mock(OAuth2AccessTokenResponse.class);
when(accessTokenResponse.getAccessToken()).thenReturn(accessToken);
when(accessTokenResponse.getRefreshToken()).thenReturn(refreshToken);
OAuth2AccessTokenResponse accessTokenResponse = accessTokenResponse().refreshToken("refresh").build();
when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);
OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(
this.authorizationRequest, success().build());
OAuth2AuthorizationCodeAuthenticationToken authenticationResult =
(OAuth2AuthorizationCodeAuthenticationToken) this.authenticationProvider.authenticate(
new OAuth2AuthorizationCodeAuthenticationToken(this.clientRegistration, this.authorizationExchange));
new OAuth2AuthorizationCodeAuthenticationToken(this.clientRegistration, authorizationExchange));
assertThat(authenticationResult.isAuthenticated()).isTrue();
assertThat(authenticationResult.getPrincipal()).isEqualTo(this.clientRegistration.getClientId());
assertThat(authenticationResult.getCredentials()).isEqualTo(accessToken.getTokenValue());
assertThat(authenticationResult.getCredentials())
.isEqualTo(accessTokenResponse.getAccessToken().getTokenValue());
assertThat(authenticationResult.getAuthorities()).isEqualTo(Collections.emptyList());
assertThat(authenticationResult.getClientRegistration()).isEqualTo(this.clientRegistration);
assertThat(authenticationResult.getAuthorizationExchange()).isEqualTo(this.authorizationExchange);
assertThat(authenticationResult.getAccessToken()).isEqualTo(accessToken);
assertThat(authenticationResult.getRefreshToken()).isEqualTo(refreshToken);
assertThat(authenticationResult.getAuthorizationExchange()).isEqualTo(authorizationExchange);
assertThat(authenticationResult.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
assertThat(authenticationResult.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken());
}
}

View File

@ -15,30 +15,27 @@
*/
package org.springframework.security.oauth2.client.authentication;
import java.util.Collections;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
import java.util.Collections;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration;
import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes;
import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests.request;
import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses.success;
/**
* Tests for {@link OAuth2AuthorizationCodeAuthenticationToken}.
*
* @author Joe Grandja
*/
@RunWith(PowerMockRunner.class)
@PrepareForTest({ClientRegistration.class, OAuth2AuthorizationExchange.class, OAuth2AuthorizationResponse.class})
public class OAuth2AuthorizationCodeAuthenticationTokenTests {
private ClientRegistration clientRegistration;
private OAuth2AuthorizationExchange authorizationExchange;
@ -46,9 +43,10 @@ public class OAuth2AuthorizationCodeAuthenticationTokenTests {
@Before
public void setUp() {
this.clientRegistration = mock(ClientRegistration.class);
this.authorizationExchange = mock(OAuth2AuthorizationExchange.class);
this.accessToken = mock(OAuth2AccessToken.class);
this.clientRegistration = clientRegistration().build();
this.authorizationExchange = new OAuth2AuthorizationExchange(request().build(),
success().code("code").build());
this.accessToken = noScopes();
}
@Test
@ -65,10 +63,6 @@ public class OAuth2AuthorizationCodeAuthenticationTokenTests {
@Test
public void constructorAuthorizationRequestResponseWhenAllParametersProvidedAndValidThenCreated() {
OAuth2AuthorizationResponse authorizationResponse = mock(OAuth2AuthorizationResponse.class);
when(authorizationResponse.getCode()).thenReturn("code");
when(this.authorizationExchange.getAuthorizationResponse()).thenReturn(authorizationResponse);
OAuth2AuthorizationCodeAuthenticationToken authentication =
new OAuth2AuthorizationCodeAuthenticationToken(this.clientRegistration, this.authorizationExchange);

View File

@ -15,15 +15,21 @@
*/
package org.springframework.security.oauth2.client.authentication;
import java.time.Instant;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
@ -34,7 +40,6 @@ import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
@ -42,30 +47,22 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
import org.springframework.security.oauth2.core.user.OAuth2User;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static org.assertj.core.api.Assertions.assertThat;
import static org.hamcrest.CoreMatchers.containsString;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyCollection;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration;
import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests.request;
import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses.error;
import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses.success;
/**
* Tests for {@link OAuth2LoginAuthenticationProvider}.
*
* @author Joe Grandja
*/
@PrepareForTest({ClientRegistration.class, OAuth2AuthorizationRequest.class,
OAuth2AuthorizationResponse.class, OAuth2AccessTokenResponse.class})
@RunWith(PowerMockRunner.class)
public class OAuth2LoginAuthenticationProviderTests {
private ClientRegistration clientRegistration;
private OAuth2AuthorizationRequest authorizationRequest;
@ -81,19 +78,13 @@ public class OAuth2LoginAuthenticationProviderTests {
@Before
@SuppressWarnings("unchecked")
public void setUp() throws Exception {
this.clientRegistration = mock(ClientRegistration.class);
this.authorizationRequest = mock(OAuth2AuthorizationRequest.class);
this.authorizationResponse = mock(OAuth2AuthorizationResponse.class);
this.clientRegistration = clientRegistration().build();
this.authorizationRequest = request().scope("scope1", "scope2").build();
this.authorizationResponse = success().build();
this.authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, this.authorizationResponse);
this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class);
this.userService = mock(OAuth2UserService.class);
this.authenticationProvider = new OAuth2LoginAuthenticationProvider(this.accessTokenResponseClient, this.userService);
when(this.authorizationRequest.getScopes()).thenReturn(new LinkedHashSet<>(Arrays.asList("scope1", "scope2")));
when(this.authorizationRequest.getState()).thenReturn("12345");
when(this.authorizationResponse.getState()).thenReturn("12345");
when(this.authorizationRequest.getRedirectUri()).thenReturn("http://example.com");
when(this.authorizationResponse.getRedirectUri()).thenReturn("http://example.com");
}
@Test
@ -121,11 +112,13 @@ public class OAuth2LoginAuthenticationProviderTests {
@Test
public void authenticateWhenAuthorizationRequestContainsOpenidScopeThenReturnNull() {
when(this.authorizationRequest.getScopes()).thenReturn(new LinkedHashSet<>(Collections.singleton("openid")));
OAuth2AuthorizationRequest authorizationRequest = request().scope("openid").build();
OAuth2AuthorizationExchange authorizationExchange =
new OAuth2AuthorizationExchange(authorizationRequest, this.authorizationResponse);
OAuth2LoginAuthenticationToken authentication =
(OAuth2LoginAuthenticationToken) this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange));
assertThat(authentication).isNull();
}
@ -135,11 +128,13 @@ public class OAuth2LoginAuthenticationProviderTests {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString(OAuth2ErrorCodes.INVALID_REQUEST));
when(this.authorizationResponse.statusError()).thenReturn(true);
when(this.authorizationResponse.getError()).thenReturn(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST));
OAuth2AuthorizationResponse authorizationResponse =
error().errorCode(OAuth2ErrorCodes.INVALID_REQUEST).build();
OAuth2AuthorizationExchange authorizationExchange =
new OAuth2AuthorizationExchange(this.authorizationRequest, authorizationResponse);
this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange));
}
@Test
@ -147,11 +142,13 @@ public class OAuth2LoginAuthenticationProviderTests {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_state_parameter"));
when(this.authorizationRequest.getState()).thenReturn("12345");
when(this.authorizationResponse.getState()).thenReturn("67890");
OAuth2AuthorizationResponse authorizationResponse =
success().state("67890").build();
OAuth2AuthorizationExchange authorizationExchange =
new OAuth2AuthorizationExchange(this.authorizationRequest, authorizationResponse);
this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange));
}
@Test
@ -159,11 +156,13 @@ public class OAuth2LoginAuthenticationProviderTests {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_redirect_uri_parameter"));
when(this.authorizationRequest.getRedirectUri()).thenReturn("http://example.com");
when(this.authorizationResponse.getRedirectUri()).thenReturn("http://example2.com");
OAuth2AuthorizationResponse authorizationResponse =
success().redirectUri("http://example2.com").build();
OAuth2AuthorizationExchange authorizationExchange =
new OAuth2AuthorizationExchange(this.authorizationRequest, authorizationResponse);
this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange));
}
@Test

View File

@ -15,30 +15,30 @@
*/
package org.springframework.security.oauth2.client.authentication;
import java.util.Collection;
import java.util.Collections;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
import org.springframework.security.oauth2.core.user.OAuth2User;
import java.util.Collection;
import java.util.Collections;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration;
import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes;
import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests.request;
import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses.success;
/**
* Tests for {@link OAuth2LoginAuthenticationToken}.
*
* @author Joe Grandja
*/
@RunWith(PowerMockRunner.class)
@PrepareForTest({ClientRegistration.class, OAuth2AuthorizationExchange.class})
public class OAuth2LoginAuthenticationTokenTests {
private OAuth2User principal;
private Collection<? extends GrantedAuthority> authorities;
@ -50,9 +50,10 @@ public class OAuth2LoginAuthenticationTokenTests {
public void setUp() {
this.principal = mock(OAuth2User.class);
this.authorities = Collections.emptyList();
this.clientRegistration = mock(ClientRegistration.class);
this.authorizationExchange = mock(OAuth2AuthorizationExchange.class);
this.accessToken = mock(OAuth2AccessToken.class);
this.clientRegistration = clientRegistration().build();
this.authorizationExchange = new OAuth2AuthorizationExchange(
request().build(), success().code("code").build());
this.accessToken = noScopes();
}
@Test(expected = IllegalArgumentException.class)

View File

@ -15,16 +15,15 @@
*/
package org.springframework.security.oauth2.client.endpoint;
import java.time.Instant;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.powermock.core.classloader.annotations.PowerMockIgnore;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
@ -36,27 +35,19 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExch
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
import java.time.Instant;
import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.Set;
import static org.assertj.core.api.Assertions.assertThat;
import static org.hamcrest.CoreMatchers.containsString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration;
import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests.request;
import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses.success;
/**
* Tests for {@link NimbusAuthorizationCodeTokenResponseClient}.
*
* @author Joe Grandja
*/
@PowerMockIgnore("okhttp3.*")
@PrepareForTest({ClientRegistration.class, OAuth2AuthorizationRequest.class, OAuth2AuthorizationResponse.class, OAuth2AuthorizationExchange.class})
@RunWith(PowerMockRunner.class)
public class NimbusAuthorizationCodeTokenResponseClientTests {
private ClientRegistration clientRegistration;
private ClientRegistration.ProviderDetails providerDetails;
private ClientRegistration.Builder clientRegistrationBuilder;
private OAuth2AuthorizationRequest authorizationRequest;
private OAuth2AuthorizationResponse authorizationResponse;
private OAuth2AuthorizationExchange authorizationExchange;
@ -67,18 +58,11 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
@Before
public void setUp() throws Exception {
this.clientRegistration = mock(ClientRegistration.class);
this.providerDetails = mock(ClientRegistration.ProviderDetails.class);
this.authorizationRequest = mock(OAuth2AuthorizationRequest.class);
this.authorizationResponse = mock(OAuth2AuthorizationResponse.class);
this.clientRegistrationBuilder = clientRegistration()
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC);
this.authorizationRequest = request().build();
this.authorizationResponse = success().build();
this.authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, this.authorizationResponse);
when(this.clientRegistration.getProviderDetails()).thenReturn(this.providerDetails);
when(this.clientRegistration.getClientId()).thenReturn("client-id");
when(this.clientRegistration.getClientSecret()).thenReturn("secret");
when(this.clientRegistration.getClientAuthenticationMethod()).thenReturn(ClientAuthenticationMethod.BASIC);
when(this.authorizationRequest.getRedirectUri()).thenReturn("http://example.com");
when(this.authorizationResponse.getCode()).thenReturn("code");
}
@Test
@ -100,12 +84,13 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
server.start();
String tokenUri = server.url("/oauth2/token").toString();
when(this.providerDetails.getTokenUri()).thenReturn(tokenUri);
this.clientRegistrationBuilder.tokenUri(tokenUri);
Instant expiresAtBefore = Instant.now().plusSeconds(3600);
OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(
new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange));
new OAuth2AuthorizationCodeGrantRequest(
this.clientRegistrationBuilder.build(), this.authorizationExchange));
Instant expiresAtAfter = Instant.now().plusSeconds(3600);
@ -126,10 +111,13 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
this.exception.expect(IllegalArgumentException.class);
String redirectUri = "http:\\example.com";
when(this.clientRegistration.getRedirectUriTemplate()).thenReturn(redirectUri);
OAuth2AuthorizationRequest authorizationRequest = request().redirectUri(redirectUri).build();
OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(
authorizationRequest, this.authorizationResponse);
this.tokenResponseClient.getTokenResponse(
new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange));
new OAuth2AuthorizationCodeGrantRequest(
this.clientRegistrationBuilder.build(), authorizationExchange));
}
@Test
@ -137,10 +125,11 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
this.exception.expect(IllegalArgumentException.class);
String tokenUri = "http:\\provider.com\\oauth2\\token";
when(this.providerDetails.getTokenUri()).thenReturn(tokenUri);
this.clientRegistrationBuilder.tokenUri(tokenUri);
this.tokenResponseClient.getTokenResponse(
new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange));
new OAuth2AuthorizationCodeGrantRequest(
this.clientRegistrationBuilder.build(), this.authorizationExchange));
}
@Test
@ -165,11 +154,12 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
server.start();
String tokenUri = server.url("/oauth2/token").toString();
when(this.providerDetails.getTokenUri()).thenReturn(tokenUri);
this.clientRegistrationBuilder.tokenUri(tokenUri);
try {
this.tokenResponseClient.getTokenResponse(
new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange));
new OAuth2AuthorizationCodeGrantRequest(
this.clientRegistrationBuilder.build(), this.authorizationExchange));
} finally {
server.shutdown();
}
@ -180,10 +170,11 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
this.exception.expect(OAuth2AuthorizationException.class);
String tokenUri = "http://invalid-provider.com/oauth2/token";
when(this.providerDetails.getTokenUri()).thenReturn(tokenUri);
this.clientRegistrationBuilder.tokenUri(tokenUri);
this.tokenResponseClient.getTokenResponse(
new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange));
new OAuth2AuthorizationCodeGrantRequest(
this.clientRegistrationBuilder.build(), this.authorizationExchange));
}
@Test
@ -203,11 +194,12 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
server.start();
String tokenUri = server.url("/oauth2/token").toString();
when(this.providerDetails.getTokenUri()).thenReturn(tokenUri);
this.clientRegistrationBuilder.tokenUri(tokenUri);
try {
this.tokenResponseClient.getTokenResponse(
new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange));
new OAuth2AuthorizationCodeGrantRequest(
this.clientRegistrationBuilder.build(), this.authorizationExchange));
} finally {
server.shutdown();
}
@ -225,11 +217,12 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
server.start();
String tokenUri = server.url("/oauth2/token").toString();
when(this.providerDetails.getTokenUri()).thenReturn(tokenUri);
this.clientRegistrationBuilder.tokenUri(tokenUri);
try {
this.tokenResponseClient.getTokenResponse(
new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange));
new OAuth2AuthorizationCodeGrantRequest(
this.clientRegistrationBuilder.build(), this.authorizationExchange));
} finally {
server.shutdown();
}
@ -254,11 +247,12 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
server.start();
String tokenUri = server.url("/oauth2/token").toString();
when(this.providerDetails.getTokenUri()).thenReturn(tokenUri);
this.clientRegistrationBuilder.tokenUri(tokenUri);
try {
this.tokenResponseClient.getTokenResponse(
new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange));
new OAuth2AuthorizationCodeGrantRequest(
this.clientRegistrationBuilder.build(), this.authorizationExchange));
} finally {
server.shutdown();
}
@ -280,13 +274,16 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
server.start();
String tokenUri = server.url("/oauth2/token").toString();
when(this.providerDetails.getTokenUri()).thenReturn(tokenUri);
this.clientRegistrationBuilder.tokenUri(tokenUri);
Set<String> requestedScopes = new LinkedHashSet<>(Arrays.asList("openid", "profile", "email", "address"));
when(this.authorizationRequest.getScopes()).thenReturn(requestedScopes);
OAuth2AuthorizationRequest authorizationRequest =
request().scope("openid", "profile", "email", "address").build();
OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(
authorizationRequest, this.authorizationResponse);
OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(
new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange));
new OAuth2AuthorizationCodeGrantRequest(
this.clientRegistrationBuilder.build(), authorizationExchange));
server.shutdown();
@ -308,13 +305,16 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
server.start();
String tokenUri = server.url("/oauth2/token").toString();
when(this.providerDetails.getTokenUri()).thenReturn(tokenUri);
this.clientRegistrationBuilder.tokenUri(tokenUri);
Set<String> requestedScopes = new LinkedHashSet<>(Arrays.asList("openid", "profile", "email", "address"));
when(this.authorizationRequest.getScopes()).thenReturn(requestedScopes);
OAuth2AuthorizationRequest authorizationRequest =
request().scope("openid", "profile", "email", "address").build();
OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(
authorizationRequest, this.authorizationResponse);
OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(
new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange));
new OAuth2AuthorizationCodeGrantRequest(
this.clientRegistrationBuilder.build(), authorizationExchange));
server.shutdown();

View File

@ -17,31 +17,28 @@ package org.springframework.security.oauth2.client.endpoint;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration;
import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationExchanges.success;
/**
* Tests for {@link OAuth2AuthorizationCodeGrantRequest}.
*
* @author Joe Grandja
*/
@RunWith(PowerMockRunner.class)
@PrepareForTest({ClientRegistration.class, OAuth2AuthorizationExchange.class})
public class OAuth2AuthorizationCodeGrantRequestTests {
private ClientRegistration clientRegistration;
private OAuth2AuthorizationExchange authorizationExchange;
@Before
public void setUp() {
this.clientRegistration = mock(ClientRegistration.class);
this.authorizationExchange = mock(OAuth2AuthorizationExchange.class);
this.clientRegistration = clientRegistration().build();
this.authorizationExchange = success();
}
@Test(expected = IllegalArgumentException.class)

View File

@ -15,16 +15,22 @@
*/
package org.springframework.security.oauth2.client.oidc.authentication;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
import org.powermock.api.mockito.PowerMockito;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
@ -36,7 +42,6 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
@ -47,33 +52,27 @@ import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.springframework.test.util.ReflectionTestUtils;
import static org.assertj.core.api.Assertions.assertThat;
import static org.hamcrest.CoreMatchers.containsString;
import static org.mockito.ArgumentMatchers.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyCollection;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration;
import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests.request;
import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses.error;
import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses.success;
/**
* Tests for {@link OidcAuthorizationCodeAuthenticationProvider}.
*
* @author Joe Grandja
*/
@PrepareForTest({ClientRegistration.class, OAuth2AuthorizationRequest.class, OAuth2AuthorizationResponse.class,
OAuth2AccessTokenResponse.class, OidcAuthorizationCodeAuthenticationProvider.class})
@RunWith(PowerMockRunner.class)
public class OidcAuthorizationCodeAuthenticationProviderTests {
private ClientRegistration clientRegistration;
private ClientRegistration.ProviderDetails providerDetails;
private OAuth2AuthorizationRequest authorizationRequest;
private OAuth2AuthorizationResponse authorizationResponse;
private OAuth2AuthorizationExchange authorizationExchange;
@ -88,26 +87,16 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
@Before
@SuppressWarnings("unchecked")
public void setUp() throws Exception {
this.clientRegistration = mock(ClientRegistration.class);
this.providerDetails = mock(ClientRegistration.ProviderDetails.class);
this.authorizationRequest = mock(OAuth2AuthorizationRequest.class);
this.authorizationResponse = mock(OAuth2AuthorizationResponse.class);
this.clientRegistration = clientRegistration().clientId("client1").build();
this.authorizationRequest = request().scope("openid", "profile", "email").build();
this.authorizationResponse = success().build();
this.authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, this.authorizationResponse);
this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class);
this.accessTokenResponse = this.accessTokenSuccessResponse();
this.userService = mock(OAuth2UserService.class);
this.authenticationProvider = PowerMockito.spy(
new OidcAuthorizationCodeAuthenticationProvider(this.accessTokenResponseClient, this.userService));
this.authenticationProvider =
new OidcAuthorizationCodeAuthenticationProvider(this.accessTokenResponseClient, this.userService);
when(this.clientRegistration.getRegistrationId()).thenReturn("client-registration-id-1");
when(this.clientRegistration.getClientId()).thenReturn("client1");
when(this.clientRegistration.getProviderDetails()).thenReturn(this.providerDetails);
when(this.providerDetails.getJwkSetUri()).thenReturn("https://provider.com/oauth2/keys");
when(this.authorizationRequest.getScopes()).thenReturn(new LinkedHashSet<>(Arrays.asList("openid", "profile", "email")));
when(this.authorizationRequest.getState()).thenReturn("12345");
when(this.authorizationResponse.getState()).thenReturn("12345");
when(this.authorizationRequest.getRedirectUri()).thenReturn("http://example.com");
when(this.authorizationResponse.getRedirectUri()).thenReturn("http://example.com");
when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(this.accessTokenResponse);
}
@ -136,11 +125,13 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
@Test
public void authenticateWhenAuthorizationRequestDoesNotContainOpenidScopeThenReturnNull() {
when(this.authorizationRequest.getScopes()).thenReturn(new LinkedHashSet<>(Collections.singleton("scope1")));
OAuth2AuthorizationRequest authorizationRequest = request().scope("scope1").build();
OAuth2AuthorizationExchange authorizationExchange =
new OAuth2AuthorizationExchange(authorizationRequest, this.authorizationResponse);
OAuth2LoginAuthenticationToken authentication =
(OAuth2LoginAuthenticationToken) this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange));
assertThat(authentication).isNull();
}
@ -150,11 +141,12 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString(OAuth2ErrorCodes.INVALID_SCOPE));
when(this.authorizationResponse.statusError()).thenReturn(true);
when(this.authorizationResponse.getError()).thenReturn(new OAuth2Error(OAuth2ErrorCodes.INVALID_SCOPE));
OAuth2AuthorizationResponse authorizationResponse = error().errorCode(OAuth2ErrorCodes.INVALID_SCOPE).build();
OAuth2AuthorizationExchange authorizationExchange =
new OAuth2AuthorizationExchange(this.authorizationRequest, authorizationResponse);
this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange));
}
@Test
@ -162,11 +154,12 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_state_parameter"));
when(this.authorizationRequest.getState()).thenReturn("34567");
when(this.authorizationResponse.getState()).thenReturn("89012");
OAuth2AuthorizationResponse authorizationResponse = success().state("89012").build();
OAuth2AuthorizationExchange authorizationExchange =
new OAuth2AuthorizationExchange(this.authorizationRequest, authorizationResponse);
this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange));
}
@Test
@ -174,11 +167,12 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_redirect_uri_parameter"));
when(this.authorizationRequest.getRedirectUri()).thenReturn("http://example1.com");
when(this.authorizationResponse.getRedirectUri()).thenReturn("http://example2.com");
OAuth2AuthorizationResponse authorizationResponse = success().redirectUri("http://example2.com").build();
OAuth2AuthorizationExchange authorizationExchange =
new OAuth2AuthorizationExchange(this.authorizationRequest, authorizationResponse);
this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange));
}
@Test
@ -201,10 +195,10 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("missing_signature_verifier"));
when(this.providerDetails.getJwkSetUri()).thenReturn(null);
ClientRegistration clientRegistration = clientRegistration().jwkSetUri(null).build();
this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
new OAuth2LoginAuthenticationToken(clientRegistration, this.authorizationExchange));
}
@Test
@ -434,7 +428,8 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
JwtDecoder jwtDecoder = mock(JwtDecoder.class);
when(jwtDecoder.decode(anyString())).thenReturn(idToken);
PowerMockito.doReturn(jwtDecoder).when(this.authenticationProvider, "getJwtDecoder", any(ClientRegistration.class));
ReflectionTestUtils.setField(this.authenticationProvider,
"jwtDecoders", Collections.singletonMap("registration-id", jwtDecoder));
}
private OAuth2AccessTokenResponse accessTokenSuccessResponse() {

View File

@ -15,6 +15,14 @@
*/
package org.springframework.security.oauth2.client.oidc.userinfo;
import java.time.Instant;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;
@ -23,17 +31,13 @@ import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.powermock.core.classloader.annotations.PowerMockIgnore;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService;
import org.springframework.security.oauth2.core.AuthenticationMethod;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
@ -43,31 +47,19 @@ import org.springframework.security.oauth2.core.oidc.StandardClaimNames;
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.hamcrest.CoreMatchers.containsString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration;
import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.scopes;
/**
* Tests for {@link OidcUserService}.
*
* @author Joe Grandja
*/
@PowerMockIgnore({"okhttp3.*", "okio.Buffer"})
@PrepareForTest(ClientRegistration.class)
@RunWith(PowerMockRunner.class)
public class OidcUserServiceTests {
private ClientRegistration clientRegistration;
private ClientRegistration.ProviderDetails providerDetails;
private ClientRegistration.ProviderDetails.UserInfoEndpoint userInfoEndpoint;
private ClientRegistration.Builder clientRegistrationBuilder;
private OAuth2AccessToken accessToken;
private OidcIdToken idToken;
private OidcUserService userService = new OidcUserService();
@ -80,26 +72,17 @@ public class OidcUserServiceTests {
public void setup() throws Exception {
this.server = new MockWebServer();
this.server.start();
this.clientRegistration = mock(ClientRegistration.class);
this.providerDetails = mock(ClientRegistration.ProviderDetails.class);
this.userInfoEndpoint = mock(ClientRegistration.ProviderDetails.UserInfoEndpoint.class);
when(this.clientRegistration.getProviderDetails()).thenReturn(this.providerDetails);
when(this.providerDetails.getUserInfoEndpoint()).thenReturn(this.userInfoEndpoint);
when(this.clientRegistration.getAuthorizationGrantType()).thenReturn(AuthorizationGrantType.AUTHORIZATION_CODE);
this.clientRegistrationBuilder = clientRegistration()
.userInfoUri(null)
.userInfoAuthenticationMethod(AuthenticationMethod.HEADER)
.userNameAttributeName(StandardClaimNames.SUB);
when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER);
when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn(StandardClaimNames.SUB);
this.accessToken = scopes(OidcScopes.OPENID, OidcScopes.PROFILE);
this.accessToken = mock(OAuth2AccessToken.class);
Set<String> authorizedScopes = new LinkedHashSet<>(Arrays.asList(OidcScopes.OPENID, OidcScopes.PROFILE));
when(this.accessToken.getScopes()).thenReturn(authorizedScopes);
this.idToken = mock(OidcIdToken.class);
Map<String, Object> idTokenClaims = new HashMap<>();
idTokenClaims.put(IdTokenClaimNames.ISS, "https://provider.com");
idTokenClaims.put(IdTokenClaimNames.SUB, "subject1");
when(this.idToken.getClaims()).thenReturn(idTokenClaims);
when(this.idToken.getSubject()).thenReturn("subject1");
this.idToken = new OidcIdToken("access-token", Instant.MIN, Instant.MAX, idTokenClaims);
this.userService.setOauth2UserService(new DefaultOAuth2UserService());
}
@ -123,22 +106,23 @@ public class OidcUserServiceTests {
@Test
public void loadUserWhenUserInfoUriIsNullThenUserInfoEndpointNotRequested() {
when(this.userInfoEndpoint.getUri()).thenReturn(null);
OidcUser user = this.userService.loadUser(
new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
new OidcUserRequest(this.clientRegistrationBuilder.build(), this.accessToken, this.idToken));
assertThat(user.getUserInfo()).isNull();
}
@Test
public void loadUserWhenAuthorizedScopesDoesNotContainUserInfoScopesThenUserInfoEndpointNotRequested() {
Set<String> authorizedScopes = new LinkedHashSet<>(Arrays.asList("scope1", "scope2"));
when(this.accessToken.getScopes()).thenReturn(authorizedScopes);
ClientRegistration clientRegistration = this.clientRegistrationBuilder
.userInfoUri("http://provider.com/user").build();
when(this.userInfoEndpoint.getUri()).thenReturn("http://provider.com/user");
Set<String> authorizedScopes = new LinkedHashSet<>(Arrays.asList("scope1", "scope2"));
OAuth2AccessToken accessToken = new OAuth2AccessToken(
OAuth2AccessToken.TokenType.BEARER, "access-token",
Instant.MIN, Instant.MAX, authorizedScopes);
OidcUser user = this.userService.loadUser(
new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
new OidcUserRequest(clientRegistration, accessToken, this.idToken));
assertThat(user.getUserInfo()).isNull();
}
@ -156,11 +140,11 @@ public class OidcUserServiceTests {
String userInfoUri = this.server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.accessToken.getTokenValue()).thenReturn("access-token");
ClientRegistration clientRegistration = this.clientRegistrationBuilder
.userInfoUri(userInfoUri).build();
OidcUser user = this.userService.loadUser(
new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
new OidcUserRequest(clientRegistration, this.accessToken, this.idToken));
assertThat(user.getIdToken()).isNotNull();
assertThat(user.getUserInfo()).isNotNull();
@ -196,11 +180,11 @@ public class OidcUserServiceTests {
String userInfoUri = this.server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn(StandardClaimNames.EMAIL);
when(this.accessToken.getTokenValue()).thenReturn("access-token");
ClientRegistration clientRegistration = this.clientRegistrationBuilder
.userInfoUri(userInfoUri)
.userNameAttributeName(StandardClaimNames.EMAIL).build();
this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken));
}
@Test
@ -215,10 +199,10 @@ public class OidcUserServiceTests {
String userInfoUri = this.server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.accessToken.getTokenValue()).thenReturn("access-token");
ClientRegistration clientRegistration = this.clientRegistrationBuilder
.userInfoUri(userInfoUri).build();
this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken));
}
@Test
@ -238,10 +222,10 @@ public class OidcUserServiceTests {
String userInfoUri = this.server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.accessToken.getTokenValue()).thenReturn("access-token");
ClientRegistration clientRegistration = this.clientRegistrationBuilder
.userInfoUri(userInfoUri).build();
this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken));
}
@Test
@ -253,10 +237,10 @@ public class OidcUserServiceTests {
String userInfoUri = server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.accessToken.getTokenValue()).thenReturn("access-token");
ClientRegistration clientRegistration = this.clientRegistrationBuilder
.userInfoUri(userInfoUri).build();
this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken));
}
@Test
@ -266,10 +250,10 @@ public class OidcUserServiceTests {
String userInfoUri = "http://invalid-provider.com/user";
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.accessToken.getTokenValue()).thenReturn("access-token");
ClientRegistration clientRegistration = this.clientRegistrationBuilder
.userInfoUri(userInfoUri).build();
this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken));
}
@Test
@ -286,12 +270,12 @@ public class OidcUserServiceTests {
String userInfoUri = this.server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn(StandardClaimNames.EMAIL);
when(this.accessToken.getTokenValue()).thenReturn("access-token");
ClientRegistration clientRegistration = this.clientRegistrationBuilder
.userInfoUri(userInfoUri)
.userNameAttributeName(StandardClaimNames.EMAIL).build();
OidcUser user = this.userService.loadUser(
new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
new OidcUserRequest(clientRegistration, this.accessToken, this.idToken));
assertThat(user.getName()).isEqualTo("user1@example.com");
}
@ -311,10 +295,10 @@ public class OidcUserServiceTests {
String userInfoUri = this.server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.accessToken.getTokenValue()).thenReturn("access-token");
ClientRegistration clientRegistration = this.clientRegistrationBuilder
.userInfoUri(userInfoUri).build();
this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken));
assertThat(this.server.takeRequest(1, TimeUnit.SECONDS).getHeader(HttpHeaders.ACCEPT))
.isEqualTo(MediaType.APPLICATION_JSON_VALUE);
}
@ -334,11 +318,10 @@ public class OidcUserServiceTests {
String userInfoUri = this.server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER);
when(this.accessToken.getTokenValue()).thenReturn("access-token");
ClientRegistration clientRegistration = this.clientRegistrationBuilder
.userInfoUri(userInfoUri).build();
this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken));
RecordedRequest request = this.server.takeRequest();
assertThat(request.getMethod()).isEqualTo(HttpMethod.GET.name());
assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE);
@ -360,11 +343,11 @@ public class OidcUserServiceTests {
String userInfoUri = this.server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.FORM);
when(this.accessToken.getTokenValue()).thenReturn("access-token");
ClientRegistration clientRegistration = this.clientRegistrationBuilder
.userInfoUri(userInfoUri)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM).build();
this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken));
RecordedRequest request = this.server.takeRequest();
assertThat(request.getMethod()).isEqualTo(HttpMethod.POST.name());
assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE);

View File

@ -15,6 +15,12 @@
*/
package org.springframework.security.oauth2.client.userinfo;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import org.junit.After;
@ -22,42 +28,29 @@ import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.powermock.core.classloader.annotations.PowerMockIgnore;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.user.OAuth2User;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat;
import static org.hamcrest.CoreMatchers.containsString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration;
import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes;
/**
* Tests for {@link CustomUserTypesOAuth2UserService}.
*
* @author Joe Grandja
*/
@PowerMockIgnore("okhttp3.*")
@PrepareForTest(ClientRegistration.class)
@RunWith(PowerMockRunner.class)
public class CustomUserTypesOAuth2UserServiceTests {
private ClientRegistration clientRegistration;
private ClientRegistration.ProviderDetails providerDetails;
private ClientRegistration.ProviderDetails.UserInfoEndpoint userInfoEndpoint;
private ClientRegistration.Builder clientRegistrationBuilder;
private OAuth2AccessToken accessToken;
private CustomUserTypesOAuth2UserService userService;
private MockWebServer server;
@ -69,14 +62,9 @@ public class CustomUserTypesOAuth2UserServiceTests {
public void setUp() throws Exception {
this.server = new MockWebServer();
this.server.start();
this.clientRegistration = mock(ClientRegistration.class);
this.providerDetails = mock(ClientRegistration.ProviderDetails.class);
this.userInfoEndpoint = mock(ClientRegistration.ProviderDetails.UserInfoEndpoint.class);
when(this.clientRegistration.getProviderDetails()).thenReturn(this.providerDetails);
when(this.providerDetails.getUserInfoEndpoint()).thenReturn(this.userInfoEndpoint);
String registrationId = "client-registration-id-1";
when(this.clientRegistration.getRegistrationId()).thenReturn(registrationId);
this.accessToken = mock(OAuth2AccessToken.class);
this.clientRegistrationBuilder = clientRegistration().registrationId(registrationId);
this.accessToken = noScopes();
Map<String, Class<? extends OAuth2User>> customUserTypes = new HashMap<>();
customUserTypes.put(registrationId, CustomOAuth2User.class);
@ -120,9 +108,10 @@ public class CustomUserTypesOAuth2UserServiceTests {
@Test
public void loadUserWhenCustomUserTypeNotFoundThenReturnNull() {
when(this.clientRegistration.getRegistrationId()).thenReturn("other-client-registration-id-1");
ClientRegistration clientRegistration =
clientRegistration().registrationId("other-client-registration-id-1").build();
OAuth2User user = this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
OAuth2User user = this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken));
assertThat(user).isNull();
}
@ -138,10 +127,10 @@ public class CustomUserTypesOAuth2UserServiceTests {
String userInfoUri = this.server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.accessToken.getTokenValue()).thenReturn("access-token");
ClientRegistration clientRegistration = this.clientRegistrationBuilder
.userInfoUri(userInfoUri).build();
OAuth2User user = this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
OAuth2User user = this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken));
assertThat(user.getName()).isEqualTo("first last");
assertThat(user.getAttributes().size()).isEqualTo(4);
@ -169,10 +158,10 @@ public class CustomUserTypesOAuth2UserServiceTests {
String userInfoUri = this.server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.accessToken.getTokenValue()).thenReturn("access-token");
ClientRegistration clientRegistration = this.clientRegistrationBuilder
.userInfoUri(userInfoUri).build();
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken));
}
@Test
@ -184,10 +173,10 @@ public class CustomUserTypesOAuth2UserServiceTests {
String userInfoUri = this.server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.accessToken.getTokenValue()).thenReturn("access-token");
ClientRegistration clientRegistration = this.clientRegistrationBuilder
.userInfoUri(userInfoUri).build();
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken));
}
@Test
@ -197,10 +186,18 @@ public class CustomUserTypesOAuth2UserServiceTests {
String userInfoUri = "http://invalid-provider.com/user";
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.accessToken.getTokenValue()).thenReturn("access-token");
ClientRegistration clientRegistration = this.clientRegistrationBuilder
.userInfoUri(userInfoUri).build();
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken));
}
private ClientRegistration.Builder withRegistrationId(String registrationId) {
return ClientRegistration
.withRegistrationId(registrationId)
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
.clientId("client")
.tokenUri("/token");
}
private MockResponse jsonResponse(String json) {

View File

@ -15,6 +15,8 @@
*/
package org.springframework.security.oauth2.client.userinfo;
import java.util.concurrent.TimeUnit;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;
@ -23,10 +25,7 @@ import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.powermock.core.classloader.annotations.PowerMockIgnore;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
@ -37,25 +36,18 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.security.oauth2.core.user.OAuth2UserAuthority;
import java.util.concurrent.TimeUnit;
import static org.assertj.core.api.Assertions.assertThat;
import static org.hamcrest.CoreMatchers.containsString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration;
import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes;
/**
* Tests for {@link DefaultOAuth2UserService}.
*
* @author Joe Grandja
*/
@PowerMockIgnore({"okhttp3.*", "okio.Buffer"})
@PrepareForTest(ClientRegistration.class)
@RunWith(PowerMockRunner.class)
public class DefaultOAuth2UserServiceTests {
private ClientRegistration clientRegistration;
private ClientRegistration.ProviderDetails providerDetails;
private ClientRegistration.ProviderDetails.UserInfoEndpoint userInfoEndpoint;
private ClientRegistration.Builder clientRegistrationBuilder;
private OAuth2AccessToken accessToken;
private DefaultOAuth2UserService userService = new DefaultOAuth2UserService();
private MockWebServer server;
@ -67,12 +59,10 @@ public class DefaultOAuth2UserServiceTests {
public void setup() throws Exception {
this.server = new MockWebServer();
this.server.start();
this.clientRegistration = mock(ClientRegistration.class);
this.providerDetails = mock(ClientRegistration.ProviderDetails.class);
this.userInfoEndpoint = mock(ClientRegistration.ProviderDetails.UserInfoEndpoint.class);
when(this.clientRegistration.getProviderDetails()).thenReturn(this.providerDetails);
when(this.providerDetails.getUserInfoEndpoint()).thenReturn(this.userInfoEndpoint);
this.accessToken = mock(OAuth2AccessToken.class);
this.clientRegistrationBuilder = clientRegistration()
.userInfoUri(null)
.userNameAttributeName(null);
this.accessToken = noScopes();
}
@After
@ -103,8 +93,8 @@ public class DefaultOAuth2UserServiceTests {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("missing_user_info_uri"));
when(this.userInfoEndpoint.getUri()).thenReturn(null);
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
ClientRegistration clientRegistration = this.clientRegistrationBuilder.build();
this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken));
}
@Test
@ -112,9 +102,9 @@ public class DefaultOAuth2UserServiceTests {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("missing_user_name_attribute"));
when(this.userInfoEndpoint.getUri()).thenReturn("http://provider.com/user");
when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn(null);
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
ClientRegistration clientRegistration = this.clientRegistrationBuilder
.userInfoUri("http://provider.com/user").build();
this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken));
}
@Test
@ -131,12 +121,12 @@ public class DefaultOAuth2UserServiceTests {
String userInfoUri = this.server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER);
when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name");
when(this.accessToken.getTokenValue()).thenReturn("access-token");
ClientRegistration clientRegistration = this.clientRegistrationBuilder
.userInfoUri(userInfoUri)
.userInfoAuthenticationMethod(AuthenticationMethod.HEADER)
.userNameAttributeName("user-name").build();
OAuth2User user = this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
OAuth2User user = this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken));
assertThat(user.getName()).isEqualTo("user1");
assertThat(user.getAttributes().size()).isEqualTo(6);
@ -171,12 +161,12 @@ public class DefaultOAuth2UserServiceTests {
String userInfoUri = this.server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER);
when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name");
when(this.accessToken.getTokenValue()).thenReturn("access-token");
ClientRegistration clientRegistration = this.clientRegistrationBuilder
.userInfoUri(userInfoUri)
.userInfoAuthenticationMethod(AuthenticationMethod.HEADER)
.userNameAttributeName("user-name").build();
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken));
}
@Test
@ -194,12 +184,12 @@ public class DefaultOAuth2UserServiceTests {
String userInfoUri = this.server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER);
when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name");
when(this.accessToken.getTokenValue()).thenReturn("access-token");
ClientRegistration clientRegistration = this.clientRegistrationBuilder
.userInfoUri(userInfoUri)
.userInfoAuthenticationMethod(AuthenticationMethod.HEADER)
.userNameAttributeName("user-name").build();
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken));
}
@Test
@ -215,12 +205,12 @@ public class DefaultOAuth2UserServiceTests {
String userInfoUri = this.server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER);
when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name");
when(this.accessToken.getTokenValue()).thenReturn("access-token");
ClientRegistration clientRegistration = this.clientRegistrationBuilder
.userInfoUri(userInfoUri)
.userInfoAuthenticationMethod(AuthenticationMethod.HEADER)
.userNameAttributeName("user-name").build();
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken));
}
@Test
@ -232,12 +222,12 @@ public class DefaultOAuth2UserServiceTests {
String userInfoUri = this.server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER);
when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name");
when(this.accessToken.getTokenValue()).thenReturn("access-token");
ClientRegistration clientRegistration = this.clientRegistrationBuilder
.userInfoUri(userInfoUri)
.userInfoAuthenticationMethod(AuthenticationMethod.HEADER)
.userNameAttributeName("user-name").build();
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken));
}
@Test
@ -247,12 +237,12 @@ public class DefaultOAuth2UserServiceTests {
String userInfoUri = "http://invalid-provider.com/user";
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER);
when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name");
when(this.accessToken.getTokenValue()).thenReturn("access-token");
ClientRegistration clientRegistration = this.clientRegistrationBuilder
.userInfoUri(userInfoUri)
.userInfoAuthenticationMethod(AuthenticationMethod.HEADER)
.userNameAttributeName("user-name").build();
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken));
}
// gh-5294
@ -270,12 +260,12 @@ public class DefaultOAuth2UserServiceTests {
String userInfoUri = this.server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER);
when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name");
when(this.accessToken.getTokenValue()).thenReturn("access-token");
ClientRegistration clientRegistration = this.clientRegistrationBuilder
.userInfoUri(userInfoUri)
.userInfoAuthenticationMethod(AuthenticationMethod.HEADER)
.userNameAttributeName("user-name").build();
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken));
assertThat(this.server.takeRequest(1, TimeUnit.SECONDS).getHeader(HttpHeaders.ACCEPT))
.isEqualTo(MediaType.APPLICATION_JSON_VALUE);
}
@ -295,12 +285,12 @@ public class DefaultOAuth2UserServiceTests {
String userInfoUri = this.server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER);
when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name");
when(this.accessToken.getTokenValue()).thenReturn("access-token");
ClientRegistration clientRegistration = this.clientRegistrationBuilder
.userInfoUri(userInfoUri)
.userInfoAuthenticationMethod(AuthenticationMethod.HEADER)
.userNameAttributeName("user-name").build();
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken));
RecordedRequest request = this.server.takeRequest();
assertThat(request.getMethod()).isEqualTo(HttpMethod.GET.name());
assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE);
@ -322,12 +312,12 @@ public class DefaultOAuth2UserServiceTests {
String userInfoUri = this.server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.FORM);
when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name");
when(this.accessToken.getTokenValue()).thenReturn("access-token");
ClientRegistration clientRegistration = this.clientRegistrationBuilder
.userInfoUri(userInfoUri)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM)
.userNameAttributeName("user-name").build();
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken));
RecordedRequest request = this.server.takeRequest();
assertThat(request.getMethod()).isEqualTo(HttpMethod.POST.name());
assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE);

View File

@ -15,13 +15,17 @@
*/
package org.springframework.security.oauth2.client.web;
import java.util.HashMap;
import java.util.Map;
import javax.servlet.FilterChain;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.powermock.core.classloader.annotations.PowerMockIgnore;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.authentication.AnonymousAuthenticationToken;
@ -39,36 +43,31 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
import org.springframework.security.web.savedrequest.RequestCache;
import javax.servlet.FilterChain;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import java.util.HashMap;
import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.Mockito.*;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes;
import static org.springframework.security.oauth2.core.TestOAuth2RefreshTokens.refreshToken;
import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationExchanges.success;
import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests.request;
/**
* Tests for {@link OAuth2AuthorizationCodeGrantFilter}.
*
* @author Joe Grandja
*/
@PowerMockIgnore("javax.security.*")
@PrepareForTest({OAuth2AuthorizationRequest.class, OAuth2AuthorizationExchange.class, OAuth2AuthorizationCodeGrantFilter.class})
@RunWith(PowerMockRunner.class)
public class OAuth2AuthorizationCodeGrantFilterTests {
private ClientRegistration registration1;
private String principalName1 = "principal-1";
@ -367,19 +366,15 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
ClientRegistration registration) {
Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId());
OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
when(authorizationRequest.getAdditionalParameters()).thenReturn(additionalParameters);
when(authorizationRequest.getRedirectUri()).thenReturn(request.getRequestURL().toString());
when(authorizationRequest.getState()).thenReturn("state");
OAuth2AuthorizationRequest authorizationRequest = request()
.additionalParameters(additionalParameters)
.redirectUri(request.getRequestURL().toString()).build();
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response);
}
private void setUpAuthenticationResult(ClientRegistration registration) {
OAuth2AuthorizationCodeAuthenticationToken authentication = mock(OAuth2AuthorizationCodeAuthenticationToken.class);
when(authentication.getClientRegistration()).thenReturn(registration);
when(authentication.getAuthorizationExchange()).thenReturn(mock(OAuth2AuthorizationExchange.class));
when(authentication.getAccessToken()).thenReturn(mock(OAuth2AccessToken.class));
when(authentication.getRefreshToken()).thenReturn(mock(OAuth2RefreshToken.class));
OAuth2AuthorizationCodeAuthenticationToken authentication =
new OAuth2AuthorizationCodeAuthenticationToken(registration, success(), noScopes(), refreshToken());
when(this.authenticationManager.authenticate(any(Authentication.class))).thenReturn(authentication);
}
}

View File

@ -15,13 +15,16 @@
*/
package org.springframework.security.oauth2.client.web;
import java.util.HashMap;
import java.util.Map;
import javax.servlet.FilterChain;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.powermock.core.classloader.annotations.PowerMockIgnore;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.authentication.AuthenticationManager;
@ -42,7 +45,6 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
@ -51,24 +53,22 @@ import org.springframework.security.web.authentication.AuthenticationFailureHand
import org.springframework.security.web.util.UrlUtils;
import org.springframework.web.util.UriComponentsBuilder;
import javax.servlet.FilterChain;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.HashMap;
import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.Mockito.*;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationExchanges.success;
/**
* Tests for {@link OAuth2LoginAuthenticationFilter}.
*
* @author Joe Grandja
*/
@PowerMockIgnore("javax.security.*")
@PrepareForTest({OAuth2AuthorizationExchange.class, OAuth2LoginAuthenticationFilter.class})
@RunWith(PowerMockRunner.class)
public class OAuth2LoginAuthenticationFilterTests {
private ClientRegistration registration1;
private ClientRegistration registration2;
@ -440,7 +440,7 @@ public class OAuth2LoginAuthenticationFilterTests {
when(this.loginAuthentication.getName()).thenReturn(this.principalName1);
when(this.loginAuthentication.getAuthorities()).thenReturn(AuthorityUtils.createAuthorityList("ROLE_USER"));
when(this.loginAuthentication.getClientRegistration()).thenReturn(registration);
when(this.loginAuthentication.getAuthorizationExchange()).thenReturn(mock(OAuth2AuthorizationExchange.class));
when(this.loginAuthentication.getAuthorizationExchange()).thenReturn(success());
when(this.loginAuthentication.getAccessToken()).thenReturn(mock(OAuth2AccessToken.class));
when(this.loginAuthentication.getRefreshToken()).thenReturn(mock(OAuth2RefreshToken.class));
when(this.loginAuthentication.isAuthenticated()).thenReturn(true);